import logging
import numpy as np
from numpy.typing import NDArray
from mumott import DataContainer
from mumott.core.wigner_d_utilities import (
    load_d_matrices, calculate_sph_coefficients_rotated_around_z,
    calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x,
    calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x,
    calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle)
from mumott.core.hashing import list_to_hash
from mumott.methods.projectors.base_projector import Projector
from mumott.methods.basis_sets.spherical_harmonics import SphericalHarmonics
from .base_residual_calculator import ResidualCalculator
logger = logging.getLogger(__name__)
[docs]class ZHTTResidualCalculator(ResidualCalculator):
    r"""Class that implements the gradient calculations for a model that uses a
    :class:`SphericalHarmonics` basis set restricted to zonal harmonics parametrized
    by a primary axis with polar coordinates :math:`\theta_0` and :math:`\phi_0`
    ,defined as:
    .. math::
        \begin{pmatrix} x_0\\ y_0\\ z_0\end{pmatrix}
        = \begin{pmatrix}
        \sin(\theta_0) \sin(\phi_0) \\
        \sin(\theta_0) \cos(\phi_0) \\
        \cos(\theta_0)
        \end{pmatrix}
    This model is equivalent to the one used in [Liebi2015]_, but uses a different approach to
    computation.
    This implementation avoids doing some of the expensive calculations of trigonometric functions and
    Legendre polynomials by doing the rotation in the space of the spherical harmonics using
    `Wigner (small) d-matrices <https://en.wikipedia.org/wiki/Wigner_D-matrix>`_.
    The forward model only involves a small number of trigonometric functions to evaluate the
    :math:`d_z(\text{angle})` matrices for the :math:`\theta` and :math:`\phi` rotations.
    Everything else is expressed as matrix products with precomputed matrices.
    The full forward model may be written as:
    .. math::
        \boldsymbol{I} =
        \boldsymbol{W} \boldsymbol{P}
        \boldsymbol{d}_z(\phi_0) \boldsymbol{d}_y(\frac{\pi}{4})^T
        \boldsymbol{d}_z(-\theta_0) \boldsymbol{d}_y(\frac{\pi}{4})
        \boldsymbol{a}'_{l0},
    where :math:`\boldsymbol{W}` is the mapping from spherical harmonic modes to detector segments,
    which can be precomputed.
    :math:`\boldsymbol{P}` is the typical projector from normal 3D tomography and
    :math:`\boldsymbol{d}_i(\text{angle})` with :math:`i = x,y,z`  are Wigner (small) d matrices
    for real spherical harmonics. :math:`\theta_0`, :math:`\phi_0`, and :math:`\boldsymbol{a}_{l0}`
    are the model parameters for each voxel.
    Derivatives are easy to evaluate because the angles only appear in the
    :math:`\boldsymbol{d}_z(\text{angle})`-matrices. All the expensive trigonometric and spherical
    harmonics calculations have been put into the precomputation of :math:`\boldsymbol{W}`
    and :math:`\boldsymbol{d}_y(\frac{\pi}{4})`.
    Parameters
    ----------
    data_container : DataContainer
        Container holding the data to be reconstructed.
    basis_set : SphericalHarmonics
        The basis set used for representing spherical functions.
    projector : Projector
        The type of projector used together with this method.
    """
    def __init__(self,
                 data_container: DataContainer,
                 basis_set: SphericalHarmonics,
                 projector: Projector):
        super().__init__(data_container, basis_set, projector)
        self._make_matrices()
        self._make_starting_guess()
    def _make_starting_guess(self) -> None:
        """Initializes the optimization parameters by setting the
        zonal coefficients to zero and randomizing the angles, which
        corresponds to sampling directions uniformly on the unit
        sphere.
        """
        volume_shape = self._projector.volume_shape
        self._zonal_coefficients = np.zeros((*volume_shape, self._basis_set.ell_max // 2 + 1))
        # Make random orientations by random sampling in 3D
        rng = np.random.default_rng()
        self._theta = np.arccos(rng.uniform(low=0, high=1, size=volume_shape))
        self._phi = rng.uniform(low=-np.pi, high=np.pi, size=volume_shape)
    def _make_matrices(self) -> None:
        """
        Loads Wigner d-matrices and creates the mapping from parameters
        to spherical harmonics coefficients.
        """
        # Load precomputed d-matrices
        ell_max = self._basis_set.ell_max
        self.d_matrices = load_d_matrices(ell_max)
        # Set up matrix for converting from zonal harmonics to full harmonics space
        ell_list = self._basis_set.ell_indices
        m_list = self._basis_set.emm_indices
        self._E = np.zeros((len(ell_list), ell_max//2+1))
        for full_index, (ell, m) in enumerate(zip(ell_list, m_list)):
            if m == 0:
                self._E[full_index, ell//2] = 1
    @property
    def coefficients(self) -> NDArray:
        """Optimization coefficients for this method.
        Contains both the zonal coefficients and the angles.
        The first N-2 elements are zonal coefficients.
        The N-1th element is the polar angle and the last element is the azimuthal angle.
        """
        self._cast_angles_to_symmetric_zone()
        return np.concatenate((self._zonal_coefficients,
                               self._theta[..., np.newaxis],
                               self._phi[..., np.newaxis]), axis=3)
    @coefficients.setter
    def coefficients(self, val: NDArray) -> None:
        # Convert from external to internal representation of optimization parameters
        val = val.reshape((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 1 + 2))
        assert np.shape(val[..., :-2]) == np.shape(self._zonal_coefficients), \
            
'Shape of new array inconsistent with expectation (zonal_coefficients)'
        assert np.shape(val[..., -2]) == np.shape(self._theta), \
            
'Shape of new array inconsistent with expectation (theta)'
        assert np.shape(val[..., -1]) == np.shape(self._phi), \
            
'Shape of new array inconsistent with expectation (phi)'
        self._zonal_coefficients = val[..., :-2]
        self._theta = val[..., -2]
        self._phi = val[..., -1]
    def _rotate_coeffs(self) -> NDArray:
        """Expand from the zonal harmonics basis to a full spherical harmonics basis and
        rotate the spherical harmonics coefficients from the symmetric coordinate system
        to the sample xyz system.
        Returns
        -------
            Array containing the rotated spherical harmonics coefficients.
        """
        ell_list = np.arange(0, self._basis_set.ell_max + 1, 2)
        # Expand symmetric coefficients into full basis
        self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E)
        # Rotate by 90 degrees about x
        calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
            self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
        # Rotate by theta about z
        calculate_sph_coefficients_rotated_around_z(
            self._coefficients, self._theta, ell_list, output_array=self._coefficients)
        # Rotate by -90 degrees about x
        calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
            self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
        # Rotate by phi about z
        calculate_sph_coefficients_rotated_around_z(
            self._coefficients, self._phi, ell_list, output_array=self._coefficients)
        return self._coefficients
    def _rotate_and_derive(self):
        """
        Rotate spherical harmonics coefficients from the symmetric coordinate system
        to the sample xyz system and evaluate the derivative of the coefficients
        with respect to the two rotation angles.
        Returns
        ----------
        self._coefficients : NDArray
            Array containing the rotated spherical harmonics coefficients.
        theta_derivative : NDArray
            Rotated spherical coefficients derived with respect to the polar rotation angle
            evaluated at the current value of the rotation angles.
        phi_derivative : NDArray
            Rotated spherical coefficients derived with respect to the azimuthal rotation angle
            evaluated at the current value of the rotation angles.
        """
        ell_list = np.arange(0, self._basis_set.ell_max+1, 2)
        # Expand symmetric coefficients into full basis
        self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E)
        theta_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set)))
        phi_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set)))
        # Do 90 degree rotation around x
        calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
            self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
        # Do z rotation of Theta and derivative
        calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle(
            self._coefficients, self._theta, ell_list, output_array=theta_derivative)
        calculate_sph_coefficients_rotated_around_z(
            self._coefficients, self._theta, ell_list, output_array=self._coefficients)
        # Do -90 degree rotation around x
        calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
            self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
        calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
            theta_derivative, ell_list, self.d_matrices, output_array=theta_derivative)
        # Do z rotation of Phi
        calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle(
            self._coefficients, self._phi, ell_list, output_array=phi_derivative)
        calculate_sph_coefficients_rotated_around_z(
            self._coefficients, self._phi, ell_list, output_array=self._coefficients)
        calculate_sph_coefficients_rotated_around_z(
            theta_derivative, self._phi, ell_list, output_array=theta_derivative)
        return self._coefficients, theta_derivative, phi_derivative
    def _rotate_coeffs_inverse(self, coefficients: NDArray):
        """
        Rotate spherical harmonics coefficients from the sample xyz system
        to the symmetric coordinate system.
        """
        ell_list = np.arange(0, self._basis_set.ell_max+1, 2)
        # Do z rotation of -phi
        calculate_sph_coefficients_rotated_around_z(
            coefficients, -self._phi, ell_list, output_array=coefficients)
        # Do 90 degree rotation around x
        calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
            coefficients, ell_list, self.d_matrices, output_array=coefficients)
        # Do z rotation of -theta
        calculate_sph_coefficients_rotated_around_z(
            coefficients, -self._theta, ell_list, output_array=coefficients)
        # Do -90 degree rotation around x
        calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
            coefficients, ell_list, self.d_matrices, output_array=coefficients)
        return coefficients
