Coverage for local_installation_linux/mumott/optimization/regularizers/group_lasso.py: 88%

38 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2import numpy as np 

3from mumott.optimization.regularizers.base_regularizer import Regularizer 

4from numpy.typing import NDArray 

5 

6logger = logging.getLogger(__name__) 

7 

8 

9class GroupLasso(Regularizer): 

10 r"""Group lasso regularizer, where the coefficients are grouped by voxel. 

11 This approach is well suited for handling voxels with zero scattering and for suppressing missing wedge 

12 artifacts. Note that this type of regularization (:math:`L_1`) has convergence issues when using 

13 gradient-based optimizers due to the divergence of the derivative of the :math:`L_1`-norm at zero. 

14 This is why one commonly uses proximal operators for optimization. 

15 

16 .. math:: 

17 L(\mathrm{c}) = \sum_{xyz} \sqrt(\sum_{i}c_{xyzi}^2) 

18 

19 Parameters 

20 ---------- 

21 regularization_parameter : float 

22 Regularization weight used to define the proximal operator. Can be left as ``1`` (default) 

23 for the normal mumott workflow. 

24 

25 step_size_parameter : float 

26 Step-size parameter used to define the proximal operator. 

27 """ 

28 

29 def __init__(self, regularization_parameter: float = 1, step_size_parameter: float = None): 

30 self._regularization_parameter = regularization_parameter 

31 self._step_size_parameter = step_size_parameter 

32 

33 def get_regularization_norm(self, coefficients: NDArray[float], get_gradient: bool = False, 

34 gradient_part: str = None) -> float: 

35 """Retrieves the group lasso regularization weight of the coefficients. 

36 

37 Parameters 

38 ---------- 

39 coefficients 

40 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where 

41 the last channel contains, e.g., tensor components. 

42 get_gradient 

43 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`. 

44 Otherwise, the entry ``'gradient'`` will be ``None``. Defaults to ``False``. 

45 gradient_part 

46 Used for reconstructions with zonal harmonics (ZHs) to determine what part of the gradient 

47 is being calculated. Default is ``None``. If one of the flag in (``'full'``, ``'angles'``, 

48 ``'coefficients'``) is passed, we assume that the ZH workflow is used and that the last two 

49 coefficients are Euler angles, which should not be regularized by this regularizer. 

50 

51 Returns 

52 ------- 

53 A dictionary with two entries, ``regularization_norm`` and ``gradient``. 

54 """ 

55 

56 if gradient_part is None: 

57 grouped_norms = np.sqrt(np.sum(coefficients**2, axis=-1)) 

58 elif gradient_part in ('full', 'coefficients', 'angles'): 58 ↛ 61line 58 didn't jump to line 61, because the condition on line 58 was never false

59 grouped_norms = np.sqrt(np.sum(coefficients[..., :-2]**2, axis=-1)) 

60 

61 result = {'regularization_norm': np.sum(grouped_norms) * self._regularization_parameter} 

62 

63 if get_gradient: 

64 

65 if gradient_part == 'angles': 

66 gradient = np.zeros(coefficients.shape) 

67 else: 

68 gradient = np.divide(coefficients, 

69 grouped_norms[..., np.newaxis], 

70 out=np.ones(coefficients.shape), 

71 where=grouped_norms[..., np.newaxis] != 0) 

72 if gradient_part in ('full', 'coefficients'): 

73 gradient[..., -2:] = 0 

74 

75 result['gradient'] = gradient*self._regularization_parameter 

76 

77 return result 

78 

79 def proximal_operator(self, coefficients: NDArray[float]) -> NDArray[float]: 

80 """Proximal operator of the group lasso regularizer. 

81 

82 Parameters 

83 ---------- 

84 coefficients 

85 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where 

86 the last channel contains, e.g., tensor components. 

87 

88 Returns 

89 ------- 

90 stepped_coefficient 

91 Input coefficients vector after the application of the proximal operator. 

92 """ 

93 

94 if self._step_size_parameter is None: 94 ↛ 95line 94 didn't jump to line 95, because the condition on line 94 was never true

95 logger.error('Proximal operator is not defined without a stepsize parameter.') 

96 return None 

97 

98 grouped_norms = np.sqrt(np.sum(coefficients**2, axis=-1)) 

99 stepped_coefficient = coefficients - self._regularization_parameter \ 

100 * self._step_size_parameter * coefficients / grouped_norms[..., np.newaxis] 

101 mask = grouped_norms <= self._regularization_parameter * self._step_size_parameter 

102 stepped_coefficient[mask, :] = 0 

103 return stepped_coefficient 

104 

105 @property 

106 def _function_as_str(self) -> str: 

107 return ('R(c) = lambda * sum_xyz( sqrt( sum_lm( c_lm(xyz)^2 ) ) )') 

108 

109 @property 

110 def _function_as_tex(self) -> str: 

111 return (r'$R(\vec{c}) = \lambda \sum_{xyz} \sqrt{ \sum_{\ell, m}c_{\ell, m}(x,y,z)^2 }_2')