import numpy as np
import tqdm
import logging
from mumott import DataContainer
from mumott.optimization.optimizers.base_optimizer import Optimizer
from mumott.optimization.loss_functions.base_loss_function import LossFunction
from mumott.methods.basis_sets.base_basis_set import BasisSet
from numpy.typing import NDArray
from mumott.methods.residual_calculators import GradientResidualCalculator
from mumott.optimization.loss_functions import SquaredLoss
from mumott.optimization.regularizers.group_lasso import GroupLasso
from mumott.methods.utilities.preconditioning import get_largest_eigenvalue
from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
from mumott.methods.basis_sets import SphericalHarmonics
logger = logging.getLogger(__name__)
[docs]def run_group_lasso(data_container: DataContainer,
                    regularization_parameter: float,
                    step_size_parameter: float = None,
                    x0: NDArray[float] = None,
                    basis_set: BasisSet = None,
                    ell_max: int = 8,
                    use_gpu: bool = False,
                    maxiter: int = 100,
                    enforce_non_negativity: bool = False,
                    no_tqdm: bool = False,
                    ):
    """A reconstruction pipeline to do least squares reconstructions regularized with the group-lasso
    regularizer and solved with the Iterative Soft-Thresholding Algorithm (ISTA), a proximal gradient
    decent method. This reconstruction automatically masks out voxels with zero scattering but
    needs the regularization weight as input.
    Parameters
    ----------
    data_container
        The :class:`DataContainer <mumott.data_handling.DataContainer>`
        containing the data set of interest.
    regularization_parameter
        Scalar weight of the regularization term. Should be optimized by performing reconstructions
        for a range of possible values.
    step_size_parameter
        Step size parameter of the reconstruction. If no value is given, a largest-safe
        value is estimated.
    x0
        Starting guess for the solution. By default (``None``) the coefficients are initialized with
        zeros.
    basis_set
        Optionally a basis set can be specified. By default (``None``)
        :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>` is used.
    ell_max
        If no basis set is given, this is the maximum spherical harmonics order used in the
        generated basis set.
    use_gpu
        Whether to use GPU resources in computing the projections.
        Default is ``False``. If set to ``True``, the method will use
        :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
    maxiter
        Number of iterations for the ISTA optimization. No stopping rules are implemented.
    enforce_non_negativity
        Whether or not to enforce non-negativitu of the solution coefficients.
    no_tqdm:
        Flag whether ot not to print a progress bar for the reconstruction.
    """
    if use_gpu:
        Projector = SAXSProjectorCUDA
    else:
        Projector = SAXSProjector
    projector = Projector(data_container.geometry)
    if basis_set is None:
        basis_set = SphericalHarmonics(ell_max=ell_max)
    if step_size_parameter is None:
        logger.info('Calculating step size parameter.')
        matrix_norm = get_largest_eigenvalue(basis_set, projector)
        step_size_parameter = 0.5 / matrix_norm
    loss_function = SquaredLoss(GradientResidualCalculator(data_container, basis_set, projector))
    reg_term = GroupLasso(regularization_parameter, step_size_parameter)
    optimizer = _ISTA(loss_function, reg_term, step_size_parameter, x0=x0, maxiter=maxiter,
                      enforce_non_negativity=enforce_non_negativity, no_tqdm=no_tqdm)
    opt_coeffs = optimizer.optimize()
    result = dict(result={'x': opt_coeffs}, optimizer=optimizer, loss_function=loss_function,
                  regularizer=reg_term, basis_set=basis_set, projector=projector)
    return result 
class _ISTA(Optimizer):
    """Internal optimizer class for the group lasso pipeline. Implements
    <mumott.optimization.optimizers.base_optimizer.Optimizer>.
    Parameters
    ----------
    loss_function : LossFunction
        The differentiable part of the :ref:`loss function <loss_functions>`
        to be minimized using this algorithm.
    reg_term : GroupLasso
        Non-differentiable regularization term to be applied in every iteration.
        Must have a `proximal_operator` method.
    step_size_parameter : float
        Step size for the differentiable part of the optimization.
    maxiter : int
        Maximum number of iterations. Default value is `50`.
    enforce_non_negativity : bool
        If `True`, forces all coefficients to be greater than `0` at the end of every iteration.
        Default value is `False`.
    Notes
    -----
    Valid entries in :attr:`kwargs` are
        x0
            Initial guess for solution vector. Must be the same size as
            :attr:`residual_calculator.coefficients`. Defaults to :attr:`loss_function.initial_values`.
    """
    def __init__(self, loss_function: LossFunction, reg_term: GroupLasso, step_size_parameter: float,
                 maxiter: int = 50, enforce_non_negativity: bool = False, **kwargs):
        super().__init__(loss_function, **kwargs)
        self._maxiter = maxiter
        self._reg_term = reg_term
        self._step_size_parameter = step_size_parameter
        self.error_function_history = []
        self._enforce_non_negativity = enforce_non_negativity
    def ISTA_step(self, coefficients):
        d = self._loss_function.get_loss(coefficients, get_gradient=True)
        gradient = d['gradient']
        total_loss = d['loss'] +\
            self._reg_term.get_regularization_norm(coefficients)['regularization_norm']
        coefficients = coefficients - self._step_size_parameter * gradient
        coefficients = self._reg_term.proximal_operator(coefficients)
        if self._enforce_non_negativity:
            np.clip(coefficients, 0, None, out=coefficients)
        return coefficients, total_loss
    def optimize(self):
        coefficients = self._loss_function.initial_values
        if 'x0' in self._options.keys():
            if self['x0'] is not None:
                coefficients = self['x0']
        # Calculate total loss
        loss_function_output = self._loss_function.get_loss(coefficients)
        reg_term_output = self._reg_term.get_regularization_norm(coefficients)
        total_loss = loss_function_output['loss'] + reg_term_output['regularization_norm']
        #  Toggle between printing an error bar or not
        if not self._no_tqdm:
            iterator = tqdm.tqdm(range(self._maxiter))
            iterator.set_description(f'Loss = {total_loss:.2E}')
        elif self._no_tqdm:
            iterator = range(self._maxiter)
        for ii in iterator:
            # Do step
            coefficients, total_loss = self.ISTA_step(coefficients)
            # Update progress bar
            self.error_function_history.append(total_loss)
            if not self._no_tqdm:
                iterator.set_description(f'Loss = {total_loss:.2E}')
        return np.array(coefficients)