Source code for mumott.pipelines.reconstruction.sigtt

import logging

from mumott.data_handling import DataContainer
from mumott.methods.basis_sets import SphericalHarmonics
from mumott.methods.residual_calculators import GradientResidualCalculator
from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
from mumott.optimization.loss_functions import SquaredLoss
from mumott.optimization.optimizers import LBFGS
from mumott.optimization.regularizers import Laplacian

logger = logging.getLogger(__name__)


[docs]def run_sigtt(data_container: DataContainer, use_gpu: bool = False, maxiter: int = 20, ftol: float = 1e-2, regularization_weight: float = 1e-4, **kwargs): """A reconstruction pipeline for the :term:`SIGTT` algorithm, which uses a gradient and a regularizer to accomplish reconstruction. Parameters ---------- data_container The :class:`DataContainer <mumott.data_handling.DataContainer>` from loading the data set of interest. use_gpu Whether to use GPU resources in computing the projections. Default is ``False``. If set to ``True``, the method will use :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. maxiter Maximum number of iterations for the gradient descent solution. ftol Tolerance for the change in the loss function. Default is ``None``, in which case the reconstruction will terminate once the maximum number of iterations have been performed. regularization_weight Regularization weight for the default :class:`Laplacian <mumott.optimization.regularizers.Laplacian>` regularizer. Ignored if a loss function is provided. kwargs Miscellaneous keyword arguments. See notes for details. Notes ----- Many options can be specified through :attr:`kwargs`. Miscellaneous ones are passed to the optimizer. Specific keywords include: Projector The :ref:`projector class <projectors>` to use. BasisSet The :ref:`basis set class <basis_sets>` to use. If not provided :class:`SphericalHarmonics <mumott.methods.basis_sets.SphericalHarmonics>` will be used. basis_set_kwargs Keyword arguments for :attr:`BasisSet`. ResidualCalculator The :ref:`residual_calculator class <residual_calculators>` to use. If not provided :class:`GradientResidualCalculator <mumott.methods.residual_calculators.GradientResidualCalculator>` will be used. residual_calculator_kwargs Keyword arguments for :attr:`ResidualCalculator`. LossFunction The :ref:`loss function class <loss_functions>` to use. If not provided :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>` will be used. loss_function_kwargs Keyword arguments for :attr:`LossFunction`. Regularizers A list of dictionaries with three entries, a name (``str``), a :ref:`regularizer object <regularizers>`, and a regularization weight (``float``); used by :func:`loss_function.add_regularizer() <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`. By default, a :class:`Laplacian <mumott.optimization.regularizers.Laplacian>` with the weight :attr:`regularization_weight` will be used. If other regularizers are specified, this will be overridden. Optimizer The :ref:`optimizer class <optimizers>` to use. If not provided :class:`LBFGS <mumott.optimization.optimizers.LBFGS>` will be used. optimizer_kwargs Keyword arguments for :attr:`Optimizer`. """ if 'Projector' in kwargs: Projector = kwargs.pop('Projector') else: if use_gpu: Projector = SAXSProjectorCUDA else: Projector = SAXSProjector projector = Projector(data_container.geometry) basis_set_kwargs = kwargs.get('basis_set_kwargs', dict()) if 'BasisSet' in kwargs: BasisSet = kwargs.pop('BasisSet') else: if 'ell_max' not in basis_set_kwargs: basis_set_kwargs['ell_max'] = 2 * ((data_container.data.shape[-1] - 1) // 2) BasisSet = SphericalHarmonics basis_set = BasisSet(**basis_set_kwargs) ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator) residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict()) residual_calculator = ResidualCalculator(data_container, basis_set, projector, **residual_calculator_kwargs) Regularizers = kwargs.get('Regularizers', [dict(name='laplacian', regularizer=Laplacian(), regularization_weight=regularization_weight)]) LossFunction = kwargs.get('LossFunction', SquaredLoss) loss_function_kwargs = kwargs.get('loss_function_kwargs', dict()) loss_function = LossFunction(residual_calculator, **loss_function_kwargs) for reg in Regularizers: loss_function.add_regularizer(**reg) Optimizer = kwargs.get('Optimizer', LBFGS) optimizer_kwargs = kwargs.get('optimizer_kwargs', dict()) optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter) optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol) optimizer = Optimizer(loss_function, **optimizer_kwargs) result = optimizer.optimize() return dict(result=result, optimizer=optimizer, loss_function=loss_function, residual_calculator=residual_calculator, basis_set=basis_set, projector=projector)