Coverage for local_installation_linux/mumott/optimization/loss_functions/huber_loss.py: 83%
38 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
3import numpy as np
4from numpy.typing import NDArray
6from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator
7from .base_loss_function import LossFunction
9logger = logging.getLogger(__name__)
12class HuberLoss(LossFunction):
14 r"""Class object for obtaining the Huber loss function and gradient from a given
15 :ref:`residual_calculator <residual_calculators>`.
17 This loss function is used for so-called
18 `robust regression <https://en.wikipedia.org/wiki/Robust_regression>`_ and can be written as
20 .. math::
21 L(r(x, D)) = \begin{Bmatrix}
22 \vert r(x, D) \vert - 0.5 \delta & \quad \text{if } \vert r(x, D) \vert > \delta \\
23 \dfrac{r(x, D)^2}{2 \delta} & \quad \text{if } \vert r(x, D) \vert < \delta
24 \end{Bmatrix},
26 where :math:`r` is the residual, a function of :math:`x`, the optimization coefficients,
27 and :math:`D`, the data. The gradient with respect to :math:`x`
28 is then :math:`\sigma(\frac{\partial r}{\partial x})` for large :math:`r`, where :math:`\sigma(x)`
29 is the sign function, and :math:`\frac{\partial r}{\partial x}` for small :math:`r`. The partial
30 derivative of :math:`r` with respect to :math:`x` is the responsibility of the
31 :attr:`residual_calculator` to compute.
33 Broadly speaking, the Huber loss function is less sensitive to outliers than the squared (or :math:`L_2`)
34 loss function, while it is easier to minimize than the :math:`L_1` loss function
35 since it its derivative is continuous in the entire domain.
37 See also the Wikipedia articles on `robust regression <https://en.wikipedia.org/wiki/Robust_regression>`_
38 and the `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_.
40 Parameters
41 ----------
42 residual_calculator : ResidualCalculator
43 The :ref:`residual calculator instance <residual_calculators>` from which the
44 residuals, weights, and gradient terms are obtained.
45 use_weights : bool
46 Whether to use weighting in the computation of the residual
47 norm and gradient. Default is ``False``.
48 preconditioner : np.ndarray
49 A preconditioner to be applied to the gradient. Must have the same shape as
50 :attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication.
51 residual_norm_multiplier : float
52 A multiplier that is applied to the residual norm and gradient. Useful in cases where
53 a very small or large loss function value changes the optimizer behaviour.
54 delta : float
55 The cutoff value where the :math:`L_1` loss function is spliced with the :math:`L_2` loss function.
56 The default value is ``1.``, but the appropriate value to use depends on the data
57 and the chosen representation.
59 """
61 def __init__(self,
62 residual_calculator: ResidualCalculator,
63 use_weights: bool = False,
64 preconditioner: NDArray[float] = None,
65 residual_norm_multiplier: float = 1.,
66 delta: float = 1.):
67 if delta < 0: 67 ↛ 68line 67 didn't jump to line 68, because the condition on line 67 was never true
68 raise ValueError('delta must be greater than or equal to zero, but a value'
69 f' of {delta} was specified!')
70 super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier)
71 self._delta = float(delta)
73 def _get_residual_norm_internal(self,
74 get_gradient: bool = False,
75 gradient_part: str = None
76 ) -> dict[str, NDArray[float]]:
77 """ Gets the residual norm, and if needed,
78 the gradient, using the attached :attr:`residual_calculator`.
80 Parameters
81 ----------
82 get_gradient
83 Whether to return the gradient. Default is ``False``.
84 gradient_part
85 Used for the zonal harmonics resonstructions to determine what part of the gradient is
86 being calculated. Default is None.
88 Returns
89 -------
90 A ``dict`` with two entries, ``residual_norm`` and ``gradient``.
91 If ``get_gradient`` is false, its value will be ``None``.
92 """
93 residual_calculator_output = self._residual_calculator.get_residuals(
94 get_gradient=False, get_weights=self._use_weights, gradient_part=gradient_part)
95 residuals = residual_calculator_output['residuals']
96 # where indicates small values, use l2 at these points
97 where = abs(residuals) < self._delta
98 residual_norm = 0.
99 if self.use_weights:
100 # weights (e.g. 1/variance) need to be applied since they depend on the loss function
101 residual_norm += 0.5 * np.reciprocal(self._delta) * np.einsum(
102 'i, i, i',
103 residuals[where].ravel(),
104 residuals[where].ravel(),
105 residual_calculator_output['weights'][where].ravel(),
106 optimize='greedy')
107 residual_norm += np.dot(abs(residuals[~where]) - 0.5 * self._delta,
108 residual_calculator_output['weights'][~where])
109 else:
110 residual_norm += 0.5 * np.reciprocal(self._delta) * np.dot(
111 residuals[where].ravel(), residuals[where].ravel())
112 residual_norm += np.sum(abs(residuals[~where]) - 0.5 * self._delta)
114 if get_gradient: 114 ↛ 121line 114 didn't jump to line 121, because the condition on line 114 was never false
115 residuals[where] *= np.reciprocal(self._delta)
116 residuals[~where] = np.sign(residuals[~where])
117 if self.use_weights:
118 residuals *= residual_calculator_output['weights']
119 gradient = self._residual_calculator.get_gradient_from_residual_gradient(residuals)
120 else:
121 gradient = None
123 if residual_norm < 1: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true
124 logger.warning(f'The residual norm value ({residual_norm}) is < 1.'
125 ' Note that some optimizers change their convergence criteria for'
126 ' loss functions < 1!')
128 return dict(residual_norm=residual_norm, gradient=gradient)
130 @property
131 def _function_as_str(self) -> str:
132 """ Should return a string representation of the associated loss function. """
133 return ('L(r[abs(r) >= delta]) =\n'
134 ' lambda * (abs(r) - 0.5 * delta)\n'
135 ' R(x[abs(r) < delta]) = lambda * (r ** 2) / (2 * delta)')
137 @property
138 def _function_as_tex(self) -> str:
139 """ Should return a string representation of the associated loss function
140 in MathJax-renderable TeX."""
141 # we use html line breaks <br> since LaTeX line breaks appear unsupported.
142 return (r'$L(x_i) = \lambda (\vert \vec{x} \vert - 0.5\delta)'
143 r'\quad \text{ if } \vert x \vert < \delta$<br>'
144 r'$L(x_i) = \lambda \dfrac{x^2}{2 \delta} \quad \text{ if } x \leq \delta$<br>'
145 r'$R(\vec{x}) = \sum_i L(x_i)$')