Coverage for local_installation_linux/mumott/optimization/loss_functions/squared_loss.py: 100%
28 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
2from typing import Dict
4import numpy as np
5from numpy.typing import NDArray
7from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator
8from .base_loss_function import LossFunction
10logger = logging.getLogger(__name__)
13class SquaredLoss(LossFunction):
15 r"""Class object for obtaining the squared loss function and
16 gradient from a given :ref:`residual_calculator <residual_calculators>`.
18 This loss function can be written as :math:`L(r(x, d)) = 0.5 r(x, d)^2`, where :math:`r` is the
19 residual, a function of :math:`x`, the optimization coefficients, and :math:`d`, the data.
20 The gradient with respect to :math:`x` is then :math:`\frac{\partial r}{\partial x}`.
21 The partial derivative of :math:`r` with respect to :math:`x` is the responsibility
22 of the :attr:`residual_calculator` to compute.
24 Generally speaking, the squared loss function is easy to compute and has a well-behaved
25 gradient, but it is not robust against outliers in the data. Using weights to normalize
26 residuals by the variance can mitigate this somewhat.
28 Parameters
29 ----------
30 residual_calculator : ResidualCalculator
31 The :ref:`residual calculator instance <residual_calculators>` from which the
32 residuals, weights, and gradient terms are obtained.
33 use_weights : bool
34 Whether to use weighting in the computation of the residual norm and gradient.
35 Default is ``False``.
36 preconditioner : np.ndarray
37 A preconditioner to be applied to the gradient. Must have the same shape as
38 :attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication.
39 residual_norm_multiplier : float
40 A multiplier that is applied to the residual norm and gradient. Useful in cases where
41 a very small or large loss function value changes the optimizer behaviour.
42 """
44 def __init__(self,
45 residual_calculator: ResidualCalculator,
46 use_weights: bool = False,
47 preconditioner: NDArray[float] = None,
48 residual_norm_multiplier: float = 1):
49 super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier)
51 def _get_residual_norm_internal(self, get_gradient: bool = False, gradient_part: str = None) -> Dict:
52 """ Gets the residual norm, and if needed,
53 the gradient, using the attached :attr:`residual_calculator`.
55 Parameters
56 ----------
57 get_gradient
58 Whether to return the gradient. Default is ``False``.
59 gradient_part
60 Used for the zonal harmonics resonstructions to determine what part of the gradient is
61 being calculated. Default is None.
63 Returns
64 -------
65 A ``dict`` with two entries, ``residual_norm`` and ``gradient``.
66 If ``get_gradient`` is false, its value will be ``None``.
67 """
68 residual_calculator_output = self._residual_calculator.get_residuals(
69 get_gradient=get_gradient, get_weights=self._use_weights, gradient_part=gradient_part)
70 residuals = residual_calculator_output['residuals']
71 if self.use_weights:
72 # weights (1/variance) need to be applied since they depend on the loss function
73 residual_norm = 0.5 * np.einsum(
74 'ijkh, ijkh, ijkh -> ...', residuals, residuals, residual_calculator_output['weights'])
75 else:
76 residual_norm = 0.5 * np.einsum(
77 'ijkh, ijkh -> ...', residuals, residuals)
79 if residual_norm < 1:
80 logger.warning(f'The residual norm value ({residual_norm}) is < 1.'
81 ' Note that some optimizers change their convergence criteria for'
82 ' loss functions < 1!')
84 return dict(residual_norm=residual_norm, gradient=residual_calculator_output['gradient'])
86 def get_estimate_of_lifschitz_constant(self) -> float:
87 """
88 Calculate an estimate of the Lifschitz constant of this cost function. Used to determine a
89 safe step-size for certain optimization algorithms.
91 Returns
92 -------
93 lifschitz_constant
94 Lifschitz constant.
95 """
96 matrix_norm = self._residual_calculator.get_estimate_of_matrix_norm()
97 return 2 / matrix_norm
99 @property
100 def _function_as_str(self) -> str:
101 """ Should return a string representation of the associated loss function. """
102 return 'L(r) = r ** 2'
104 @property
105 def _function_as_tex(self) -> str:
106 """ Should return a string representation of the associated loss function
107 in MathJax-renderable TeX."""
108 return r'$L(r) = r^2$'