Coverage for local_installation_linux/mumott/pipelines/reconstruction/mitra.py: 95%

78 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3import numpy as np 

4 

5from mumott.data_handling import DataContainer 

6from mumott.data_handling.utilities import get_absorbances 

7from mumott.methods.basis_sets import TrivialBasis, GaussianKernels 

8from mumott.methods.residual_calculators import GradientResidualCalculator 

9from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector 

10from mumott.methods.utilities import (get_sirt_weights, get_sirt_preconditioner, 

11 get_tensor_sirt_weights, get_tensor_sirt_preconditioner) 

12from mumott.optimization.loss_functions import SquaredLoss 

13from mumott.optimization.optimizers import GradientDescent 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18def run_mitra(data_container: DataContainer, 

19 use_absorbances: bool = True, 

20 use_sirt_weights: bool = True, 

21 use_gpu: bool = False, 

22 maxiter: int = 20, 

23 ftol: float = None, 

24 **kwargs): 

25 """Reconstruction pipeline for the Modular Iterative Tomographic Reconstruction Algorithm (MITRA). 

26 This is a versatile, configureable interface for tomographic reconstruction that allows for 

27 various optimizers, projectors, loss functions and regularizers to be supplied. 

28 

29 This is meant as a convenience interface for intermediate or advanced users to create 

30 customized reconstruction pipelines. 

31 

32 Parameters 

33 ---------- 

34 data_container 

35 The :class:`DataContainer <mumott.data_handling.DataContainer>` 

36 from loading the data set of interest. 

37 use_absorbances 

38 If ``True``, the reconstruction will use the absorbances 

39 calculated from the diode, or absorbances provided via a keyword argument. 

40 If ``False``, the data in :attr:`data_container.data` will be used. 

41 use_sirt_weights 

42 If ``True`` (default), SIRT or tensor SIRT weights will be computed 

43 for use in the reconstruction. 

44 use_gpu 

45 Whether to use GPU resources in computing the projections. 

46 Default is ``False``, which means 

47 :class:`SAXSProjector <mumott.methods.projectors.SAXSProjector>`. 

48 If set to ``True``, the method will use 

49 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. 

50 maxiter 

51 Maximum number of iterations for the gradient descent solution. 

52 ftol 

53 Tolerance for the change in the loss function. Default is ``None``, 

54 in which case the reconstruction will terminate once the maximum 

55 number of iterations have been performed. 

56 kwargs 

57 Miscellaneous keyword arguments. See notes for details. 

58 

59 Notes 

60 ----- 

61 Many options can be specified through ``kwargs``. These include: 

62 

63 Projector 

64 The :ref:`projector class <projectors>` to use. 

65 absorbances 

66 If :attr:`use_absorbances` is set to ``True``, these absorbances 

67 will be used instead of ones calculated from the diode. 

68 preconditioner_cutoff 

69 The cutoff to use when computing the :term:`SIRT` preconditioner. 

70 Default value is ``0.1``, 

71 which will lead to a roughly ellipsoidal mask. 

72 weights_cutoff 

73 The cutoff to use when computing the :term:`SIRT` weights. 

74 Default value is ``0.1``, 

75 which will clip some projection edges. 

76 BasisSet 

77 The :ref:`basis set class <basis_sets>` to use. If not provided 

78 :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>` 

79 will be used for absorbances and 

80 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>` 

81 for other data. 

82 basis_set_kwargs 

83 Keyword arguments for :attr:`BasisSet`. 

84 ResidualCalculator 

85 The :ref:`residual calculator class <residual_calculators>` to use. 

86 If not provided, then 

87 :class:`GradientResidualCalculator 

88 <mumott.methods.residual_calculators.GradientResidualCalculator>` 

89 will be used. 

90 residual_calculator_kwargs 

91 Keyword arguments for :attr:`ResidualCalculator`. 

92 LossFunction 

93 The :ref:`loss function class <loss_functions>` to use. If not provided 

94 :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>` 

95 will be used. 

96 loss_function_kwargs 

97 Keyword arguments for :attr:`LossFunction`. 

98 Regularizers 

99 A list of dictionaries with three entries, a name 

100 (``str``), a :ref:`regularizer object <regularizers>`, and 

101 a regularization weight (``float``); used by 

102 :func:`loss_function.add_regularizer() 

103 <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`. 

104 Optimizer 

105 The optimizer class to use. If not provided 

106 :class:`GradientDescent <mumott.optimization.optimizers.GradientDescent>` 

107 will be used. By default, the keyword argument ``nestorov_weight`` is set to 

108 ``0.95``, and ``enforce_non_negativity`` is ``True`` 

109 optimizer_kwargs 

110 Keyword arguments for :attr:`Optimizer`. 

111 """ 

