Coverage for local_installation_linux/mumott/pipelines/reconstruction/group_lasso.py: 85%

68 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import numpy as np 

2import tqdm 

3import logging 

4 

5from mumott import DataContainer 

6from mumott.optimization.optimizers.base_optimizer import Optimizer 

7from mumott.optimization.loss_functions.base_loss_function import LossFunction 

8from mumott.methods.basis_sets.base_basis_set import BasisSet 

9from numpy.typing import NDArray 

10 

11from mumott.methods.residual_calculators import GradientResidualCalculator 

12from mumott.optimization.loss_functions import SquaredLoss 

13from mumott.optimization.regularizers.group_lasso import GroupLasso 

14 

15from mumott.methods.utilities.preconditioning import get_largest_eigenvalue 

16from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector 

17from mumott.methods.basis_sets import SphericalHarmonics 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22def run_group_lasso(data_container: DataContainer, 

23 regularization_parameter: float, 

24 step_size_parameter: float = None, 

25 x0: NDArray[float] = None, 

26 basis_set: BasisSet = None, 

27 ell_max: int = 8, 

28 use_gpu: bool = False, 

29 maxiter: int = 100, 

30 enforce_non_negativity: bool = False, 

31 no_tqdm: bool = False, 

32 ): 

33 

34 """A reconstruction pipeline to do least squares reconstructions regularized with the group-lasso 

35 regularizer and solved with the Iterative Soft-Thresholding Algorithm (ISTA), a proximal gradient 

36 decent method. This reconstruction automatically masks out voxels with zero scattering but 

37 needs the regularization weight as input. 

38 

39 Parameters 

40 ---------- 

41 data_container 

42 The :class:`DataContainer <mumott.data_handling.DataContainer>` 

43 containing the data set of interest. 

44 regularization_parameter 

45 Scalar weight of the regularization term. Should be optimized by performing reconstructions 

46 for a range of possible values. 

47 step_size_parameter 

48 Step size parameter of the reconstruction. If no value is given, a largest-safe 

49 value is estimated. 

50 x0 

51 Starting guess for the solution. By default (``None``) the coefficients are initialized with 

52 zeros. 

53 basis_set 

54 Optionally a basis set can be specified. By default (``None``) 

55 :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>` is used. 

56 ell_max 

57 If no basis set is given, this is the maximum spherical harmonics order used in the 

58 generated basis set. 

59 use_gpu 

60 Whether to use GPU resources in computing the projections. 

61 Default is ``False``. If set to ``True``, the method will use 

62 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. 

63 maxiter 

64 Number of iterations for the ISTA optimization. No stopping rules are implemented. 

65 enforce_non_negativity 

66 Whether or not to enforce non-negativitu of the solution coefficients. 

67 no_tqdm: 

68 Flag whether ot not to print a progress bar for the reconstruction. 

69 """ 

70 

71 if use_gpu: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true

72 Projector = SAXSProjectorCUDA 

73 else: 

74 Projector = SAXSProjector 

75 projector = Projector(data_container.geometry) 

76 

77 if basis_set is None: 

78 basis_set = SphericalHarmonics(ell_max=ell_max) 

79 

80 if step_size_parameter is None: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true

81 logger.info('Calculating step size parameter.') 

82 matrix_norm = get_largest_eigenvalue(basis_set, projector) 

83 step_size_parameter = 0.5 / matrix_norm 

84 

85 loss_function = SquaredLoss(GradientResidualCalculator(data_container, basis_set, projector)) 

86 reg_term = GroupLasso(regularization_parameter, step_size_parameter) 

87 optimizer = _ISTA(loss_function, reg_term, step_size_parameter, x0=x0, maxiter=maxiter, 

88 enforce_non_negativity=enforce_non_negativity, no_tqdm=no_tqdm) 

89 

90 opt_coeffs = optimizer.optimize() 

91 

92 result = dict(result={'x': opt_coeffs}, optimizer=optimizer, loss_function=loss_function, 

93 regularizer=reg_term, basis_set=basis_set, projector=projector) 

94 return result 

95 

96 

97class _ISTA(Optimizer): 

98 """Internal optimizer class for the group lasso pipeline. Implements 

99 <mumott.optimization.optimizers.base_optimizer.Optimizer>. 

100 

101 Parameters 

102 ---------- 

103 loss_function : LossFunction 

104 The differentiable part of the :ref:`loss function <loss_functions>` 

105 to be minimized using this algorithm. 

106 reg_term : GroupLasso 

107 Non-differentiable regularization term to be applied in every iteration. 

108 Must have a `proximal_operator` method. 

109 step_size_parameter : float 

110 Step size for the differentiable part of the optimization. 

111 maxiter : int 

112 Maximum number of iterations. Default value is `50`. 

113 enforce_non_negativity : bool 

114 If `True`, forces all coefficients to be greater than `0` at the end of every iteration. 

115 Default value is `False`. 

116 

117 Notes 

118 ----- 

119 Valid entries in :attr:`kwargs` are 

120 x0 

121 Initial guess for solution vector. Must be the same size as 

122 :attr:`residual_calculator.coefficients`. Defaults to :attr:`loss_function.initial_values`. 

123 """ 

124 

125 def __init__(self, loss_function: LossFunction, reg_term: GroupLasso, step_size_parameter: float, 

126 maxiter: int = 50, enforce_non_negativity: bool = False, **kwargs): 

127 

128 super().__init__(loss_function, **kwargs) 

129 self._maxiter = maxiter 

130 self._reg_term = reg_term 

131 self._step_size_parameter = step_size_parameter 

132 self.error_function_history = [] 

133 self._enforce_non_negativity = enforce_non_negativity 

134 

135 def ISTA_step(self, coefficients): 

136 

137 d = self._loss_function.get_loss(coefficients, get_gradient=True) 

138 gradient = d['gradient'] 

139 total_loss = d['loss'] +\ 

140 self._reg_term.get_regularization_norm(coefficients)['regularization_norm'] 

141 coefficients = coefficients - self._step_size_parameter * gradient 

142 coefficients = self._reg_term.proximal_operator(coefficients) 

143 if self._enforce_non_negativity: 

144 np.clip(coefficients, 0, None, out=coefficients) 

145 

146 return coefficients, total_loss 

147 

148 def optimize(self): 

149 

150 coefficients = self._loss_function.initial_values 

151 if 'x0' in self._options.keys(): 151 ↛ 156line 151 didn't jump to line 156, because the condition on line 151 was never false

152 if self['x0'] is not None: 

153 coefficients = self['x0'] 

154 

155 # Calculate total loss 

156 loss_function_output = self._loss_function.get_loss(coefficients) 

157 reg_term_output = self._reg_term.get_regularization_norm(coefficients) 

158 total_loss = loss_function_output['loss'] + reg_term_output['regularization_norm'] 

159 

160 # Toggle between printing an error bar or not 

161 if not self._no_tqdm: 161 ↛ 164line 161 didn't jump to line 164, because the condition on line 161 was never false

162 iterator = tqdm.tqdm(range(self._maxiter)) 

163 iterator.set_description(f'Loss = {total_loss:.2E}') 

164 elif self._no_tqdm: 

165 iterator = range(self._maxiter) 

166 

167 for ii in iterator: 

168 # Do step 

169 coefficients, total_loss = self.ISTA_step(coefficients) 

170 # Update progress bar 

171 self.error_function_history.append(total_loss) 

172 if not self._no_tqdm: 172 ↛ 167line 172 didn't jump to line 167, because the condition on line 172 was never false

173 iterator.set_description(f'Loss = {total_loss:.2E}') 

174 

175 return np.array(coefficients)