Coverage for local_installation_linux/mumott/optimization/loss_functions/squared_loss.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2from typing import Dict 

3 

4import numpy as np 

5from numpy.typing import NDArray 

6 

7from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator 

8from .base_loss_function import LossFunction 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class SquaredLoss(LossFunction): 

14 

15 r"""Class object for obtaining the squared loss function and 

16 gradient from a given :ref:`residual_calculator <residual_calculators>`. 

17 

18 This loss function can be written as :math:`L(r(x, d)) = 0.5 r(x, d)^2`, where :math:`r` is the 

19 residual, a function of :math:`x`, the optimization coefficients, and :math:`d`, the data. 

20 The gradient with respect to :math:`x` is then :math:`\frac{\partial r}{\partial x}`. 

21 The partial derivative of :math:`r` with respect to :math:`x` is the responsibility 

22 of the :attr:`residual_calculator` to compute. 

23 

24 Generally speaking, the squared loss function is easy to compute and has a well-behaved 

25 gradient, but it is not robust against outliers in the data. Using weights to normalize 

26 residuals by the variance can mitigate this somewhat. 

27 

28 Parameters 

29 ---------- 

30 residual_calculator : ResidualCalculator 

31 The :ref:`residual calculator instance <residual_calculators>` from which the 

32 residuals, weights, and gradient terms are obtained. 

33 use_weights : bool 

34 Whether to use weighting in the computation of the residual norm and gradient. 

35 Default is ``False``. 

36 preconditioner : np.ndarray 

37 A preconditioner to be applied to the gradient. Must have the same shape as 

38 :attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication. 

39 residual_norm_multiplier : float 

40 A multiplier that is applied to the residual norm and gradient. Useful in cases where 

41 a very small or large loss function value changes the optimizer behaviour. 

42 """ 

43 

44 def __init__(self, 

45 residual_calculator: ResidualCalculator, 

46 use_weights: bool = False, 

47 preconditioner: NDArray[float] = None, 

48 residual_norm_multiplier: float = 1): 

49 super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier) 

50 

51 def _get_residual_norm_internal(self, get_gradient: bool = False, gradient_part: str = None) -> Dict: 

52 """ Gets the residual norm, and if needed, 

53 the gradient, using the attached :attr:`residual_calculator`. 

54 

55 Parameters 

56 ---------- 

57 get_gradient 

58 Whether to return the gradient. Default is ``False``. 

59 gradient_part 

60 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

61 being calculated. Default is None. 

62 

63 Returns 

64 ------- 

65 A ``dict`` with two entries, ``residual_norm`` and ``gradient``. 

66 If ``get_gradient`` is false, its value will be ``None``. 

67 """ 

68 residual_calculator_output = self._residual_calculator.get_residuals( 

69 get_gradient=get_gradient, get_weights=self._use_weights, gradient_part=gradient_part) 

70 residuals = residual_calculator_output['residuals'] 

71 if self.use_weights: 

72 # weights (1/variance) need to be applied since they depend on the loss function 

73 residual_norm = 0.5 * np.einsum( 

74 'ijkh, ijkh, ijkh -> ...', residuals, residuals, residual_calculator_output['weights']) 

75 else: 

76 residual_norm = 0.5 * np.einsum( 

77 'ijkh, ijkh -> ...', residuals, residuals) 

78 

79 if residual_norm < 1: 

80 logger.warning(f'The residual norm value ({residual_norm}) is < 1.' 

81 ' Note that some optimizers change their convergence criteria for' 

82 ' loss functions < 1!') 

83 

84 return dict(residual_norm=residual_norm, gradient=residual_calculator_output['gradient']) 

85 

86 def get_estimate_of_lifschitz_constant(self) -> float: 

87 """ 

88 Calculate an estimate of the Lifschitz constant of this cost function. Used to determine a 

89 safe step-size for certain optimization algorithms. 

90 

91 Returns 

92 ------- 

93 lifschitz_constant 

94 Lifschitz constant. 

95 """ 

96 matrix_norm = self._residual_calculator.get_estimate_of_matrix_norm() 

97 return 2 / matrix_norm 

98 

99 @property 

100 def _function_as_str(self) -> str: 

101 """ Should return a string representation of the associated loss function. """ 

102 return 'L(r) = r ** 2' 

103 

104 @property 

105 def _function_as_tex(self) -> str: 

106 """ Should return a string representation of the associated loss function 

107 in MathJax-renderable TeX.""" 

108 return r'$L(r) = r^2$'