Coverage for local_installation_linux/mumott/pipelines/reconstruction/sirt.py: 93%

54 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3import numpy as np 

4 

5from mumott.data_handling import DataContainer 

6from mumott.data_handling.utilities import get_absorbances 

7from mumott.methods.basis_sets import TrivialBasis, GaussianKernels 

8from mumott.methods.residual_calculators import GradientResidualCalculator 

9from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector 

10from mumott.methods.utilities import get_sirt_weights, get_sirt_preconditioner 

11from mumott.optimization.loss_functions import SquaredLoss 

12from mumott.optimization.optimizers import GradientDescent 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17def run_sirt(data_container: DataContainer, 

18 use_absorbances: bool = True, 

19 use_gpu: bool = False, 

20 maxiter: int = 20, 

21 enforce_non_negativity: bool = False, 

22 **kwargs): 

23 """A reconstruction pipeline for the :term:`SIRT` algorithm, which uses 

24 a gradient preconditioner and a set of weights for the projections 

25 to achieve fast convergence. Generally, one varies the number of iterations 

26 until a good reconstruction is obtained. 

27 

28 Advanced users may wish to also modify the ``preconditioner_cutoff`` and 

29 ``weights_cutoff`` keyword arguments. 

30 

31 Parameters 

32 ---------- 

33 data_container 

34 The :class:`DataContainer <mumott.data_handling.DataContainer>` 

35 from loading the data set of interest. 

36 use_absorbances 

37 If ``True``, the reconstruction will use the absorbances 

38 calculated from the diode, or absorbances provided via a keyword argument. 

39 If ``False``, the data in :attr:`data_container.data` will be used. 

40 use_gpu 

41 Whether to use GPU resources in computing the projections. 

42 Default is ``False``. If set to ``True``, the method will use 

43 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. 

44 maxiter 

45 Maximum number of iterations for the gradient descent solution. 

46 enforce_non_negativity 

47 Enforces strict positivity on all the coefficients. Should only be used 

48 with local or scalar representations. Default value is ``False``. 

49 kwargs 

50 Miscellaneous keyword arguments. See notes for details. 

51 

52 Notes 

53 ----- 

54 Many options can be specified through ``kwargs``. These include: 

55 

56 Projector 

57 The :ref:`projector class <projectors>` to use. 

58 preconditioner_cutoff 

59 The cutoff to use when computing the :term:`SIRT` preconditioner. 

60 Default value is ``0.1``, 

61 which will lead to a roughly ellipsoidal mask. 

62 weights_cutoff 

63 The cutoff to use when computing the :term:`SIRT` weights. 

64 Default value is ``0.1``, 

65 which will clip some projection edges. 

66 absorbances 

67 If :attr:`use_absorbances` is set to ``True``, these absorbances 

68 will be used instead of ones calculated from the diode. 

69 BasisSet 

70 The :ref:`basis set class <basis_sets>` to use. If not provided 

71 :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>` 

72 will be used for absorbances and 

73 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>` 

74 for other data. 

75 basis_set_kwargs 

76 Keyword arguments for :attr:`BasisSet`. 

77 no_tqdm 

78 Used to avoid a ``tqdm`` progress bar in the optimizer. 

79 """ 

80 if 'Projector' in kwargs: 80 ↛ 81line 80 didn't jump to line 81, because the condition on line 80 was never true

81 Projector = kwargs.pop('Projector') 

82 else: 

83 if use_gpu: 

84 Projector = SAXSProjectorCUDA 

85 else: 

86 Projector = SAXSProjector 

87 projector = Projector(data_container.geometry) 

88 preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1) 

89 weights_cutoff = kwargs.get('weights_cutoff', 0.1) 

90 preconditioner = get_sirt_preconditioner(projector, cutoff=preconditioner_cutoff) 

91 sirt_weights = get_sirt_weights(projector, cutoff=weights_cutoff) 

92 # Save previous weights to avoid accumulation. 

93 old_weights = data_container.projections.weights.copy() 

94 # Respect previous masking in data container 

95 data_container.projections.weights = sirt_weights * np.ceil(data_container.projections.weights) 

96 if use_absorbances: 

97 if 'absorbances' in kwargs: 

98 absorbances = kwargs.pop('absorbances') 

99 else: 

100 abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True) 

101 absorbances = abs_dict['absorbances'] 

102 transmittivity_cutoff_mask = abs_dict['cutoff_mask'] 

103 data_container.projections.weights *= transmittivity_cutoff_mask 

104 else: 

105 absorbances = None 

106 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict()) 

107 if 'BasisSet' in kwargs: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true

108 BasisSet = kwargs.pop('BasisSet') 

109 else: 

110 if use_absorbances: 

111 BasisSet = TrivialBasis 

112 if 'channels' not in basis_set_kwargs: 112 ↛ 117line 112 didn't jump to line 117, because the condition on line 112 was never false

113 basis_set_kwargs['channels'] = 1 

114 else: 

115 BasisSet = GaussianKernels 

116 basis_set_kwargs['grid_scale'] = (data_container.projections.data.shape[-1]) // 2 + 1 

117 basis_set = BasisSet(**basis_set_kwargs) 

118 residual_calculator_kwargs = dict(use_scalar_projections=use_absorbances, 

119 scalar_projections=absorbances) 

120 residual_calculator = GradientResidualCalculator(data_container, 

121 basis_set, 

122 projector, 

123 **residual_calculator_kwargs) 

124 loss_function_kwargs = dict(use_weights=True, preconditioner=preconditioner) 

125 loss_function = SquaredLoss(residual_calculator, 

126 **loss_function_kwargs) 

127 

128 optimizer_kwargs = dict(maxiter=maxiter) 

129 optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', False) 

130 optimizer_kwargs['enforce_non_negativity'] = enforce_non_negativity 

131 optimizer = GradientDescent(loss_function, 

132 **optimizer_kwargs) 

133 

134 result = optimizer.optimize() 

135 weights = data_container.projections.weights.copy() 

136 data_container.projections.weights = old_weights 

137 return dict(result=result, optimizer=optimizer, loss_function=loss_function, 

138 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector, 

139 absorbances=absorbances, weights=weights)