[docs]    def get_residuals(self,
                      get_gradient: bool = False,
                      get_weights: bool = False,
                      gradient_part: str = 'full') -> dict[str, NDArray[float]]:
        """ Calculates the residuals and possibly the gradient of the residual square sum
        (without the factor of -2!) with respect to the parameters.
        The coefficients are projected using the :attr:`SphericalHarmonics` and :attr:`Projector`
        attached to this instance.
        Parameters
        ----------
        get_gradient
            Whether to return the gradient. Default is ``False``.
        get_weights
            Whether to return weights. Default is ``False``. If ``True`` along with
            :attr:`get_gradient`, the gradient will be computed with weights.
        gradient_part
            If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all
            parameters;
            if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed;
            if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal
            spherical harmonics coefficients is computed.
        Returns
        -------
            A dictionary containing the residuals, and possibly the
            gradient and/or weights. If gradient and/or weights
            are not returned, their value will be ``None``.
        """
        if not get_gradient:
            # Rotate the coefficients
            self._rotate_coeffs()
            # Project from voxel to detector space and from coefficient to angle space
            projection = self._basis_set.forward(
                self._projector.forward(self._coefficients.astype(self.dtype)))
            # Calculate residuals
            residuals = self._data - projection
            if get_weights:
                residuals *= self._weights
            output = {'residuals': residuals, 'gradient': None}
        elif get_gradient:
            output = self.get_gradient(get_weights=get_weights)
        # Pass on weights, if asked to
        if get_weights:
            output['weights'] = self._weights
        else:
            output['weights'] = None
        return output 
