Source code for mumott.pipelines.reconstruction.mitra

import logging

import numpy as np

from mumott.data_handling import DataContainer
from mumott.data_handling.utilities import get_absorbances
from mumott.methods.basis_sets import TrivialBasis, GaussianKernels
from mumott.methods.residual_calculators import GradientResidualCalculator
from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
from mumott.methods.utilities import (get_sirt_weights, get_sirt_preconditioner,
                                      get_tensor_sirt_weights, get_tensor_sirt_preconditioner)
from mumott.optimization.loss_functions import SquaredLoss
from mumott.optimization.optimizers import GradientDescent

logger = logging.getLogger(__name__)


[docs]def run_mitra(data_container: DataContainer, use_absorbances: bool = True, use_sirt_weights: bool = True, use_gpu: bool = False, maxiter: int = 20, ftol: float = None, **kwargs): """Reconstruction pipeline for the Modular Iterative Tomographic Reconstruction Algorithm (MITRA). This is a versatile, configureable interface for tomographic reconstruction that allows for various optimizers, projectors, loss functions and regularizers to be supplied. This is meant as a convenience interface for intermediate or advanced users to create customized reconstruction pipelines. Parameters ---------- data_container The :class:`DataContainer <mumott.data_handling.DataContainer>` from loading the data set of interest. use_absorbances If ``True``, the reconstruction will use the absorbances calculated from the diode, or absorbances provided via a keyword argument. If ``False``, the data in :attr:`data_container.data` will be used. use_sirt_weights If ``True`` (default), SIRT or tensor SIRT weights will be computed for use in the reconstruction. use_gpu Whether to use GPU resources in computing the projections. Default is ``False``, which means :class:`SAXSProjector <mumott.methods.projectors.SAXSProjector>`. If set to ``True``, the method will use :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`. maxiter Maximum number of iterations for the gradient descent solution. ftol Tolerance for the change in the loss function. Default is ``None``, in which case the reconstruction will terminate once the maximum number of iterations have been performed. kwargs Miscellaneous keyword arguments. See notes for details. Notes ----- Many options can be specified through ``kwargs``. These include: Projector The :ref:`projector class <projectors>` to use. absorbances If :attr:`use_absorbances` is set to ``True``, these absorbances will be used instead of ones calculated from the diode. preconditioner_cutoff The cutoff to use when computing the :term:`SIRT` preconditioner. Default value is ``0.1``, which will lead to a roughly ellipsoidal mask. weights_cutoff The cutoff to use when computing the :term:`SIRT` weights. Default value is ``0.1``, which will clip some projection edges. BasisSet The :ref:`basis set class <basis_sets>` to use. If not provided :class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>` will be used for absorbances and :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>` for other data. basis_set_kwargs Keyword arguments for :attr:`BasisSet`. ResidualCalculator The :ref:`residual calculator class <residual_calculators>` to use. If not provided, then :class:`GradientResidualCalculator <mumott.methods.residual_calculators.GradientResidualCalculator>` will be used. residual_calculator_kwargs Keyword arguments for :attr:`ResidualCalculator`. LossFunction The :ref:`loss function class <loss_functions>` to use. If not provided :class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>` will be used. loss_function_kwargs Keyword arguments for :attr:`LossFunction`. Regularizers A list of dictionaries with three entries, a name (``str``), a :ref:`regularizer object <regularizers>`, and a regularization weight (``float``); used by :func:`loss_function.add_regularizer() <mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`. Optimizer The optimizer class to use. If not provided :class:`GradientDescent <mumott.optimization.optimizers.GradientDescent>` will be used. By default, the keyword argument ``nestorov_weight`` is set to ``0.95``, and ``enforce_non_negativity`` is ``True`` optimizer_kwargs Keyword arguments for :attr:`Optimizer`. """ if 'Projector' in kwargs: Projector = kwargs.pop('Projector') else: if use_gpu: Projector = SAXSProjectorCUDA else: Projector = SAXSProjector projector = Projector(data_container.geometry) if use_absorbances: if 'absorbances' in kwargs: absorbances = kwargs.pop('absorbances') else: abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True) absorbances = abs_dict['absorbances'] transmittivity_cutoff_mask = abs_dict['cutoff_mask'] data_container.projections.weights *= transmittivity_cutoff_mask else: absorbances = None basis_set_kwargs = kwargs.get('basis_set_kwargs', dict()) if 'BasisSet' in kwargs: BasisSet = kwargs.pop('BasisSet') else: if use_absorbances: BasisSet = TrivialBasis if 'channels' not in basis_set_kwargs: basis_set_kwargs['channels'] = 1 else: BasisSet = GaussianKernels basis_set_kwargs['grid_scale'] = \ basis_set_kwargs.get('grid_scale', (data_container.projections.data.shape[-1]) // 2 + 1) basis_set = BasisSet(**basis_set_kwargs) ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator) residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict()) residual_calculator_kwargs['use_scalar_projections'] = residual_calculator_kwargs.get( 'use_scalar_projections', use_absorbances) residual_calculator_kwargs['scalar_projections'] = residual_calculator_kwargs.get( 'scalar_projections', absorbances) residual_calculator = ResidualCalculator(data_container, basis_set, projector, **residual_calculator_kwargs) Regularizers = kwargs.get('Regularizers', []) LossFunction = kwargs.get('LossFunction', SquaredLoss) loss_function_kwargs = kwargs.get('loss_function_kwargs', dict()) preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1) weights_cutoff = kwargs.get('weights_cutoff', 0.1) if use_sirt_weights: if use_absorbances: preconditioner = get_sirt_preconditioner( projector, cutoff=preconditioner_cutoff) sirt_weights = get_sirt_weights( projector, cutoff=weights_cutoff) else: preconditioner = get_tensor_sirt_preconditioner( projector, basis_set, cutoff=preconditioner_cutoff) sirt_weights = get_tensor_sirt_weights( projector, basis_set, cutoff=weights_cutoff) old_weights = data_container.projections.weights.copy() weights = sirt_weights * np.round(data_container.projections.weights > 0).astype(float) data_container.projections.weights = weights loss_function_kwargs['use_weights'] = True else: # If not using SIRT weights, just fetch identically named arguments as normal from kwargs weights = kwargs.get('weights', data_container.projections.weights) preconditioner = loss_function_kwargs.get('preconditioner', None) loss_function_kwargs['use_weights'] = loss_function_kwargs.get('use_weights', True) loss_function_kwargs['preconditioner'] = preconditioner loss_function = LossFunction(residual_calculator, **loss_function_kwargs) for reg in Regularizers: loss_function.add_regularizer(**reg) optimizer_kwargs = kwargs.get('optimizer_kwargs', dict()) if 'Optimizer' in kwargs: Optimizer = kwargs.pop('Optimizer') else: Optimizer = GradientDescent if 'nestorov_weight' not in optimizer_kwargs: optimizer_kwargs['nestorov_weight'] = 0.95 optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter) optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol) optimizer_kwargs['enforce_non_negativity'] = optimizer_kwargs.get('enforce_non_negativity', True) optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', optimizer_kwargs.get('no_tqdm', False)) optimizer = Optimizer(loss_function, **optimizer_kwargs) result = optimizer.optimize() if use_sirt_weights: data_container.projections.weights = old_weights return dict(result=result, optimizer=optimizer, loss_function=loss_function, residual_calculator=residual_calculator, basis_set=basis_set, projector=projector, absorbances=absorbances, weights=weights, preconditioner=preconditioner)