Coverage for local_installation_linux/mumott/pipelines/reconstruction/mitra.py: 95%
78 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
3import numpy as np
5from mumott.data_handling import DataContainer
6from mumott.data_handling.utilities import get_absorbances
7from mumott.methods.basis_sets import TrivialBasis, GaussianKernels
8from mumott.methods.residual_calculators import GradientResidualCalculator
9from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
10from mumott.methods.utilities import (get_sirt_weights, get_sirt_preconditioner,
11 get_tensor_sirt_weights, get_tensor_sirt_preconditioner)
12from mumott.optimization.loss_functions import SquaredLoss
13from mumott.optimization.optimizers import GradientDescent
15logger = logging.getLogger(__name__)
18def run_mitra(data_container: DataContainer,
19 use_absorbances: bool = True,
20 use_sirt_weights: bool = True,
21 use_gpu: bool = False,
22 maxiter: int = 20,
23 ftol: float = None,
24 **kwargs):
25 """Reconstruction pipeline for the Modular Iterative Tomographic Reconstruction Algorithm (MITRA).
26 This is a versatile, configureable interface for tomographic reconstruction that allows for
27 various optimizers, projectors, loss functions and regularizers to be supplied.
29 This is meant as a convenience interface for intermediate or advanced users to create
30 customized reconstruction pipelines.
32 Parameters
33 ----------
34 data_container
35 The :class:`DataContainer <mumott.data_handling.DataContainer>`
36 from loading the data set of interest.
37 use_absorbances
38 If ``True``, the reconstruction will use the absorbances
39 calculated from the diode, or absorbances provided via a keyword argument.
40 If ``False``, the data in :attr:`data_container.data` will be used.
41 use_sirt_weights
42 If ``True`` (default), SIRT or tensor SIRT weights will be computed
43 for use in the reconstruction.
44 use_gpu
45 Whether to use GPU resources in computing the projections.
46 Default is ``False``, which means
47 :class:`SAXSProjector <mumott.methods.projectors.SAXSProjector>`.
48 If set to ``True``, the method will use
49 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
50 maxiter
51 Maximum number of iterations for the gradient descent solution.
52 ftol
53 Tolerance for the change in the loss function. Default is ``None``,
54 in which case the reconstruction will terminate once the maximum
55 number of iterations have been performed.
56 kwargs
57 Miscellaneous keyword arguments. See notes for details.
59 Notes
60 -----
61 Many options can be specified through ``kwargs``. These include:
63 Projector
64 The :ref:`projector class <projectors>` to use.
65 absorbances
66 If :attr:`use_absorbances` is set to ``True``, these absorbances
67 will be used instead of ones calculated from the diode.
68 preconditioner_cutoff
69 The cutoff to use when computing the :term:`SIRT` preconditioner.
70 Default value is ``0.1``,
71 which will lead to a roughly ellipsoidal mask.
72 weights_cutoff
73 The cutoff to use when computing the :term:`SIRT` weights.
74 Default value is ``0.1``,
75 which will clip some projection edges.
76 BasisSet
77 The :ref:`basis set class <basis_sets>` to use. If not provided
78 :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>`
79 will be used for absorbances and
80 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>`
81 for other data.
82 basis_set_kwargs
83 Keyword arguments for :attr:`BasisSet`.
84 ResidualCalculator
85 The :ref:`residual calculator class <residual_calculators>` to use.
86 If not provided, then
87 :class:`GradientResidualCalculator
88 <mumott.methods.residual_calculators.GradientResidualCalculator>`
89 will be used.
90 residual_calculator_kwargs
91 Keyword arguments for :attr:`ResidualCalculator`.
92 LossFunction
93 The :ref:`loss function class <loss_functions>` to use. If not provided
94 :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>`
95 will be used.
96 loss_function_kwargs
97 Keyword arguments for :attr:`LossFunction`.
98 Regularizers
99 A list of dictionaries with three entries, a name
100 (``str``), a :ref:`regularizer object <regularizers>`, and
101 a regularization weight (``float``); used by
102 :func:`loss_function.add_regularizer()
103 <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`.
104 Optimizer
105 The optimizer class to use. If not provided
106 :class:`GradientDescent <mumott.optimization.optimizers.GradientDescent>`
107 will be used. By default, the keyword argument ``nestorov_weight`` is set to
108 ``0.95``, and ``enforce_non_negativity`` is ``True``
109 optimizer_kwargs
110 Keyword arguments for :attr:`Optimizer`.
111 """
112 if 'Projector' in kwargs: 112 ↛ 113line 112 didn't jump to line 113, because the condition on line 112 was never true
113 Projector = kwargs.pop('Projector')
114 else:
115 if use_gpu:
116 Projector = SAXSProjectorCUDA
117 else:
118 Projector = SAXSProjector
119 projector = Projector(data_container.geometry)
120 if use_absorbances:
121 if 'absorbances' in kwargs: 121 ↛ 122line 121 didn't jump to line 122, because the condition on line 121 was never true
122 absorbances = kwargs.pop('absorbances')
123 else:
124 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True)
125 absorbances = abs_dict['absorbances']
126 transmittivity_cutoff_mask = abs_dict['cutoff_mask']
127 data_container.projections.weights *= transmittivity_cutoff_mask
128 else:
129 absorbances = None
130 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict())
131 if 'BasisSet' in kwargs:
132 BasisSet = kwargs.pop('BasisSet')
133 else:
134 if use_absorbances:
135 BasisSet = TrivialBasis
136 if 'channels' not in basis_set_kwargs: 136 ↛ 143line 136 didn't jump to line 143, because the condition on line 136 was never false
137 basis_set_kwargs['channels'] = 1
138 else:
139 BasisSet = GaussianKernels
140 basis_set_kwargs['grid_scale'] = \
141 basis_set_kwargs.get('grid_scale', (data_container.projections.data.shape[-1]) // 2 + 1)
143 basis_set = BasisSet(**basis_set_kwargs)
145 ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator)
146 residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict())
147 residual_calculator_kwargs['use_scalar_projections'] = residual_calculator_kwargs.get(
148 'use_scalar_projections', use_absorbances)
149 residual_calculator_kwargs['scalar_projections'] = residual_calculator_kwargs.get(
150 'scalar_projections', absorbances)
151 residual_calculator = ResidualCalculator(data_container,
152 basis_set,
153 projector,
154 **residual_calculator_kwargs)
155 Regularizers = kwargs.get('Regularizers', [])
156 LossFunction = kwargs.get('LossFunction', SquaredLoss)
157 loss_function_kwargs = kwargs.get('loss_function_kwargs', dict())
158 preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1)
159 weights_cutoff = kwargs.get('weights_cutoff', 0.1)
161 if use_sirt_weights:
162 if use_absorbances:
163 preconditioner = get_sirt_preconditioner(
164 projector, cutoff=preconditioner_cutoff)
165 sirt_weights = get_sirt_weights(
166 projector, cutoff=weights_cutoff)
167 else:
168 preconditioner = get_tensor_sirt_preconditioner(
169 projector, basis_set, cutoff=preconditioner_cutoff)
170 sirt_weights = get_tensor_sirt_weights(
171 projector, basis_set, cutoff=weights_cutoff)
172 old_weights = data_container.projections.weights.copy()
173 weights = sirt_weights * np.round(data_container.projections.weights > 0).astype(float)
174 data_container.projections.weights = weights
175 loss_function_kwargs['use_weights'] = True
176 else:
177 # If not using SIRT weights, just fetch identically named arguments as normal from kwargs
178 weights = kwargs.get('weights', data_container.projections.weights)
179 preconditioner = loss_function_kwargs.get('preconditioner', None)
180 loss_function_kwargs['use_weights'] = loss_function_kwargs.get('use_weights', True)
182 loss_function_kwargs['preconditioner'] = preconditioner
183 loss_function = LossFunction(residual_calculator,
184 **loss_function_kwargs)
186 for reg in Regularizers:
187 loss_function.add_regularizer(**reg)
188 optimizer_kwargs = kwargs.get('optimizer_kwargs', dict())
189 if 'Optimizer' in kwargs:
190 Optimizer = kwargs.pop('Optimizer')
191 else:
192 Optimizer = GradientDescent
193 if 'nestorov_weight' not in optimizer_kwargs:
194 optimizer_kwargs['nestorov_weight'] = 0.95
195 optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter)
196 optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol)
197 optimizer_kwargs['enforce_non_negativity'] = optimizer_kwargs.get('enforce_non_negativity', True)
198 optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', optimizer_kwargs.get('no_tqdm', False))
199 optimizer = Optimizer(loss_function,
200 **optimizer_kwargs)
202 result = optimizer.optimize()
204 if use_sirt_weights:
205 data_container.projections.weights = old_weights
207 return dict(result=result, optimizer=optimizer, loss_function=loss_function,
208 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector,
209 absorbances=absorbances, weights=weights, preconditioner=preconditioner)