Coverage for local_installation_linux/mumott/pipelines/reconstruction/sigtt.py: 98%
39 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
3from mumott.data_handling import DataContainer
4from mumott.methods.basis_sets import SphericalHarmonics
5from mumott.methods.residual_calculators import GradientResidualCalculator
6from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
7from mumott.optimization.loss_functions import SquaredLoss
8from mumott.optimization.optimizers import LBFGS
9from mumott.optimization.regularizers import Laplacian
11logger = logging.getLogger(__name__)
14def run_sigtt(data_container: DataContainer,
15 use_gpu: bool = False,
16 maxiter: int = 20,
17 ftol: float = 1e-2,
18 regularization_weight: float = 1e-4,
19 **kwargs):
20 """A reconstruction pipeline for the :term:`SIGTT` algorithm, which uses
21 a gradient and a regularizer to accomplish reconstruction.
23 Parameters
24 ----------
25 data_container
26 The :class:`DataContainer <mumott.data_handling.DataContainer>`
27 from loading the data set of interest.
28 use_gpu
29 Whether to use GPU resources in computing the projections.
30 Default is ``False``. If set to ``True``, the method will use
31 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
32 maxiter
33 Maximum number of iterations for the gradient descent solution.
34 ftol
35 Tolerance for the change in the loss function. Default is ``None``,
36 in which case the reconstruction will terminate once the maximum
37 number of iterations have been performed.
38 regularization_weight
39 Regularization weight for the default
40 :class:`Laplacian <mumott.optimization.regularizers.Laplacian>` regularizer.
41 Ignored if a loss function is provided.
42 kwargs
43 Miscellaneous keyword arguments. See notes for details.
45 Notes
46 -----
47 Many options can be specified through :attr:`kwargs`. Miscellaneous ones are passed to the optimizer.
48 Specific keywords include:
50 Projector
51 The :ref:`projector class <projectors>` to use.
52 BasisSet
53 The :ref:`basis set class <basis_sets>` to use. If not provided
54 :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>`
55 will be used.
56 basis_set_kwargs
57 Keyword arguments for :attr:`BasisSet`.
58 ResidualCalculator
59 The :ref:`residual_calculator class <residual_calculators>` to use. If not provided
60 :class:`GradientResidualCalculator
61 <mumott.methods.residual_calculators.GradientResidualCalculator>`
62 will be used.
63 residual_calculator_kwargs
64 Keyword arguments for :attr:`ResidualCalculator`.
65 LossFunction
66 The :ref:`loss function class <loss_functions>` to use. If not provided
67 :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>`
68 will be used.
69 loss_function_kwargs
70 Keyword arguments for :attr:`LossFunction`.
71 Regularizers
72 A list of dictionaries with three entries, a name
73 (``str``), a :ref:`regularizer object <regularizers>`, and
74 a regularization weight (``float``); used by
75 :func:`loss_function.add_regularizer()
76 <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`.
77 By default, a :class:`Laplacian
78 <mumott.optimization.regularizers.Laplacian>` with the
79 weight :attr:`regularization_weight` will be used. If
80 other regularizers are specified, this will be overridden.
81 Optimizer
82 The :ref:`optimizer class <optimizers>` to use. If not provided
83 :class:`LBFGS <mumott.optimization.optimizers.LBFGS>` will be used.
84 optimizer_kwargs
85 Keyword arguments for :attr:`Optimizer`.
87 """
88 if 'Projector' in kwargs:
89 Projector = kwargs.pop('Projector')
90 else:
91 if use_gpu:
92 Projector = SAXSProjectorCUDA
93 else:
94 Projector = SAXSProjector
95 projector = Projector(data_container.geometry)
96 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict())
97 if 'BasisSet' in kwargs:
98 BasisSet = kwargs.pop('BasisSet')
99 else:
100 if 'ell_max' not in basis_set_kwargs: 100 ↛ 102line 100 didn't jump to line 102, because the condition on line 100 was never false
101 basis_set_kwargs['ell_max'] = 2 * ((data_container.data.shape[-1] - 1) // 2)
102 BasisSet = SphericalHarmonics
103 basis_set = BasisSet(**basis_set_kwargs)
104 ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator)
105 residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict())
106 residual_calculator = ResidualCalculator(data_container,
107 basis_set,
108 projector,
109 **residual_calculator_kwargs)
111 Regularizers = kwargs.get('Regularizers',
112 [dict(name='laplacian',
113 regularizer=Laplacian(),
114 regularization_weight=regularization_weight)])
115 LossFunction = kwargs.get('LossFunction', SquaredLoss)
116 loss_function_kwargs = kwargs.get('loss_function_kwargs', dict())
117 loss_function = LossFunction(residual_calculator,
118 **loss_function_kwargs)
119 for reg in Regularizers:
120 loss_function.add_regularizer(**reg)
121 Optimizer = kwargs.get('Optimizer', LBFGS)
122 optimizer_kwargs = kwargs.get('optimizer_kwargs', dict())
123 optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter)
124 optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol)
125 optimizer = Optimizer(loss_function,
126 **optimizer_kwargs)
127 result = optimizer.optimize()
128 return dict(result=result, optimizer=optimizer, loss_function=loss_function,
129 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector)