Source code for mumott.pipelines.phase_matching_alignment

import logging
import sys
from typing import Any, Callable, Set

import numpy as np
import tqdm
from skimage.registration import phase_cross_correlation as phase_xcorr
from scipy.ndimage import center_of_mass

from mumott.data_handling import DataContainer
from .reconstruction import run_mitra

logger = logging.getLogger(__name__)
rng = np.random.default_rng()

def _relax_offsets(offsets: np.ndarray[float]) -> np.ndarray[float]:
    """ Internal convenience function for adding a stochastic relaxation factor
    to offsets. """
    diffs = offsets - offsets.mean()
    stds = np.std(diffs)
    relaxations = np.sign(diffs) * \
        np.fmax(0, abs(diffs) - abs(stds * rng.standard_normal(diffs.shape)))
    return relaxations

def _shift_toward_center(center_of_mass_2d: np.ndarray[float],
                         center_of_mass_3d: np.ndarray[float],
                         j_vector: np.ndarray[float],
                         k_vector: np.ndarray[float],
                         j_offset: float,
                         k_offset: float) -> np.ndarray[float]:
    """ Internal convenience function for aligning centers of mass. """
    com_2d_xyz = j_vector * (center_of_mass_2d[0] + j_offset) + \
        k_vector * (center_of_mass_2d[1] + k_offset)
    com_3d_diff = com_2d_xyz - center_of_mass_3d
    shifts = np.array((, com_3d_diff),, com_3d_diff)))
    return shifts

[docs]def run_phase_matching_alignment(data_container: DataContainer, ignored_subset: Set[int] = None, projection_cropping: tuple[slice, slice] = np.s_[:, :], reconstruction_pipeline: Callable = run_mitra, reconstruction_pipeline_kwargs: dict[str, any] = None, use_gpu: bool = False, use_absorbances: bool = True, maxiter: int = 20, upsampling: int = 1, shift_tolerance: float = None, shift_cutoff: float = None, relative_sample_size: float = 1.0, relaxation_weight: float = 0.0, center_of_mass_shift_weight: float = 0.0, align_j: bool = True, align_k: bool = True) -> dict[str, Any]: r"""A pipeline for alignment using the phase cross-correlation method as implemented by `scikit-image <>`_. For details on the cross-correlation algorithm, see `this article by Guizar-Sicairos et al., (2008) <>`_. Briefly, the algorithm calculates the cross-correlation between a reference image (the data) and the corresponding projection of a reconstruction, and finds the shift that would result in maximal correlation between the two. It supports large upsampling factors with very little computational overhead. This implementation applies this algorithm to a randomly sampled subset of the projections in each iteration, and adds to this two smoothing terms – a stochastic relaxation term, and a shift toward the center of mass of the reconstruction. These terms are added partly to reduce the determinism in the algorithm, and partly to improve the performance when no upsampling is used. The relaxation term is given by .. math:: d(x_i) = \text{sgn}(x_i) \cdot \text{max} (0, \vert x_i \vert - \vert \mathcal{N}(\overline{\mu}(x), \sigma(x)) \vert) where :math:`x_i` is a given offset and :math:`\mathcal{N}(\mu, \sigma)` is a random variable from a normal distribution with mean :math:`\mu` and standard deviation :math:`\sigma`. :math:`x_i` is then updated by .. math:: x_i \leftarrow x_i + \lambda \cdot \text{sign}(d(x_i)) \cdot \text{max}(1, \vert d(x_i) \vert) where :math:`\lambda` is the :attr:`relaxation_weight`. The shift toward the center of mass is given by .. math:: t(x_i) = \mathbf{v_i} \cdot (\mathbf{v_i}(\text{CoM}(P_i) + x_i)_j - \text{CoM}(R)) where :math:`\mathbf{v_i}` is the three-dimensional basis vector that maps out :math:`x_i`. This expression assumes that the basis vectors of the two shift directions are orthogonal, but the general expression is similar. The term :math:`t(x_i)` is then used to update :math:`x_i` similarly to :math:`d(x_i)`. Parameters ---------- data_container The data container from loading the data set of interest. Note that the offset factors in :class:`data_container.geometry <mumott.core.geometry.Geometry>` will be modified during the alignment. ignored_subset A subset of projection numbers which will not have their alignment modified. The subset is still used in the reconstruction. projection_cropping A tuple of two slices (``slice``), which specify the cropping of the ``projection`` and ``data``. For example, to clip the first and last 5 pixels in each direction, set this parameter to ``(slice(5, -5), slice(5, -5))``. reconstruction_pipeline A ``callable``, typically from the :ref:`reconstruction pipelines <reconstruction_pipelines>`, that performs the reconstruction at each alignment iteration. Must return a dictionary with a entry labelled ``'result'``, which has an entry labelled ``'x'`` containing the reconstruction. Additionally, it must expose a ``'weights'`` entry containing the weights used during the reconstruction as well as a :ref:`Projector object <projectors>` under the keyword ``'projector'``. If the pipeline supports the :attr:`use_absorbances` keyword argument, then an ``absorbances`` entry must also be exposed. If the pipeline supports using multi-channel data, absorbances, then a :ref:`basis set object <basis_sets>` must be available under ``'basis_set'``. reconstruction_pipeline_kwargs Keyword arguments to pass to :attr:`reconstruction_pipeline`. If ``'data_container'`` or ``'use_gpu'`` are set as keys, they will override :attr:`data_container` and :attr:`use_gpu` use_gpu Whether to use GPU resources in computing the reconstruction. Default is ``False``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`. use_absorbances Whether to use the absorbances to compute the reconstruction and align the projections. Default is ``True``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`. maxiter Maximum number of iterations for the alignment. upsampling Upsampling factor during alignment. If used, any masking in :attr:`data_container.weights` will be ignored, but :attr:`projection_clipping` will still be used. The suggested range of use is ``[1, 20]``. shift_tolerance Tolerance for the the maximal shift distance of each iteration of the alignment. The alignment will terminate when the maximal shift falls below this value. The maximal shift is the largest Euclidean distance any one projection is shifted by. Default value is ``1 / upsampling``. shift_cutoff Largest permissible shift due to cross-correlation in each iteration, as measured by the Euclidean distance. Larger shifts will be rescaled so as to not exceed this value. Default value is ``5 / upsampling``. relative_sample_size Fraction of projections to align in each iteration. At each alignment iteration, ``ceil(number_of_projections * relative_sample_size)`` will be randomly selected for alignment. If set to ``1``, all projections will be aligned at each iteration. relaxation_weight A relaxation parameter for stochastic relaxation; the larger this weight is, the more shifts will tend toward the mean shift in each direction. The relaxation step size in each direction at each iteration cannot be larger than this weight. This is :math:`\lambda` in the expression given above. center_of_mass_shift_weight A parameter that controls the tendency for the projection center of mass to be shifted toward the reconstruction center of mass. The relaxation step size in each direction at each iteration cannot be larger than this weight. align_j Whether to align in the ``j`` direction. Default is ``True``. align_k Whether to align in the ``k`` direction. Default is ``True``. Returns ------- A dictionary with three entries for inspection: reconstruction The reconstruction used in the last alignment step. projections Projections of the ``reconstruction``. reference The reference image derived from the data used to align the ``projections``. """ # This is not strictly needed since we don't modify the list, but having a mutable default is bad. if ignored_subset is None: ignored_subset = set() if not isinstance(ignored_subset, set): raise TypeError(f'ignored_subset must be a set, but a {type(ignored_subset).__name__} was given!') # Allow user to override arguments given to this function with pipeline kwargs. if reconstruction_pipeline_kwargs is None: reconstruction_pipeline_kwargs = dict() reconstruction_pipeline_kwargs['data_container'] = \ reconstruction_pipeline_kwargs.get('data_container', data_container) reconstruction_pipeline_kwargs['use_gpu'] = reconstruction_pipeline_kwargs.get('use_gpu', use_gpu) reconstruction_pipeline_kwargs['use_absorbances'] = \ reconstruction_pipeline_kwargs.get('use_absorbances', use_absorbances) reconstruction_pipeline_kwargs['no_tqdm'] = True if shift_tolerance is None: shift_tolerance = 1. / upsampling if shift_cutoff is None: shift_cutoff = 5. / upsampling if not (align_j or align_k): raise ValueError('At least one of align_j and align_k must be set to True,' ' but both are set to False.') number_of_samples = int( (np.ceil(len(data_container.geometry) - len(ignored_subset)) * relative_sample_size)) j_vectors = np.einsum( 'kij,i->kj', data_container.geometry.rotations_as_array, data_container.geometry.j_direction_0) k_vectors = np.einsum( 'kij,i->kj', data_container.geometry.rotations_as_array, data_container.geometry.k_direction_0) for i in tqdm.tqdm(range(maxiter), file=sys.stdout): pipeline = reconstruction_pipeline(**reconstruction_pipeline_kwargs) if upsampling == 1: # Mask is boolean and has no "channel" index mask = np.all(pipeline['weights'] > 0, -1) # Project reconstruction into 2D reconstruction = pipeline['result']['x'] com_3d = np.array(center_of_mass(reconstruction[..., 0])) projector = pipeline['projector'] projections = projector.forward(reconstruction) # Reinitialize shifts since we apply them at the end of each iteration shifts = np.zeros((len(projections), 2), dtype=np.float64) if reconstruction_pipeline_kwargs['use_absorbances'] is True: reference = pipeline['absorbances'] else: # If not absorbances, use mean of detector segments. reference = np.mean(, -1) reference = reference.reshape(*reference.shape, 1) projections = pipeline['basis_set'].forward(projections).mean(-1) projections = projections.reshape(*projections.shape, 1) valid_indices = list(set(range(len(projections))) - ignored_subset) sampled_subset = rng.choice(valid_indices, number_of_samples, replace=False) for i in sampled_subset: p = projections[i, ..., 0][projection_cropping] r = reference[i, ..., 0][projection_cropping] if upsampling == 1: m = mask[i][projection_cropping] else: m = None shifts[i, :] = phase_xcorr( p, r, return_error='always', upsample_factor=upsampling, reference_mask=m, moving_mask=m)[0] # The cross-correlation function is not always totally stable. shifts = np.nan_to_num(shifts, posinf=0, neginf=0, nan=0) # Rescale shifts that are too large. shift_size = np.sqrt(shifts[:, 0] ** 2 + shifts[:, 1] ** 2) shifts[shift_size > 0, :] *= (shift_size[shift_size > 0].clip(None, shift_cutoff) / shift_size[shift_size > 0]).reshape(-1, 1) # Add stochastic relaxation factor, tending to move shifts toward the mean. shifts[sampled_subset, 0] -= _relax_offsets( data_container.geometry.j_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight shifts[sampled_subset, 1] -= _relax_offsets( data_container.geometry.k_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight # Add movement of projection center of mass toward reconstruction center of mass. for i in sampled_subset: com_2d = np.array(center_of_mass(projections[i, ..., 0])) com_shifts = _shift_toward_center(com_2d, com_3d, j_vectors[i], k_vectors[i], data_container.geometry.j_offsets[i], data_container.geometry.k_offsets[i]) shifts[i, 0] += com_shifts[0].clip(-1, 1) * center_of_mass_shift_weight shifts[i, 1] += com_shifts[1].clip(-1, 1) * center_of_mass_shift_weight if not align_j: shifts[:, 0] = 0 if not align_k: shifts[:, 1] = 0. data_container.geometry.j_offsets = data_container.geometry.j_offsets_as_array + shifts[:, 0] data_container.geometry.k_offsets = data_container.geometry.k_offsets_as_array + shifts[:, 1] if np.max(shift_size) < shift_tolerance:'Maximal shift is {np.max(shift_size):.2f}, which is less than' f' the specified tolerance {shift_tolerance:.2f}. Alignment completed.') break else:'Maximal number of iterations reached. Alignment completed.') return dict(reconstruction=reconstruction, projections=projections, reference=reference)