Coverage for local_installation_linux/mumott/optimization/loss_functions/huber_loss.py: 83%

38 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5 

6from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator 

7from .base_loss_function import LossFunction 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12class HuberLoss(LossFunction): 

13 

14 r"""Class object for obtaining the Huber loss function and gradient from a given 

15 :ref:`residual_calculator <residual_calculators>`. 

16 

17 This loss function is used for so-called 

18 `robust regression <https://en.wikipedia.org/wiki/Robust_regression>`_ and can be written as 

19 

20 .. math:: 

21 L(r(x, D)) = \begin{Bmatrix} 

22 \vert r(x, D) \vert - 0.5 \delta & \quad \text{if } \vert r(x, D) \vert > \delta \\ 

23 \dfrac{r(x, D)^2}{2 \delta} & \quad \text{if } \vert r(x, D) \vert < \delta 

24 \end{Bmatrix}, 

25 

26 where :math:`r` is the residual, a function of :math:`x`, the optimization coefficients, 

27 and :math:`D`, the data. The gradient with respect to :math:`x` 

28 is then :math:`\sigma(\frac{\partial r}{\partial x})` for large :math:`r`, where :math:`\sigma(x)` 

29 is the sign function, and :math:`\frac{\partial r}{\partial x}` for small :math:`r`. The partial 

30 derivative of :math:`r` with respect to :math:`x` is the responsibility of the 

31 :attr:`residual_calculator` to compute. 

32 

33 Broadly speaking, the Huber loss function is less sensitive to outliers than the squared (or :math:`L_2`) 

34 loss function, while it is easier to minimize than the :math:`L_1` loss function 

35 since it its derivative is continuous in the entire domain. 

36 

37 See also the Wikipedia articles on `robust regression <https://en.wikipedia.org/wiki/Robust_regression>`_ 

38 and the `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_. 

39 

40 Parameters 

41 ---------- 

42 residual_calculator : ResidualCalculator 

43 The :ref:`residual calculator instance <residual_calculators>` from which the 

44 residuals, weights, and gradient terms are obtained. 

45 use_weights : bool 

46 Whether to use weighting in the computation of the residual 

47 norm and gradient. Default is ``False``. 

48 preconditioner : np.ndarray 

49 A preconditioner to be applied to the gradient. Must have the same shape as 

50 :attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication. 

51 residual_norm_multiplier : float 

52 A multiplier that is applied to the residual norm and gradient. Useful in cases where 

53 a very small or large loss function value changes the optimizer behaviour. 

54 delta : float 

55 The cutoff value where the :math:`L_1` loss function is spliced with the :math:`L_2` loss function. 

56 The default value is ``1.``, but the appropriate value to use depends on the data 

57 and the chosen representation. 

58 

59 """ 

60 

61 def __init__(self, 

62 residual_calculator: ResidualCalculator, 

63 use_weights: bool = False, 

64 preconditioner: NDArray[float] = None, 

65 residual_norm_multiplier: float = 1., 

66 delta: float = 1.): 

67 if delta < 0: 67 ↛ 68line 67 didn't jump to line 68, because the condition on line 67 was never true

68 raise ValueError('delta must be greater than or equal to zero, but a value' 

69 f' of {delta} was specified!') 

70 super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier) 

71 self._delta = float(delta) 

72 

73 def _get_residual_norm_internal(self, 

74 get_gradient: bool = False, 

75 gradient_part: str = None 

76 ) -> dict[str, NDArray[float]]: 

77 """ Gets the residual norm, and if needed, 

78 the gradient, using the attached :attr:`residual_calculator`. 

79 

80 Parameters 

81 ---------- 

82 get_gradient 

83 Whether to return the gradient. Default is ``False``. 

84 gradient_part 

85 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

86 being calculated. Default is None. 

87 

88 Returns 

89 ------- 

90 A ``dict`` with two entries, ``residual_norm`` and ``gradient``. 

91 If ``get_gradient`` is false, its value will be ``None``. 

92 """ 

93 residual_calculator_output = self._residual_calculator.get_residuals( 

94 get_gradient=False, get_weights=self._use_weights, gradient_part=gradient_part) 

95 residuals = residual_calculator_output['residuals'] 

96 # where indicates small values, use l2 at these points 

97 where = abs(residuals) < self._delta 

98 residual_norm = 0. 

99 if self.use_weights: 

100 # weights (e.g. 1/variance) need to be applied since they depend on the loss function 

101 residual_norm += 0.5 * np.reciprocal(self._delta) * np.einsum( 

102 'i, i, i', 

103 residuals[where].ravel(), 

104 residuals[where].ravel(), 

105 residual_calculator_output['weights'][where].ravel(), 

106 optimize='greedy') 

107 residual_norm += np.dot(abs(residuals[~where]) - 0.5 * self._delta, 

108 residual_calculator_output['weights'][~where]) 

109 else: 

110 residual_norm += 0.5 * np.reciprocal(self._delta) * np.dot( 

111 residuals[where].ravel(), residuals[where].ravel()) 

112 residual_norm += np.sum(abs(residuals[~where]) - 0.5 * self._delta) 

113 

114 if get_gradient: 114 ↛ 121line 114 didn't jump to line 121, because the condition on line 114 was never false

115 residuals[where] *= np.reciprocal(self._delta) 

116 residuals[~where] = np.sign(residuals[~where]) 

117 if self.use_weights: 

118 residuals *= residual_calculator_output['weights'] 

119 gradient = self._residual_calculator.get_gradient_from_residual_gradient(residuals) 

120 else: 

121 gradient = None 

122 

123 if residual_norm < 1: 123 ↛ 124line 123 didn't jump to line 124, because the condition on line 123 was never true

124 logger.warning(f'The residual norm value ({residual_norm}) is < 1.' 

125 ' Note that some optimizers change their convergence criteria for' 

126 ' loss functions < 1!') 

127 

128 return dict(residual_norm=residual_norm, gradient=gradient) 

129 

130 @property 

131 def _function_as_str(self) -> str: 

132 """ Should return a string representation of the associated loss function. """ 

133 return ('L(r[abs(r) >= delta]) =\n' 

134 ' lambda * (abs(r) - 0.5 * delta)\n' 

135 ' R(x[abs(r) < delta]) = lambda * (r ** 2) / (2 * delta)') 

136 

137 @property 

138 def _function_as_tex(self) -> str: 

139 """ Should return a string representation of the associated loss function 

140 in MathJax-renderable TeX.""" 

141 # we use html line breaks <br> since LaTeX line breaks appear unsupported. 

142 return (r'$L(x_i) = \lambda (\vert \vec{x} \vert - 0.5\delta)' 

143 r'\quad \text{ if } \vert x \vert < \delta$<br>' 

144 r'$L(x_i) = \lambda \dfrac{x^2}{2 \delta} \quad \text{ if } x \leq \delta$<br>' 

145 r'$R(\vec{x}) = \sum_i L(x_i)$')