112 if 'Projector' in kwargs: 112 ↛ 113line 112 didn't jump to line 113, because the condition on line 112 was never true

113 Projector = kwargs.pop('Projector') 

114 else: 

115 if use_gpu: 

116 Projector = SAXSProjectorCUDA 

117 else: 

118 Projector = SAXSProjector 

119 projector = Projector(data_container.geometry) 

120 if use_absorbances: 

121 if 'absorbances' in kwargs: 121 ↛ 122line 121 didn't jump to line 122, because the condition on line 121 was never true

122 absorbances = kwargs.pop('absorbances') 

123 else: 

124 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True) 

125 absorbances = abs_dict['absorbances'] 

126 transmittivity_cutoff_mask = abs_dict['cutoff_mask'] 

127 data_container.projections.weights *= transmittivity_cutoff_mask 

128 else: 

129 absorbances = None 

130 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict()) 

131 if 'BasisSet' in kwargs: 

132 BasisSet = kwargs.pop('BasisSet') 

133 else: 

134 if use_absorbances: 

135 BasisSet = TrivialBasis 

136 if 'channels' not in basis_set_kwargs: 136 ↛ 143line 136 didn't jump to line 143, because the condition on line 136 was never false

137 basis_set_kwargs['channels'] = 1 

138 else: 

139 BasisSet = GaussianKernels 

140 basis_set_kwargs['grid_scale'] = \ 

141 basis_set_kwargs.get('grid_scale', (data_container.projections.data.shape[-1]) // 2 + 1) 

142 

143 basis_set = BasisSet(**basis_set_kwargs) 

144 

145 ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator) 

146 residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict()) 

147 residual_calculator_kwargs['use_scalar_projections'] = residual_calculator_kwargs.get( 

148 'use_scalar_projections', use_absorbances) 

149 residual_calculator_kwargs['scalar_projections'] = residual_calculator_kwargs.get( 

150 'scalar_projections', absorbances) 

151 residual_calculator = ResidualCalculator(data_container, 

152 basis_set, 

153 projector, 

154 **residual_calculator_kwargs) 

155 Regularizers = kwargs.get('Regularizers', []) 

156 LossFunction = kwargs.get('LossFunction', SquaredLoss) 

157 loss_function_kwargs = kwargs.get('loss_function_kwargs', dict()) 

158 preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1) 

159 weights_cutoff = kwargs.get('weights_cutoff', 0.1) 

160 

161 if use_sirt_weights: 

162 if use_absorbances: 

163 preconditioner = get_sirt_preconditioner( 

164 projector, cutoff=preconditioner_cutoff) 

165 sirt_weights = get_sirt_weights( 

166 projector, cutoff=weights_cutoff) 

167 else: 

168 preconditioner = get_tensor_sirt_preconditioner( 

169 projector, basis_set, cutoff=preconditioner_cutoff) 

170 sirt_weights = get_tensor_sirt_weights( 

171 projector, basis_set, cutoff=weights_cutoff) 

172 old_weights = data_container.projections.weights.copy() 

173 weights = sirt_weights * np.round(data_container.projections.weights > 0).astype(float) 

174 data_container.projections.weights = weights 

175 loss_function_kwargs['use_weights'] = True 

176 else: 

177 # If not using SIRT weights, just fetch identically named arguments as normal from kwargs 

178 weights = kwargs.get('weights', data_container.projections.weights) 

179 preconditioner = loss_function_kwargs.get('preconditioner', None) 

180 loss_function_kwargs['use_weights'] = loss_function_kwargs.get('use_weights', True) 

181 

182 loss_function_kwargs['preconditioner'] = preconditioner 

183 loss_function = LossFunction(residual_calculator, 

184 **loss_function_kwargs) 

185 

186 for reg in Regularizers: 

187 loss_function.add_regularizer(**reg) 

188 optimizer_kwargs = kwargs.get('optimizer_kwargs', dict()) 

189 if 'Optimizer' in kwargs: 

190 Optimizer = kwargs.pop('Optimizer') 

191 else: 

192 Optimizer = GradientDescent 

193 if 'nestorov_weight' not in optimizer_kwargs: 

194 optimizer_kwargs['nestorov_weight'] = 0.95 

195 optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter) 

196 optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol) 

197 optimizer_kwargs['enforce_non_negativity'] = optimizer_kwargs.get('enforce_non_negativity', True) 

198 optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', optimizer_kwargs.get('no_tqdm', False)) 

199 optimizer = Optimizer(loss_function, 

200 **optimizer_kwargs) 

201 

202 result = optimizer.optimize() 

203 

204 if use_sirt_weights: 

205 data_container.projections.weights = old_weights 

206 

207 return dict(result=result, optimizer=optimizer, loss_function=loss_function, 

208 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector, 

209 absorbances=absorbances, weights=weights, preconditioner=preconditioner)