# Source code for mumott.optimization.loss_functions.squared_loss

```
import logging
from typing import Dict
import numpy as np
from numpy.typing import NDArray
from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator
from .base_loss_function import LossFunction
logger = logging.getLogger(__name__)
[docs]class SquaredLoss(LossFunction):
r"""Class object for obtaining the squared loss function and
gradient from a given :ref:`residual_calculator <residual_calculators>`.
This loss function can be written as :math:`L(r(x, d)) = 0.5 r(x, d)^2`, where :math:`r` is the
residual, a function of :math:`x`, the optimization coefficients, and :math:`d`, the data.
The gradient with respect to :math:`x` is then :math:`\frac{\partial r}{\partial x}`.
The partial derivative of :math:`r` with respect to :math:`x` is the responsibility
of the :attr:`residual_calculator` to compute.
Generally speaking, the squared loss function is easy to compute and has a well-behaved
gradient, but it is not robust against outliers in the data. Using weights to normalize
residuals by the variance can mitigate this somewhat.
Parameters
----------
residual_calculator : ResidualCalculator
The :ref:`residual calculator instance <residual_calculators>` from which the
residuals, weights, and gradient terms are obtained.
use_weights : bool
Whether to use weighting in the computation of the residual norm and gradient.
Default is ``False``.
preconditioner : np.ndarray
A preconditioner to be applied to the gradient. Must have the same shape as
:attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication.
residual_norm_multiplier : float
A multiplier that is applied to the residual norm and gradient. Useful in cases where
a very small or large loss function value changes the optimizer behaviour.
"""
def __init__(self,
residual_calculator: ResidualCalculator,
use_weights: bool = False,
preconditioner: NDArray[float] = None,
residual_norm_multiplier: float = 1):
super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier)
def _get_residual_norm_internal(self, get_gradient: bool = False, gradient_part: str = None) -> Dict:
""" Gets the residual norm, and if needed,
the gradient, using the attached :attr:`residual_calculator`.
Parameters
----------
get_gradient
Whether to return the gradient. Default is ``False``.
gradient_part
Used for the zonal harmonics resonstructions to determine what part of the gradient is
being calculated. Default is None.
Returns
-------
A ``dict`` with two entries, ``residual_norm`` and ``gradient``.
If ``get_gradient`` is false, its value will be ``None``.
"""
residual_calculator_output = self._residual_calculator.get_residuals(
get_gradient=get_gradient, get_weights=self._use_weights, gradient_part=gradient_part)
residuals = residual_calculator_output['residuals']
if self.use_weights:
# weights (1/variance) need to be applied since they depend on the loss function
residual_norm = 0.5 * np.einsum(
'ijkh, ijkh, ijkh -> ...', residuals, residuals, residual_calculator_output['weights'])
else:
residual_norm = 0.5 * np.einsum(
'ijkh, ijkh -> ...', residuals, residuals)
if residual_norm < 1:
logger.warning(f'The residual norm value ({residual_norm}) is < 1.'
' Note that some optimizers change their convergence criteria for'
' loss functions < 1!')
return dict(residual_norm=residual_norm, gradient=residual_calculator_output['gradient'])
[docs] def get_estimate_of_lifschitz_constant(self) -> float:
"""
Calculate an estimate of the Lifschitz constant of this cost function. Used to determine a
safe step-size for certain optimization algorithms.
Returns
-------
lifschitz_constant
Lifschitz constant.
"""
matrix_norm = self._residual_calculator.get_estimate_of_matrix_norm()
return 2 / matrix_norm
@property
def _function_as_str(self) -> str:
""" Should return a string representation of the associated loss function. """
return 'L(r) = r ** 2'
@property
def _function_as_tex(self) -> str:
""" Should return a string representation of the associated loss function
in MathJax-renderable TeX."""
return r'$L(r) = r^2$'
```