import logging
import numpy as np
from mumott.data_handling import DataContainer
from mumott.data_handling.utilities import get_absorbances
from mumott.methods.basis_sets import TrivialBasis, GaussianKernels
from mumott.methods.residual_calculators import GradientResidualCalculator
from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
from mumott.methods.utilities import get_sirt_weights, get_sirt_preconditioner
from mumott.optimization.loss_functions import SquaredLoss
from mumott.optimization.optimizers import GradientDescent
logger = logging.getLogger(__name__)
[docs]def run_sirt(data_container: DataContainer,
             use_absorbances: bool = True,
             use_gpu: bool = False,
             maxiter: int = 20,
             enforce_non_negativity: bool = False,
             **kwargs):
    """A reconstruction pipeline for the :term:`SIRT` algorithm, which uses
    a gradient preconditioner and a set of weights for the projections
    to achieve fast convergence. Generally, one varies the number of iterations
    until a good reconstruction is obtained.
    Advanced users may wish to also modify the ``preconditioner_cutoff`` and
    ``weights_cutoff`` keyword arguments.
    Parameters
    ----------
    data_container
        The :class:`DataContainer <mumott.data_handling.DataContainer>`
        from loading the data set of interest.
    use_absorbances
        If ``True``, the reconstruction will use the absorbances
        calculated from the diode, or absorbances provided via a keyword argument.
        If ``False``, the data in :attr:`data_container.data` will be used.
    use_gpu
        Whether to use GPU resources in computing the projections.
        Default is ``False``. If set to ``True``, the method will use
        :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
    maxiter
        Maximum number of iterations for the gradient descent solution.
    enforce_non_negativity
        Enforces strict positivity on all the coefficients. Should only be used
        with local or scalar representations. Default value is ``False``.
    kwargs
        Miscellaneous keyword arguments. See notes for details.
    Notes
    -----
    Many options can be specified through ``kwargs``. These include:
        Projector
            The :ref:`projector class <projectors>` to use.
        preconditioner_cutoff
            The cutoff to use when computing the :term:`SIRT` preconditioner.
            Default value is ``0.1``,
            which will lead to a roughly ellipsoidal mask.
        weights_cutoff
            The cutoff to use when computing the :term:`SIRT` weights.
            Default value is ``0.1``,
            which will clip some projection edges.
        absorbances
            If :attr:`use_absorbances` is set to ``True``, these absorbances
            will be used instead of ones calculated from the diode.
        BasisSet
            The :ref:`basis set class <basis_sets>` to use. If not provided
            :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>`
            will be used for absorbances and
            :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>`
            for other data.
        basis_set_kwargs
            Keyword arguments for :attr:`BasisSet`.
        no_tqdm
            Used to avoid a ``tqdm`` progress bar in the optimizer.
    """
    if 'Projector' in kwargs:
        Projector = kwargs.pop('Projector')
    else:
        if use_gpu:
            Projector = SAXSProjectorCUDA
        else:
            Projector = SAXSProjector
    projector = Projector(data_container.geometry)
    preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1)
    weights_cutoff = kwargs.get('weights_cutoff', 0.1)
    preconditioner = get_sirt_preconditioner(projector, cutoff=preconditioner_cutoff)
    sirt_weights = get_sirt_weights(projector, cutoff=weights_cutoff)
    # Save previous weights to avoid accumulation.
    old_weights = data_container.projections.weights.copy()
    # Respect previous masking in data container
    data_container.projections.weights = sirt_weights * np.ceil(data_container.projections.weights)
    if use_absorbances:
        if 'absorbances' in kwargs:
            absorbances = kwargs.pop('absorbances')
        else:
            abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True)
            absorbances = abs_dict['absorbances']
            transmittivity_cutoff_mask = abs_dict['cutoff_mask']
            data_container.projections.weights *= transmittivity_cutoff_mask
    else:
        absorbances = None
    basis_set_kwargs = kwargs.get('basis_set_kwargs', dict())
    if 'BasisSet' in kwargs:
        BasisSet = kwargs.pop('BasisSet')
    else:
        if use_absorbances:
            BasisSet = TrivialBasis
            if 'channels' not in basis_set_kwargs:
                basis_set_kwargs['channels'] = 1
        else:
            BasisSet = GaussianKernels
            basis_set_kwargs['grid_scale'] = (data_container.projections.data.shape[-1]) // 2 + 1
    basis_set = BasisSet(**basis_set_kwargs)
    residual_calculator_kwargs = dict(use_scalar_projections=use_absorbances,
                                      scalar_projections=absorbances)
    residual_calculator = GradientResidualCalculator(data_container,
                                                     basis_set,
                                                     projector,
                                                     **residual_calculator_kwargs)
    loss_function_kwargs = dict(use_weights=True, preconditioner=preconditioner)
    loss_function = SquaredLoss(residual_calculator,
                                **loss_function_kwargs)
    optimizer_kwargs = dict(maxiter=maxiter)
    optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', False)
    optimizer_kwargs['enforce_non_negativity'] = enforce_non_negativity
    optimizer = GradientDescent(loss_function,
                                **optimizer_kwargs)
    result = optimizer.optimize()
    weights = data_container.projections.weights.copy()
    data_container.projections.weights = old_weights
    return dict(result=result, optimizer=optimizer, loss_function=loss_function,
                residual_calculator=residual_calculator, basis_set=basis_set, projector=projector,
                absorbances=absorbances, weights=weights)