[docs]    def get_gradient(self,
                     get_weights: bool = False,
                     gradient_part: str = 'full') -> dict[str, NDArray[float]]:
        """ Calculates the gradient of *half* the sum of residuals squared.
        Parameters
        ----------
        get_gradient
            Whether to return the gradient. Default is ``False``.
        gradient_part
            If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all
            parameters;
            if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed;
            if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal
            spherical harmonics coefficients is computed.
        Returns
        -------
            A dictionary containing the residuals of the gradient. If only a part of the
            gradient is computed, the rest of the elements will be filled with zeros.
        """
        # initialize output array
        gradient = np.zeros((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 3))
        # If only the coefficients are needed, do not evaluate the derivatives.
        if gradient_part == 'coefficients':
            coefficients = self._rotate_coeffs()
        else:
            coefficients, theta_derivative, phi_derivative = self._rotate_and_derive()
        # Project from voxel to detector space and the from coeff-space to angle-space
        projection = self._basis_set.forward(self._projector.forward(coefficients.astype(self.dtype)))
        # Calculate residuals
        residuals = self._data - projection
        if get_weights:
            residuals *= self._weights
        # Backproject residual
        bp_res = self._projector.adjoint(
                    self._basis_set.gradient(residuals * self._weights).astype(self.dtype))
        # If the gradient with respect to angles is needed, compute the inner products
        if gradient_part in ['full', 'angles']:
            gradient[:, :, :, -2] = -np.einsum('xyzm,xyzm->xyz', bp_res, theta_derivative)
            gradient[:, :, :, -1] = -np.einsum('xyzm,xyzm->xyz', bp_res, phi_derivative)
        if gradient_part == 'full' or gradient_part == 'coefficients':
            # back-rotate coefficients
            bp_res = self._rotate_coeffs_inverse(bp_res)
            gradient[..., :-2] += -np.einsum('...i,ij->...j', bp_res, self._E)
        return {'residuals': residuals, 'gradient': gradient} 
    def _cast_angles_to_symmetric_zone(self):
        r"""
        Casts internal angle arrays into the range :math:`\theta \in [0, \phi/2[` and
        :math:`\phi \in [0, 2\phi[`.
        """
        self._theta = self._theta % np.pi
        southern_hemisphere = self._theta > (np.pi / 2)
        self._theta[southern_hemisphere] = np.pi - self._theta[southern_hemisphere]
        self._phi[southern_hemisphere] = self._phi[southern_hemisphere] + np.pi
        self._phi = self._phi % (2 * np.pi)
    @property
    def rotated_coefficients(self):
        """
        Returns the real spherical harmonics coefficients.
        """
        return self._rotate_coeffs()
    @property
    def directions(self):
        """
        Returns the direction of symmetry as a unit vector in in xyz coordinates.
        The vector index is the last index of the output.
        """
        # Make unit direction vectors
        directions = np.stack((np.cos(self._phi)*np.sin(self._theta),
                               np.sin(self._phi)*np.sin(self._theta),
                               np.cos(self._phi)), axis=-1)
        return directions
    @property
    def ell_max(self) -> int:
        """l max"""
        return self._basis_set.ell_max
    @property
    def volume_shape(self) -> int:
        """Shape of voxel volume"""
        return self._projector.volume_shape
    def _update(self, force_update: bool = False) -> None:
        """ Carries out necessary updates if anything changes with respect to
        the geometry or basis set. """
        if not (self.is_dirty or force_update):
            return
        self._basis_set.probed_coordinates = self.probed_coordinates
        # See ell_max changed
        old_ellmax = (self._zonal_coefficients.shape[-1] - 1) * 2
        old_num_coeffs = (old_ellmax + 1) * (old_ellmax + 2) // 2
        len_diff = len(self._basis_set) - old_num_coeffs
        vol_diff = self._data_container.geometry.volume_shape - np.array(self._coefficients.shape[:-1])
        # TODO: Think about whether the ``Method`` should do this or handle it differently
        if len_diff != 0 and not np.any(vol_diff != 0):
            logger.warning('ell_max has changed. Coefficients will be truncated or appended with zeros.')
            self._make_matrices()
            old_params = np.array(self._zonal_coefficients)
            self._zonal_coefficients = np.zeros((*self._projector.volume_shape,
                                                 self._basis_set.ell_max // 2 + 1))
            if len_diff > 0:
                self._zonal_coefficients[:, :, :, :old_params.shape[-1]] = old_params
                self._coefficients = np.zeros((*self._data_container.geometry.volume_shape,
                                               len(self._basis_set)), dtype=self.dtype)
            if len_diff < 0:
                self._zonal_coefficients = old_params[:, :, :, :self._zonal_coefficients.shape[-1]]
                self._coefficients = np.zeros((*self._data_container.geometry.volume_shape,
                                               len(self._basis_set)), dtype=self.dtype)
        elif np.any(vol_diff != 0):
            logger.warning('Volume shape has changed.'
                           ' Coefficients have been reset to zero and angles have been randomized.')
            self._make_matrices()
            self._random_starting_guess()
        self._geometry_hash = hash(self._data_container.geometry)
        self._basis_set_hash = hash(self._basis_set)
    def __hash__(self) -> int:
        """ Returns a hash of the current state of this instance. """
        to_hash = [self._zonal_coefficients,
                   self._theta,
                   self._phi,
                   hash(self._projector),
                   hash(self._data_container.geometry),
                   self._basis_set_hash,
                   self._geometry_hash]
        return int(list_to_hash(to_hash), 16)
    @property
    def is_dirty(self) -> bool:
        """ ``True`` if stored hashes of geometry or basis set objects do
        not match their current hashes. Used to trigger updates """
        return ((self._geometry_hash != hash(self._data_container.geometry)) or
                (self._basis_set_hash != hash(self._basis_set)))
    def __str__(self) -> str:
        wdt = 74
        s = []
        s += ['=' * wdt]
        s += [self.__class__.__name__.center(wdt)]
        s += ['-' * wdt]
        with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
            s += ['{:18} : {}'.format('BasisSet', self._basis_set.__class__.__name__)]
            s += ['{:18} : {}'.format('Projector', self._projector.__class__.__name__)]
            s += ['{:18} : {}'.format('is_dirty', self.is_dirty)]
            s += ['{:18} : {}'.format('probed_coordinates (hash)',
                                      hex(hash(self.probed_coordinates))[2:8])]
            s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
        s += ['-' * wdt]
        return '\n'.join(s)
    def _repr_html_(self) -> str:
        s = []
        s += [f'<h3>{self.__class__.__name__}</h3>']
        s += ['<table border="1" class="dataframe">']
        s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
        s += ['<tbody>']
        with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
            s += ['<tr><td style="text-align: left;">BasisSet</td>']
            s += [f'<td>{1}</td><td>{self._basis_set.__class__.__name__}</td></tr>']
            s += ['<tr><td style="text-align: left;">Projector</td>']
            s += [f'<td>{len(self._projector.__class__.__name__)}</td>'
                  f'<td>{self._projector.__class__.__name__}</td></tr>']
            s += ['<tr><td style="text-align: left;">Is dirty</td>']
            s += [f'<td>{1}</td><td>{self.is_dirty}</td></tr>']
            s += ['<tr><td style="text-align: left;">probed_coordinates</td>']
            s += [f'<td>{self.probed_coordinates.vector.shape}</td>'
                  f'<td>{hex(hash(self.probed_coordinates))[2:8]} (hash)</td></tr>']
            s += ['<tr><td style="text-align: left;">Hash</td>']
            h = hex(hash(self))
            s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
        s += ['</tbody>']
        s += ['</table>']
        return '\n'.join(s)