import logging
from typing import Dict
import numpy as np
from numpy.typing import NDArray
from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator
from .base_loss_function import LossFunction
logger = logging.getLogger(__name__)
[docs]class SquaredLoss(LossFunction):
    r"""Class object for obtaining the squared loss function and
    gradient from a given :ref:`residual_calculator <residual_calculators>`.
    This loss function can be written as :math:`L(r(x, d)) = 0.5 r(x, d)^2`, where :math:`r` is the
    residual, a function of :math:`x`, the optimization coefficients, and :math:`d`, the data.
    The gradient with respect to :math:`x` is then :math:`\frac{\partial r}{\partial x}`.
    The partial derivative of :math:`r` with respect to :math:`x` is the responsibility
    of the :attr:`residual_calculator` to compute.
    Generally speaking, the squared loss function is easy to compute and has a well-behaved
    gradient, but it is not robust against outliers in the data. Using weights to normalize
    residuals by the variance can mitigate this somewhat.
    Parameters
    ----------
    residual_calculator : ResidualCalculator
        The :ref:`residual calculator instance <residual_calculators>` from which the
        residuals, weights, and gradient terms are obtained.
    use_weights : bool
        Whether to use weighting in the computation of the residual norm and gradient.
        Default is ``False``.
    preconditioner : np.ndarray
        A preconditioner to be applied to the gradient. Must have the same shape as
        :attr:`residual_calculator.coefficients` or it must be possible to broadcast by multiplication.
    residual_norm_multiplier : float
        A multiplier that is applied to the residual norm and gradient. Useful in cases where
        a very small or large loss function value changes the optimizer behaviour.
    """
    def __init__(self,
                 residual_calculator: ResidualCalculator,
                 use_weights: bool = False,
                 preconditioner: NDArray[float] = None,
                 residual_norm_multiplier: float = 1):
        super().__init__(residual_calculator, use_weights, preconditioner, residual_norm_multiplier)
    def _get_residual_norm_internal(self, get_gradient: bool = False, gradient_part: str = None) -> Dict:
        """ Gets the residual norm, and if needed,
        the gradient, using the attached :attr:`residual_calculator`.
        Parameters
        ----------
        get_gradient
            Whether to return the gradient. Default is ``False``.
        gradient_part
            Used for the zonal harmonics resonstructions to determine what part of the gradient is
            being calculated. Default is None.
        Returns
        -------
             A ``dict`` with two entries, ``residual_norm`` and ``gradient``.
             If ``get_gradient`` is false, its value will be ``None``.
        """
        residual_calculator_output = self._residual_calculator.get_residuals(
            get_gradient=get_gradient, get_weights=self._use_weights, gradient_part=gradient_part)
        residuals = residual_calculator_output['residuals']
        if self.use_weights:
            # weights (1/variance) need to be applied since they depend on the loss function
            residual_norm = 0.5 * np.einsum(
                'ijkh, ijkh, ijkh -> ...', residuals, residuals, residual_calculator_output['weights'])
        else:
            residual_norm = 0.5 * np.einsum(
                'ijkh, ijkh -> ...', residuals, residuals)
        if residual_norm < 1:
            logger.warning(f'The residual norm value ({residual_norm}) is < 1.'
                           ' Note that some optimizers change their convergence criteria for'
                           ' loss functions < 1!')
        return dict(residual_norm=residual_norm, gradient=residual_calculator_output['gradient'])
[docs]    def get_estimate_of_lifschitz_constant(self) -> float:
        """
        Calculate an estimate of the Lifschitz constant of this cost function. Used to determine a
        safe step-size for certain optimization algorithms.
        Returns
        -------
        lifschitz_constant
            Lifschitz constant.
        """
        matrix_norm = self._residual_calculator.get_estimate_of_matrix_norm()
        return 2 / matrix_norm 
    @property
    def _function_as_str(self) -> str:
        """ Should return a string representation of the associated loss function. """
        return 'L(r) = r ** 2'
    @property
    def _function_as_tex(self) -> str:
        """ Should return a string representation of the associated loss function
        in MathJax-renderable TeX."""
        return r'$L(r) = r^2$'