Coverage for local_installation_linux/mumott/optimization/regularizers/group_lasso.py: 88%
38 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
2import numpy as np
3from mumott.optimization.regularizers.base_regularizer import Regularizer
4from numpy.typing import NDArray
6logger = logging.getLogger(__name__)
9class GroupLasso(Regularizer):
10 r"""Group lasso regularizer, where the coefficients are grouped by voxel.
11 This approach is well suited for handling voxels with zero scattering and for suppressing missing wedge
12 artifacts. Note that this type of regularization (:math:`L_1`) has convergence issues when using
13 gradient-based optimizers due to the divergence of the derivative of the :math:`L_1`-norm at zero.
14 This is why one commonly uses proximal operators for optimization.
16 .. math::
17 L(\mathrm{c}) = \sum_{xyz} \sqrt(\sum_{i}c_{xyzi}^2)
19 Parameters
20 ----------
21 regularization_parameter : float
22 Regularization weight used to define the proximal operator. Can be left as ``1`` (default)
23 for the normal mumott workflow.
25 step_size_parameter : float
26 Step-size parameter used to define the proximal operator.
27 """
29 def __init__(self, regularization_parameter: float = 1, step_size_parameter: float = None):
30 self._regularization_parameter = regularization_parameter
31 self._step_size_parameter = step_size_parameter
33 def get_regularization_norm(self, coefficients: NDArray[float], get_gradient: bool = False,
34 gradient_part: str = None) -> float:
35 """Retrieves the group lasso regularization weight of the coefficients.
37 Parameters
38 ----------
39 coefficients
40 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where
41 the last channel contains, e.g., tensor components.
42 get_gradient
43 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`.
44 Otherwise, the entry ``'gradient'`` will be ``None``. Defaults to ``False``.
45 gradient_part
46 Used for reconstructions with zonal harmonics (ZHs) to determine what part of the gradient
47 is being calculated. Default is ``None``. If one of the flag in (``'full'``, ``'angles'``,
48 ``'coefficients'``) is passed, we assume that the ZH workflow is used and that the last two
49 coefficients are Euler angles, which should not be regularized by this regularizer.
51 Returns
52 -------
53 A dictionary with two entries, ``regularization_norm`` and ``gradient``.
54 """
56 if gradient_part is None:
57 grouped_norms = np.sqrt(np.sum(coefficients**2, axis=-1))
58 elif gradient_part in ('full', 'coefficients', 'angles'): 58 ↛ 61line 58 didn't jump to line 61, because the condition on line 58 was never false
59 grouped_norms = np.sqrt(np.sum(coefficients[..., :-2]**2, axis=-1))
61 result = {'regularization_norm': np.sum(grouped_norms) * self._regularization_parameter}
63 if get_gradient:
65 if gradient_part == 'angles':
66 gradient = np.zeros(coefficients.shape)
67 else:
68 gradient = np.divide(coefficients,
69 grouped_norms[..., np.newaxis],
70 out=np.ones(coefficients.shape),
71 where=grouped_norms[..., np.newaxis] != 0)
72 if gradient_part in ('full', 'coefficients'):
73 gradient[..., -2:] = 0
75 result['gradient'] = gradient*self._regularization_parameter
77 return result
79 def proximal_operator(self, coefficients: NDArray[float]) -> NDArray[float]:
80 """Proximal operator of the group lasso regularizer.
82 Parameters
83 ----------
84 coefficients
85 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where
86 the last channel contains, e.g., tensor components.
88 Returns
89 -------
90 stepped_coefficient
91 Input coefficients vector after the application of the proximal operator.
92 """
94 if self._step_size_parameter is None: 94 ↛ 95line 94 didn't jump to line 95, because the condition on line 94 was never true
95 logger.error('Proximal operator is not defined without a stepsize parameter.')
96 return None
98 grouped_norms = np.sqrt(np.sum(coefficients**2, axis=-1))
99 stepped_coefficient = coefficients - self._regularization_parameter \
100 * self._step_size_parameter * coefficients / grouped_norms[..., np.newaxis]
101 mask = grouped_norms <= self._regularization_parameter * self._step_size_parameter
102 stepped_coefficient[mask, :] = 0
103 return stepped_coefficient
105 @property
106 def _function_as_str(self) -> str:
107 return ('R(c) = lambda * sum_xyz( sqrt( sum_lm( c_lm(xyz)^2 ) ) )')
109 @property
110 def _function_as_tex(self) -> str:
111 return (r'$R(\vec{c}) = \lambda \sum_{xyz} \sqrt{ \sum_{\ell, m}c_{\ell, m}(x,y,z)^2 }_2')