Coverage for local_installation_linux/mumott/pipelines/reconstruction/sigtt.py: 98%

39 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3from mumott.data_handling import DataContainer 

4from mumott.methods.basis_sets import SphericalHarmonics 

5from mumott.methods.residual_calculators import GradientResidualCalculator 

6from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector 

7from mumott.optimization.loss_functions import SquaredLoss 

8from mumott.optimization.optimizers import LBFGS 

9from mumott.optimization.regularizers import Laplacian 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def run_sigtt(data_container: DataContainer, 

15 use_gpu: bool = False, 

16 maxiter: int = 20, 

17 ftol: float = 1e-2, 

18 regularization_weight: float = 1e-4, 

19 **kwargs): 

20 """A reconstruction pipeline for the :term:`SIGTT` algorithm, which uses 

21 a gradient and a regularizer to accomplish reconstruction. 

22 

23 Parameters 

24 ---------- 

25 data_container 

26 The :class:`DataContainer <mumott.data_handling.DataContainer>` 

27 from loading the data set of interest. 

28 use_gpu 

29 Whether to use GPU resources in computing the projections. 

30 Default is ``False``. If set to ``True``, the method will use 

31 :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. 

32 maxiter 

33 Maximum number of iterations for the gradient descent solution. 

34 ftol 

35 Tolerance for the change in the loss function. Default is ``None``, 

36 in which case the reconstruction will terminate once the maximum 

37 number of iterations have been performed. 

38 regularization_weight 

39 Regularization weight for the default 

40 :class:`Laplacian <mumott.optimization.regularizers.Laplacian>` regularizer. 

41 Ignored if a loss function is provided. 

42 kwargs 

43 Miscellaneous keyword arguments. See notes for details. 

44 

45 Notes 

46 ----- 

47 Many options can be specified through :attr:`kwargs`. Miscellaneous ones are passed to the optimizer. 

48 Specific keywords include: 

49 

50 Projector 

51 The :ref:`projector class <projectors>` to use. 

52 BasisSet 

53 The :ref:`basis set class <basis_sets>` to use. If not provided 

54 :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>` 

55 will be used. 

56 basis_set_kwargs 

57 Keyword arguments for :attr:`BasisSet`. 

58 ResidualCalculator 

59 The :ref:`residual_calculator class <residual_calculators>` to use. If not provided 

60 :class:`GradientResidualCalculator 

61 <mumott.methods.residual_calculators.GradientResidualCalculator>` 

62 will be used. 

63 residual_calculator_kwargs 

64 Keyword arguments for :attr:`ResidualCalculator`. 

65 LossFunction 

66 The :ref:`loss function class <loss_functions>` to use. If not provided 

67 :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>` 

68 will be used. 

69 loss_function_kwargs 

70 Keyword arguments for :attr:`LossFunction`. 

71 Regularizers 

72 A list of dictionaries with three entries, a name 

73 (``str``), a :ref:`regularizer object <regularizers>`, and 

74 a regularization weight (``float``); used by 

75 :func:`loss_function.add_regularizer() 

76 <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`. 

77 By default, a :class:`Laplacian 

78 <mumott.optimization.regularizers.Laplacian>` with the 

79 weight :attr:`regularization_weight` will be used. If 

80 other regularizers are specified, this will be overridden. 

81 Optimizer 

82 The :ref:`optimizer class <optimizers>` to use. If not provided 

83 :class:`LBFGS <mumott.optimization.optimizers.LBFGS>` will be used. 

84 optimizer_kwargs 

85 Keyword arguments for :attr:`Optimizer`. 

86 

87 """ 

88 if 'Projector' in kwargs: 

89 Projector = kwargs.pop('Projector') 

90 else: 

91 if use_gpu: 

92 Projector = SAXSProjectorCUDA 

93 else: 

94 Projector = SAXSProjector 

95 projector = Projector(data_container.geometry) 

96 basis_set_kwargs = kwargs.get('basis_set_kwargs', dict()) 

97 if 'BasisSet' in kwargs: 

98 BasisSet = kwargs.pop('BasisSet') 

99 else: 

100 if 'ell_max' not in basis_set_kwargs: 100 ↛ 102line 100 didn't jump to line 102, because the condition on line 100 was never false

101 basis_set_kwargs['ell_max'] = 2 * ((data_container.data.shape[-1] - 1) // 2) 

102 BasisSet = SphericalHarmonics 

103 basis_set = BasisSet(**basis_set_kwargs) 

104 ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator) 

105 residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict()) 

106 residual_calculator = ResidualCalculator(data_container, 

107 basis_set, 

108 projector, 

109 **residual_calculator_kwargs) 

110 

111 Regularizers = kwargs.get('Regularizers', 

112 [dict(name='laplacian', 

113 regularizer=Laplacian(), 

114 regularization_weight=regularization_weight)]) 

115 LossFunction = kwargs.get('LossFunction', SquaredLoss) 

116 loss_function_kwargs = kwargs.get('loss_function_kwargs', dict()) 

117 loss_function = LossFunction(residual_calculator, 

118 **loss_function_kwargs) 

119 for reg in Regularizers: 

120 loss_function.add_regularizer(**reg) 

121 Optimizer = kwargs.get('Optimizer', LBFGS) 

122 optimizer_kwargs = kwargs.get('optimizer_kwargs', dict()) 

123 optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter) 

124 optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol) 

125 optimizer = Optimizer(loss_function, 

126 **optimizer_kwargs) 

127 result = optimizer.optimize() 

128 return dict(result=result, optimizer=optimizer, loss_function=loss_function, 

129 residual_calculator=residual_calculator, basis_set=basis_set, projector=projector)