Coverage for local_installation_linux/mumott/optimization/regularizers/huber_norm.py: 81%

44 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import numpy as np 

2from numpy.typing import NDArray 

3 

4from mumott.optimization.regularizers.base_regularizer import Regularizer 

5import logging 

6logger = logging.getLogger(__name__) 

7 

8 

9class HuberNorm(Regularizer): 

10 

11 r"""Regularizes using the Huber norm of the coefficient, which splices the :math`L_1` 

12 and :math:`L_2` norms. 

13 Suitable for scalar fields or tensor fields in local representations. 

14 Tends to reduce noise while converging more easily than the :math:`L_1` loss function. 

15 

16 The Huber norm of a vector :math:`x` is given by :math:`R(\vec{x}) = \sum_i L(x_i)`, 

17 where :math:`L(x_i)` is given by 

18 

19 .. math:: 

20 L(x_i) = \begin{Bmatrix}\vert x_i \vert - 0.5 \delta & \quad \text{if } \vert x_i \vert > \delta \\ 

21 \dfrac{x^2}{2 \delta} & \quad \text{if } \vert x_i \vert \leq \delta\end{Bmatrix} 

22 

23 See also `the Wikipedia article on the Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_. 

24 

25 Parameters 

26 ---------- 

27 delta : float 

28 The threshold value for the Huber norm. Must be greater than ``0``. Default value 

29 is ``1.``, but the appropriate value is data-dependent. 

30 """ 

31 

32 def __init__(self, delta: float = 1.): 

33 if delta <= 0: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true

34 raise ValueError('delta must be greater than zero, but a value' 

35 f' of {delta} was specified! For pure L1 regularization, use L1Norm.') 

36 super().__init__() 

37 self._delta = float(delta) 

38 

39 def get_regularization_norm(self, 

40 coefficients: NDArray[float], 

41 get_gradient: bool = False, 

42 gradient_part: str = None) -> dict[str, NDArray[float]]: 

43 """Retrieves the Huber loss of the coefficients. Appropriate for 

44 use with scalar fields or tensor fields in local representations. 

45 

46 Parameters 

47 ---------- 

48 coefficients 

49 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where 

50 the last channel contains, e.g., tensor components. 

51 get_gradient 

52 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`. 

53 Otherwise, the entry ``'gradient'`` will be ``None``. Defaults to ``False``. 

54 gradient_part 

55 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

56 being calculated. Default is None. If a flag is passed in ('full', 'angles', 'coefficients'), 

57 we assume that the ZH workflow is used and that the last two coefficients are euler angles, 

58 which should not be regularized by this regularizer. 

59 

60 Returns 

61 ------- 

62 A dictionary with two entries, ``regularization_norm`` and ``gradient``. 

63 """ 

64 

65 r = abs(coefficients) 

66 

67 where = r < self._delta 

68 

69 # where indicates small values, use l2 at these points 

70 r[where] *= r[where] * np.reciprocal(2 * self._delta) 

71 r[~where] -= 0.5 * self._delta 

72 

73 result = dict(regularization_norm=None, gradient=None) 

74 

75 if get_gradient: 75 ↛ 91line 75 didn't jump to line 91, because the condition on line 75 was never false

76 if gradient_part is None: 

77 gradient = coefficients * np.reciprocal(self._delta) 

78 gradient[~where] = np.sign(gradient[~where]) 

79 result['gradient'] = gradient 

80 elif gradient_part in ('full', 'coefficients'): 

81 gradient = coefficients * np.reciprocal(self._delta) 

82 gradient[~where] = np.sign(gradient[~where]) 

83 result['gradient'] = gradient 

84 result['gradient'][..., -2:] = 0 

85 elif gradient_part in ('angles'): 85 ↛ 88line 85 didn't jump to line 88, because the condition on line 85 was never false

86 result['gradient'] = np.zeros(coefficients.shape) 

87 else: 

88 logger.warning('Unexpected argument given for gradient part.') 

89 raise ValueError 

90 

91 if gradient_part is None: 

92 result['regularization_norm'] = np.sum(r) 

93 elif gradient_part in ('full', 'coefficients', 'angles'): 93 ↛ 96line 93 didn't jump to line 96, because the condition on line 93 was never false

94 result['regularization_norm'] = np.sum(r[..., :-2] ** 2) 

95 else: 

96 logger.warning('Unexpected argument given for gradient part.') 

97 raise ValueError 

98 

99 return result 

100 

101 @property 

102 def _function_as_str(self) -> str: 

103 return ('R(x[abs(x) >= delta]) =\n' 

104 ' lambda * (abs(x) - 0.5 * delta)\n' 

105 ' R(x[abs(x) < delta]) = lambda * (x ** 2) / (2 * delta)') 

106 

107 @property 

108 def _function_as_tex(self) -> str: 

109 # we use html line breaks <br> since LaTeX line breaks appear unsupported. 

110 return (r'$L(x_i) = \lambda (\vert \vec{x} \vert - 0.5\delta)' 

111 r'\quad \text{if } \vert x \vert \geq \delta$<br>' 

112 r'$L(x_i) = \lambda \dfrac{x^2}{2 \delta} \quad \text{ if } x < \delta$<br>' 

113 r'$R(\vec{x}) = \sum_i L(x_i)$')