Coverage for local_installation_linux/mumott/pipelines/reconstruction/sirt.py: 93%
54 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
3import numpy as np
5from mumott.data_handling import DataContainer
6from mumott.data_handling.utilities import get_absorbances
7from mumott.methods.basis_sets import TrivialBasis, GaussianKernels
8from mumott.methods.residual_calculators import GradientResidualCalculator
9from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
10from mumott.methods.utilities import get_sirt_weights, get_sirt_preconditioner
11from mumott.optimization.loss_functions import SquaredLoss
12from mumott.optimization.optimizers import GradientDescent
14logger = logging.getLogger(__name__)
17def run_sirt(data_container: DataContainer,
18 use_absorbances: bool = True,
19 use_gpu: bool = False,
20 maxiter: int = 20,
21 enforce_non_negativity: bool = False,
22 **kwargs):
23 """A reconstruction pipeline for the :term:`SIRT` algorithm, which uses
24 a gradient preconditioner and a set of weights for the projections
25 to achieve fast convergence. Generally, one varies the number of iterations
26 until a good reconstruction is obtained.
28 Advanced users may wish to also modify the ``preconditioner_cutoff`` and
29 ``weights_cutoff`` keyword arguments.
31 Parameters
32 ----------
33 data_container
34 The :class:`DataContainer <mumott.data_handling.DataContainer>`
35 from loading the data set of interest.
36 use_absorbances
37 If ``True``, the reconstruction will use the absorbances
38 calculated from the diode, or absorbances provided via a keyword argument.
39 If ``False``, the data in :attr:`data_container.data` will be used.
40 use_gpu
41 Whether to use GPU resources in computing the projections.
42 Default is ``False``. If set to ``True``, the method will use
43 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
44 maxiter
45 Maximum number of iterations for the gradient descent solution.
46 enforce_non_negativity
47 Enforces strict positivity on all the coefficients. Should only be used
48 with local or scalar representations. Default value is ``False``.
49 kwargs
50 Miscellaneous keyword arguments. See notes for details.
52 Notes
53 -----
54 Many options can be specified through ``kwargs``. These include:
56 Projector
57 The :ref:`projector class <projectors>` to use.
58 preconditioner_cutoff
59 The cutoff to use when computing the :term:`SIRT` preconditioner.
60 Default value is ``0.1``,
61 which will lead to a roughly ellipsoidal mask.
62 weights_cutoff
63 The cutoff to use when computing the :term:`SIRT` weights.
64 Default value is ``0.1``,
65 which will clip some projection edges.
66 absorbances
67 If :attr:`use_absorbances` is set to ``True``, these absorbances
68 will be used instead of ones calculated from the diode.
69 BasisSet
70 The :ref:`basis set class <basis_sets>` to use. If not provided
71 :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>`
72 will be used for absorbances and
73 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>`
74 for other data.
75 basis_set_kwargs
76 Keyword arguments for :attr:`BasisSet`.
77 no_tqdm
78 Used to avoid a ``tqdm`` progress bar in the optimizer.
79 """
80 if 'Projector' in kwargs: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true
81 Projector = kwargs.pop('Projector')
82 else:
83 if use_gpu:
84 Projector = SAXSProjectorCUDA
85 else:
86 Projector = SAXSProjector
87 projector = Projector(data_container.geometry)
88 preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1)
89 weights_cutoff = kwargs.get('weights_cutoff', 0.1)
90 preconditioner = get_sirt_preconditioner(projector, cutoff=preconditioner_cutoff)
91 sirt_weights = get_sirt_weights(projector, cutoff=weights_cutoff)
92 # Save previous weights to avoid accumulation.
93 old_weights = data_container.projections.weights.copy()
94 # Respect previous masking in data container
95 data_container.projections.weights = sirt_weights * np.ceil(data_container.projections.weights)
96 if use_absorbances:
97 if 'absorbances' in kwargs:
98 absorbances = kwargs.pop('absorbances')
99 else:
100 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True)
101 absorbances = abs_dict['absorbances']
102 transmittivity_cutoff_mask = abs_dict['cutoff_mask']
103 data_container.projections.weights *= transmittivity_cutoff_mask
104 else:
105 absorbances = None
106 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict())
107 if 'BasisSet' in kwargs: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true
108 BasisSet = kwargs.pop('BasisSet')
109 else:
110 if use_absorbances:
111 BasisSet = TrivialBasis
112 if 'channels' not in basis_set_kwargs: 112 ↛ 117line 112 didn't jump to line 117, because the condition on line 112 was never false
113 basis_set_kwargs['channels'] = 1
114 else:
115 BasisSet = GaussianKernels
116 basis_set_kwargs['grid_scale'] = (data_container.projections.data.shape[-1]) // 2 + 1
117 basis_set = BasisSet(**basis_set_kwargs)
118 residual_calculator_kwargs = dict(use_scalar_projections=use_absorbances,
119 scalar_projections=absorbances)
120 residual_calculator = GradientResidualCalculator(data_container,
121 basis_set,
122 projector,
123 **residual_calculator_kwargs)
124 loss_function_kwargs = dict(use_weights=True, preconditioner=preconditioner)
125 loss_function = SquaredLoss(residual_calculator,
126 **loss_function_kwargs)
128 optimizer_kwargs = dict(maxiter=maxiter)
129 optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', False)
130 optimizer_kwargs['enforce_non_negativity'] = enforce_non_negativity
131 optimizer = GradientDescent(loss_function,
132 **optimizer_kwargs)
134 result = optimizer.optimize()
135 weights = data_container.projections.weights.copy()
136 data_container.projections.weights = old_weights
137 return dict(result=result, optimizer=optimizer, loss_function=loss_function,
138 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector,
139 absorbances=absorbances, weights=weights)