Coverage for local_installation_linux/mumott/optimization/regularizers/huber_norm.py: 81%
44 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import numpy as np
2from numpy.typing import NDArray
4from mumott.optimization.regularizers.base_regularizer import Regularizer
5import logging
6logger = logging.getLogger(__name__)
9class HuberNorm(Regularizer):
11 r"""Regularizes using the Huber norm of the coefficient, which splices the :math`L_1`
12 and :math:`L_2` norms.
13 Suitable for scalar fields or tensor fields in local representations.
14 Tends to reduce noise while converging more easily than the :math:`L_1` loss function.
16 The Huber norm of a vector :math:`x` is given by :math:`R(\vec{x}) = \sum_i L(x_i)`,
17 where :math:`L(x_i)` is given by
19 .. math::
20 L(x_i) = \begin{Bmatrix}\vert x_i \vert - 0.5 \delta & \quad \text{if } \vert x_i \vert > \delta \\
21 \dfrac{x^2}{2 \delta} & \quad \text{if } \vert x_i \vert \leq \delta\end{Bmatrix}
23 See also `the Wikipedia article on the Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_.
25 Parameters
26 ----------
27 delta : float
28 The threshold value for the Huber norm. Must be greater than ``0``. Default value
29 is ``1.``, but the appropriate value is data-dependent.
30 """
32 def __init__(self, delta: float = 1.):
33 if delta <= 0: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true
34 raise ValueError('delta must be greater than zero, but a value'
35 f' of {delta} was specified! For pure L1 regularization, use L1Norm.')
36 super().__init__()
37 self._delta = float(delta)
39 def get_regularization_norm(self,
40 coefficients: NDArray[float],
41 get_gradient: bool = False,
42 gradient_part: str = None) -> dict[str, NDArray[float]]:
43 """Retrieves the Huber loss of the coefficients. Appropriate for
44 use with scalar fields or tensor fields in local representations.
46 Parameters
47 ----------
48 coefficients
49 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where
50 the last channel contains, e.g., tensor components.
51 get_gradient
52 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`.
53 Otherwise, the entry ``'gradient'`` will be ``None``. Defaults to ``False``.
54 gradient_part
55 Used for the zonal harmonics resonstructions to determine what part of the gradient is
56 being calculated. Default is None. If a flag is passed in ('full', 'angles', 'coefficients'),
57 we assume that the ZH workflow is used and that the last two coefficients are euler angles,
58 which should not be regularized by this regularizer.
60 Returns
61 -------
62 A dictionary with two entries, ``regularization_norm`` and ``gradient``.
63 """
65 r = abs(coefficients)
67 where = r < self._delta
69 # where indicates small values, use l2 at these points
70 r[where] *= r[where] * np.reciprocal(2 * self._delta)
71 r[~where] -= 0.5 * self._delta
73 result = dict(regularization_norm=None, gradient=None)
75 if get_gradient: 75 ↛ 91line 75 didn't jump to line 91, because the condition on line 75 was never false
76 if gradient_part is None:
77 gradient = coefficients * np.reciprocal(self._delta)
78 gradient[~where] = np.sign(gradient[~where])
79 result['gradient'] = gradient
80 elif gradient_part in ('full', 'coefficients'):
81 gradient = coefficients * np.reciprocal(self._delta)
82 gradient[~where] = np.sign(gradient[~where])
83 result['gradient'] = gradient
84 result['gradient'][..., -2:] = 0
85 elif gradient_part in ('angles'): 85 ↛ 88line 85 didn't jump to line 88, because the condition on line 85 was never false
86 result['gradient'] = np.zeros(coefficients.shape)
87 else:
88 logger.warning('Unexpected argument given for gradient part.')
89 raise ValueError
91 if gradient_part is None:
92 result['regularization_norm'] = np.sum(r)
93 elif gradient_part in ('full', 'coefficients', 'angles'): 93 ↛ 96line 93 didn't jump to line 96, because the condition on line 93 was never false
94 result['regularization_norm'] = np.sum(r[..., :-2] ** 2)
95 else:
96 logger.warning('Unexpected argument given for gradient part.')
97 raise ValueError
99 return result
101 @property
102 def _function_as_str(self) -> str:
103 return ('R(x[abs(x) >= delta]) =\n'
104 ' lambda * (abs(x) - 0.5 * delta)\n'
105 ' R(x[abs(x) < delta]) = lambda * (x ** 2) / (2 * delta)')
107 @property
108 def _function_as_tex(self) -> str:
109 # we use html line breaks <br> since LaTeX line breaks appear unsupported.
110 return (r'$L(x_i) = \lambda (\vert \vec{x} \vert - 0.5\delta)'
111 r'\quad \text{if } \vert x \vert \geq \delta$<br>'
112 r'$L(x_i) = \lambda \dfrac{x^2}{2 \delta} \quad \text{ if } x < \delta$<br>'
113 r'$R(\vec{x}) = \sum_i L(x_i)$')