Coverage for local_installation_linux/mumott/pipelines/reconstruction/group_lasso.py: 85%
68 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import numpy as np
2import tqdm
3import logging
5from mumott import DataContainer
6from mumott.optimization.optimizers.base_optimizer import Optimizer
7from mumott.optimization.loss_functions.base_loss_function import LossFunction
8from mumott.methods.basis_sets.base_basis_set import BasisSet
9from numpy.typing import NDArray
11from mumott.methods.residual_calculators import GradientResidualCalculator
12from mumott.optimization.loss_functions import SquaredLoss
13from mumott.optimization.regularizers.group_lasso import GroupLasso
15from mumott.methods.utilities.preconditioning import get_largest_eigenvalue
16from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
17from mumott.methods.basis_sets import SphericalHarmonics
19logger = logging.getLogger(__name__)
22def run_group_lasso(data_container: DataContainer,
23 regularization_parameter: float,
24 step_size_parameter: float = None,
25 x0: NDArray[float] = None,
26 basis_set: BasisSet = None,
27 ell_max: int = 8,
28 use_gpu: bool = False,
29 maxiter: int = 100,
30 enforce_non_negativity: bool = False,
31 no_tqdm: bool = False,
32 ):
34 """A reconstruction pipeline to do least squares reconstructions regularized with the group-lasso
35 regularizer and solved with the Iterative Soft-Thresholding Algorithm (ISTA), a proximal gradient
36 decent method. This reconstruction automatically masks out voxels with zero scattering but
37 needs the regularization weight as input.
39 Parameters
40 ----------
41 data_container
42 The :class:`DataContainer <mumott.data_handling.DataContainer>`
43 containing the data set of interest.
44 regularization_parameter
45 Scalar weight of the regularization term. Should be optimized by performing reconstructions
46 for a range of possible values.
47 step_size_parameter
48 Step size parameter of the reconstruction. If no value is given, a largest-safe
49 value is estimated.
50 x0
51 Starting guess for the solution. By default (``None``) the coefficients are initialized with
52 zeros.
53 basis_set
54 Optionally a basis set can be specified. By default (``None``)
55 :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>` is used.
56 ell_max
57 If no basis set is given, this is the maximum spherical harmonics order used in the
58 generated basis set.
59 use_gpu
60 Whether to use GPU resources in computing the projections.
61 Default is ``False``. If set to ``True``, the method will use
62 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
63 maxiter
64 Number of iterations for the ISTA optimization. No stopping rules are implemented.
65 enforce_non_negativity
66 Whether or not to enforce non-negativitu of the solution coefficients.
67 no_tqdm:
68 Flag whether ot not to print a progress bar for the reconstruction.
69 """
71 if use_gpu: 71 ↛ 72line 71 didn't jump to line 72, because the condition on line 71 was never true
72 Projector = SAXSProjectorCUDA
73 else:
74 Projector = SAXSProjector
75 projector = Projector(data_container.geometry)
77 if basis_set is None:
78 basis_set = SphericalHarmonics(ell_max=ell_max)
80 if step_size_parameter is None: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true
81 logger.info('Calculating step size parameter.')
82 matrix_norm = get_largest_eigenvalue(basis_set, projector)
83 step_size_parameter = 0.5 / matrix_norm
85 loss_function = SquaredLoss(GradientResidualCalculator(data_container, basis_set, projector))
86 reg_term = GroupLasso(regularization_parameter, step_size_parameter)
87 optimizer = _ISTA(loss_function, reg_term, step_size_parameter, x0=x0, maxiter=maxiter,
88 enforce_non_negativity=enforce_non_negativity, no_tqdm=no_tqdm)
90 opt_coeffs = optimizer.optimize()
92 result = dict(result={'x': opt_coeffs}, optimizer=optimizer, loss_function=loss_function,
93 regularizer=reg_term, basis_set=basis_set, projector=projector)
94 return result
97class _ISTA(Optimizer):
98 """Internal optimizer class for the group lasso pipeline. Implements
99 <mumott.optimization.optimizers.base_optimizer.Optimizer>.
101 Parameters
102 ----------
103 loss_function : LossFunction
104 The differentiable part of the :ref:`loss function <loss_functions>`
105 to be minimized using this algorithm.
106 reg_term : GroupLasso
107 Non-differentiable regularization term to be applied in every iteration.
108 Must have a `proximal_operator` method.
109 step_size_parameter : float
110 Step size for the differentiable part of the optimization.
111 maxiter : int
112 Maximum number of iterations. Default value is `50`.
113 enforce_non_negativity : bool
114 If `True`, forces all coefficients to be greater than `0` at the end of every iteration.
115 Default value is `False`.
117 Notes
118 -----
119 Valid entries in :attr:`kwargs` are
120 x0
121 Initial guess for solution vector. Must be the same size as
122 :attr:`residual_calculator.coefficients`. Defaults to :attr:`loss_function.initial_values`.
123 """
125 def __init__(self, loss_function: LossFunction, reg_term: GroupLasso, step_size_parameter: float,
126 maxiter: int = 50, enforce_non_negativity: bool = False, **kwargs):
128 super().__init__(loss_function, **kwargs)
129 self._maxiter = maxiter
130 self._reg_term = reg_term
131 self._step_size_parameter = step_size_parameter
132 self.error_function_history = []
133 self._enforce_non_negativity = enforce_non_negativity
135 def ISTA_step(self, coefficients):
137 d = self._loss_function.get_loss(coefficients, get_gradient=True)
138 gradient = d['gradient']
139 total_loss = d['loss'] +\
140 self._reg_term.get_regularization_norm(coefficients)['regularization_norm']
141 coefficients = coefficients - self._step_size_parameter * gradient
142 coefficients = self._reg_term.proximal_operator(coefficients)
143 if self._enforce_non_negativity:
144 np.clip(coefficients, 0, None, out=coefficients)
146 return coefficients, total_loss
148 def optimize(self):
150 coefficients = self._loss_function.initial_values
151 if 'x0' in self._options.keys(): 151 ↛ 156line 151 didn't jump to line 156, because the condition on line 151 was never false
152 if self['x0'] is not None:
153 coefficients = self['x0']
155 # Calculate total loss
156 loss_function_output = self._loss_function.get_loss(coefficients)
157 reg_term_output = self._reg_term.get_regularization_norm(coefficients)
158 total_loss = loss_function_output['loss'] + reg_term_output['regularization_norm']
160 # Toggle between printing an error bar or not
161 if not self._no_tqdm: 161 ↛ 164line 161 didn't jump to line 164, because the condition on line 161 was never false
162 iterator = tqdm.tqdm(range(self._maxiter))
163 iterator.set_description(f'Loss = {total_loss:.2E}')
164 elif self._no_tqdm:
165 iterator = range(self._maxiter)
167 for ii in iterator:
168 # Do step
169 coefficients, total_loss = self.ISTA_step(coefficients)
170 # Update progress bar
171 self.error_function_history.append(total_loss)
172 if not self._no_tqdm: 172 ↛ 167line 172 didn't jump to line 167, because the condition on line 172 was never false
173 iterator.set_description(f'Loss = {total_loss:.2E}')
175 return np.array(coefficients)