import logging
import numpy as np
from mumott.data_handling import DataContainer
from mumott.data_handling.utilities import get_absorbances
from mumott.methods.basis_sets import TrivialBasis, GaussianKernels
from mumott.methods.residual_calculators import GradientResidualCalculator
from mumott.methods.projectors import SAXSProjectorCUDA, SAXSProjector
from mumott.methods.utilities import (get_sirt_weights, get_sirt_preconditioner,
get_tensor_sirt_weights, get_tensor_sirt_preconditioner)
from mumott.optimization.loss_functions import SquaredLoss
from mumott.optimization.optimizers import GradientDescent
logger = logging.getLogger(__name__)
[docs]def run_mitra(data_container: DataContainer,
use_absorbances: bool = True,
use_sirt_weights: bool = True,
use_gpu: bool = False,
maxiter: int = 20,
ftol: float = None,
**kwargs):
"""Reconstruction pipeline for the Modular Iterative Tomographic Reconstruction Algorithm (MITRA).
This is a versatile, configureable interface for tomographic reconstruction that allows for
various optimizers, projectors, loss functions and regularizers to be supplied.
This is meant as a convenience interface for intermediate or advanced users to create
customized reconstruction pipelines.
Parameters
----------
data_container
The :class:`DataContainer <mumott.data_handling.DataContainer>`
from loading the data set of interest.
use_absorbances
If ``True``, the reconstruction will use the absorbances
calculated from the diode, or absorbances provided via a keyword argument.
If ``False``, the data in :attr:`data_container.data` will be used.
use_sirt_weights
If ``True`` (default), SIRT or tensor SIRT weights will be computed
for use in the reconstruction.
use_gpu
Whether to use GPU resources in computing the projections.
Default is ``False``, which means
:class:`SAXSProjector <mumott.methods.projectors.SAXSProjector>`.
If set to ``True``, the method will use
:class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>`.
maxiter
Maximum number of iterations for the gradient descent solution.
ftol
Tolerance for the change in the loss function. Default is ``None``,
in which case the reconstruction will terminate once the maximum
number of iterations have been performed.
kwargs
Miscellaneous keyword arguments. See notes for details.
Notes
-----
Many options can be specified through ``kwargs``. These include:
Projector
The :ref:`projector class <projectors>` to use.
absorbances
If :attr:`use_absorbances` is set to ``True``, these absorbances
will be used instead of ones calculated from the diode.
preconditioner_cutoff
The cutoff to use when computing the :term:`SIRT` preconditioner.
Default value is ``0.1``,
which will lead to a roughly ellipsoidal mask.
weights_cutoff
The cutoff to use when computing the :term:`SIRT` weights.
Default value is ``0.1``,
which will clip some projection edges.
BasisSet
The :ref:`basis set class <basis_sets>` to use. If not provided
:class:`TrivialBasis <mumott.methods.basis_sets.TrivialBasis>`
will be used for absorbances and
:class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>`
for other data.
basis_set_kwargs
Keyword arguments for :attr:`BasisSet`.
ResidualCalculator
The :ref:`residual calculator class <residual_calculators>` to use.
If not provided, then
:class:`GradientResidualCalculator
<mumott.methods.residual_calculators.GradientResidualCalculator>`
will be used.
residual_calculator_kwargs
Keyword arguments for :attr:`ResidualCalculator`.
LossFunction
The :ref:`loss function class <loss_functions>` to use. If not provided
:class:`SquaredLoss <mumott.optimization.loss_functions.SquaredLoss>`
will be used.
loss_function_kwargs
Keyword arguments for :attr:`LossFunction`.
Regularizers
A list of dictionaries with three entries, a name
(``str``), a :ref:`regularizer object <regularizers>`, and
a regularization weight (``float``); used by
:func:`loss_function.add_regularizer()
<mumott.optimization.loss_functions.SquaredLoss.add_regularizer>`.
Optimizer
The optimizer class to use. If not provided
:class:`GradientDescent <mumott.optimization.optimizers.GradientDescent>`
will be used. By default, the keyword argument ``nestorov_weight`` is set to
``0.95``, and ``enforce_non_negativity`` is ``True``
optimizer_kwargs
Keyword arguments for :attr:`Optimizer`.
"""
if 'Projector' in kwargs:
Projector = kwargs.pop('Projector')
else:
if use_gpu:
Projector = SAXSProjectorCUDA
else:
Projector = SAXSProjector
projector = Projector(data_container.geometry)
if use_absorbances:
if 'absorbances' in kwargs:
absorbances = kwargs.pop('absorbances')
else:
abs_dict = get_absorbances(data_container.diode, normalize_per_projection=True)
absorbances = abs_dict['absorbances']
transmittivity_cutoff_mask = abs_dict['cutoff_mask']
data_container.projections.weights *= transmittivity_cutoff_mask
else:
absorbances = None
basis_set_kwargs = kwargs.get('basis_set_kwargs', dict())
if 'BasisSet' in kwargs:
BasisSet = kwargs.pop('BasisSet')
else:
if use_absorbances:
BasisSet = TrivialBasis
if 'channels' not in basis_set_kwargs:
basis_set_kwargs['channels'] = 1
else:
BasisSet = GaussianKernels
basis_set_kwargs['grid_scale'] = \
basis_set_kwargs.get('grid_scale', (data_container.projections.data.shape[-1]) // 2 + 1)
basis_set = BasisSet(**basis_set_kwargs)
ResidualCalculator = kwargs.get('ResidualCalculator', GradientResidualCalculator)
residual_calculator_kwargs = kwargs.get('residual_calculator_kwargs', dict())
residual_calculator_kwargs['use_scalar_projections'] = residual_calculator_kwargs.get(
'use_scalar_projections', use_absorbances)
residual_calculator_kwargs['scalar_projections'] = residual_calculator_kwargs.get(
'scalar_projections', absorbances)
residual_calculator = ResidualCalculator(data_container,
basis_set,
projector,
**residual_calculator_kwargs)
Regularizers = kwargs.get('Regularizers', [])
LossFunction = kwargs.get('LossFunction', SquaredLoss)
loss_function_kwargs = kwargs.get('loss_function_kwargs', dict())
preconditioner_cutoff = kwargs.get('preconditioner_cutoff', 0.1)
weights_cutoff = kwargs.get('weights_cutoff', 0.1)
if use_sirt_weights:
if use_absorbances:
preconditioner = get_sirt_preconditioner(
projector, cutoff=preconditioner_cutoff)
sirt_weights = get_sirt_weights(
projector, cutoff=weights_cutoff)
else:
preconditioner = get_tensor_sirt_preconditioner(
projector, basis_set, cutoff=preconditioner_cutoff)
sirt_weights = get_tensor_sirt_weights(
projector, basis_set, cutoff=weights_cutoff)
old_weights = data_container.projections.weights.copy()
weights = sirt_weights * np.round(data_container.projections.weights > 0).astype(float)
data_container.projections.weights = weights
loss_function_kwargs['use_weights'] = True
else:
# If not using SIRT weights, just fetch identically named arguments as normal from kwargs
weights = kwargs.get('weights', data_container.projections.weights)
preconditioner = loss_function_kwargs.get('preconditioner', None)
loss_function_kwargs['use_weights'] = loss_function_kwargs.get('use_weights', True)
loss_function_kwargs['preconditioner'] = preconditioner
loss_function = LossFunction(residual_calculator,
**loss_function_kwargs)
for reg in Regularizers:
loss_function.add_regularizer(**reg)
optimizer_kwargs = kwargs.get('optimizer_kwargs', dict())
if 'Optimizer' in kwargs:
Optimizer = kwargs.pop('Optimizer')
else:
Optimizer = GradientDescent
if 'nestorov_weight' not in optimizer_kwargs:
optimizer_kwargs['nestorov_weight'] = 0.95
optimizer_kwargs['maxiter'] = optimizer_kwargs.get('maxiter', maxiter)
optimizer_kwargs['ftol'] = optimizer_kwargs.get('ftol', ftol)
optimizer_kwargs['enforce_non_negativity'] = optimizer_kwargs.get('enforce_non_negativity', True)
optimizer_kwargs['no_tqdm'] = kwargs.get('no_tqdm', optimizer_kwargs.get('no_tqdm', False))
optimizer = Optimizer(loss_function,
**optimizer_kwargs)
result = optimizer.optimize()
if use_sirt_weights:
data_container.projections.weights = old_weights
return dict(result=result, optimizer=optimizer, loss_function=loss_function,
residual_calculator=residual_calculator, basis_set=basis_set, projector=projector,
absorbances=absorbances, weights=weights, preconditioner=preconditioner)