Coverage for local_installation_linux/mumott/pipelines/phase_matching_alignment.py: 94%
86 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
2import sys
3from typing import Any, Callable, Set
5import numpy as np
6import tqdm
7from skimage.registration import phase_cross_correlation as phase_xcorr
8from scipy.ndimage import center_of_mass
10from mumott.data_handling import DataContainer
11from .reconstruction import run_mitra
13logger = logging.getLogger(__name__)
14rng = np.random.default_rng()
17def _relax_offsets(offsets: np.ndarray[float]) -> np.ndarray[float]:
18 """ Internal convenience function for adding a stochastic relaxation factor
19 to offsets. """
20 diffs = offsets - offsets.mean()
21 stds = np.std(diffs)
22 relaxations = np.sign(diffs) * \
23 np.fmax(0, abs(diffs) - abs(stds * rng.standard_normal(diffs.shape)))
24 return relaxations
27def _shift_toward_center(center_of_mass_2d: np.ndarray[float],
28 center_of_mass_3d: np.ndarray[float],
29 j_vector: np.ndarray[float],
30 k_vector: np.ndarray[float],
31 j_offset: float,
32 k_offset: float) -> np.ndarray[float]:
33 """ Internal convenience function for aligning centers of mass. """
34 com_2d_xyz = j_vector * (center_of_mass_2d[0] + j_offset) + \
35 k_vector * (center_of_mass_2d[1] + k_offset)
36 com_3d_diff = com_2d_xyz - center_of_mass_3d
37 shifts = np.array((np.dot(j_vector, com_3d_diff), np.dot(k_vector, com_3d_diff)))
38 return shifts
41def run_phase_matching_alignment(data_container: DataContainer,
42 ignored_subset: Set[int] = None,
43 projection_cropping: tuple[slice, slice] = np.s_[:, :],
44 reconstruction_pipeline: Callable = run_mitra,
45 reconstruction_pipeline_kwargs: dict[str, any] = None,
46 use_gpu: bool = False,
47 use_absorbances: bool = True,
48 maxiter: int = 20,
49 upsampling: int = 1,
50 shift_tolerance: float = None,
51 shift_cutoff: float = None,
52 relative_sample_size: float = 1.0,
53 relaxation_weight: float = 0.0,
54 center_of_mass_shift_weight: float = 0.0,
55 align_j: bool = True,
56 align_k: bool = True) -> dict[str, Any]:
57 r"""A pipeline for alignment using the phase cross-correlation method as implemented
58 by `scikit-image <https://scikit-image.org>`_.
60 For details on the cross-correlation algorithm, see
61 `this article by Guizar-Sicairos et al., (2008) <https://doi.org/10.1364/OL.33.000156>`_.
62 Briefly, the algorithm calculates the cross-correlation between a reference image (the data) and the
63 corresponding projection of a reconstruction, and finds the shift that would result in
64 maximal correlation between the two. It supports large upsampling factors with
65 very little computational overhead.
67 This implementation applies this algorithm to a randomly sampled subset of the projections in each
68 iteration, and adds to this two smoothing terms – a stochastic relaxation term, and a
69 shift toward the center of mass of the reconstruction. These terms are added partly to reduce the
70 determinism in the algorithm, and partly to improve the performance when no
71 upsampling is used.
73 The relaxation term is given by
75 .. math::
76 d(x_i) = \text{sgn}(x_i) \cdot \text{max}
77 (0, \vert x_i \vert - \vert \mathcal{N}(\overline{\mu}(x), \sigma(x)) \vert)
79 where :math:`x_i` is a given offset and :math:`\mathcal{N}(\mu, \sigma)` is a random variable
80 from a normal distribution with mean :math:`\mu` and standard deviation :math:`\sigma`.
81 :math:`x_i` is then updated by
83 .. math::
84 x_i \leftarrow x_i + \lambda \cdot \text{sign}(d(x_i)) \cdot \text{max}(1, \vert d(x_i) \vert)
86 where :math:`\lambda` is the :attr:`relaxation_weight`.
88 The shift toward the center of mass is given by
90 .. math::
91 t(x_i) = \mathbf{v_i} \cdot (\mathbf{v_i}(\text{CoM}(P_i) + x_i)_j - \text{CoM}(R))
93 where :math:`\mathbf{v_i}` is the three-dimensional basis vector that maps out :math:`x_i`.
94 This expression assumes that the basis vectors of the two shift directions are orthogonal,
95 but the general expression is similar. The term :math:`t(x_i)` is then used to update
96 :math:`x_i` similarly to :math:`d(x_i)`.
99 Parameters
100 ----------
101 data_container
102 The data container from loading the data set of interest. Note that the offset
103 factors in :class:`data_container.geometry <mumott.core.geometry.Geometry>` will be
104 modified during the alignment.
105 ignored_subset
106 A subset of projection numbers which will not have their alignment modified.
107 The subset is still used in the reconstruction.
108 projection_cropping
109 A tuple of two slices (``slice``), which specify the cropping of the
110 ``projection`` and ``data``. For example, to clip the first and last 5
111 pixels in each direction, set this parameter to ``(slice(5, -5), slice(5, -5))``.
112 reconstruction_pipeline
113 A ``callable``, typically from the :ref:`reconstruction pipelines <reconstruction_pipelines>`,
114 that performs the reconstruction at each alignment iteration. Must return a dictionary with a
115 entry labelled ``'result'``, which has an entry labelled ``'x'`` containing the
116 reconstruction. Additionally, it must expose a ``'weights'`` entry containing the
117 weights used during the reconstruction as well as a :ref:`Projector object <projectors>`
118 under the keyword ``'projector'``.
119 If the pipeline supports the :attr:`use_absorbances` keyword argument, then an
120 ``absorbances`` entry must also be exposed. If the pipeline supports using multi-channel data,
121 absorbances, then a :ref:`basis set object <basis_sets>` must be
122 available under ``'basis_set'``.
123 reconstruction_pipeline_kwargs
124 Keyword arguments to pass to :attr:`reconstruction_pipeline`. If ``'data_container'`` or
125 ``'use_gpu'`` are set as keys, they will override :attr:`data_container` and :attr:`use_gpu`
126 use_gpu
127 Whether to use GPU resources in computing the reconstruction.
128 Default is ``False``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`.
129 use_absorbances
130 Whether to use the absorbances to compute the reconstruction and align the projections.
131 Default is ``True``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`.
132 maxiter
133 Maximum number of iterations for the alignment.
134 upsampling
135 Upsampling factor during alignment. If used, any masking in :attr:`data_container.weights` will
136 be ignored, but :attr:`projection_clipping` will still be used. The suggested range of use
137 is ``[1, 20]``.
138 shift_tolerance
139 Tolerance for the the maximal shift distance of each iteration of the alignment.
140 The alignment will terminate when the maximal shift falls below this value.
141 The maximal shift is the largest Euclidean distance any one projection is shifted by.
142 Default value is ``1 / upsampling``.
143 shift_cutoff
144 Largest permissible shift due to cross-correlation in each iteration, as measured by the
145 Euclidean distance. Larger shifts will be rescaled so as to not exceed this value.
146 Default value is ``5 / upsampling``.
147 relative_sample_size
148 Fraction of projections to align in each iteration. At each alignment iteration,
149 ``ceil(number_of_projections * relative_sample_size)`` will be randomly selected for alignment.
150 If set to ``1``, all projections will be aligned at each iteration.
151 relaxation_weight
152 A relaxation parameter for stochastic relaxation; the larger this weight is, the more shifts will tend
153 toward the mean shift in each direction. The relaxation step size in each direction at each iteration
154 cannot be larger than this weight. This is :math:`\lambda` in the expression given above.
155 center_of_mass_shift_weight
156 A parameter that controls the tendency for the projection center of mass
157 to be shifted toward the reconstruction center of mass. The relaxation step size in each direction
158 at each iteration
159 cannot be larger than this weight.
160 align_j
161 Whether to align in the ``j`` direction. Default is ``True``.
162 align_k
163 Whether to align in the ``k`` direction. Default is ``True``.
165 Returns
166 -------
167 A dictionary with three entries for inspection:
168 reconstruction
169 The reconstruction used in the last alignment step.
170 projections
171 Projections of the ``reconstruction``.
172 reference
173 The reference image derived from the data used to align the ``projections``.
174 """
175 # This is not strictly needed since we don't modify the list, but having a mutable default is bad.
176 if ignored_subset is None: 176 ↛ 178line 176 didn't jump to line 178, because the condition on line 176 was never false
177 ignored_subset = set()
178 if not isinstance(ignored_subset, set): 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true
179 raise TypeError(f'ignored_subset must be a set, but a {type(ignored_subset).__name__} was given!')
181 # Allow user to override arguments given to this function with pipeline kwargs.
182 if reconstruction_pipeline_kwargs is None:
183 reconstruction_pipeline_kwargs = dict()
184 reconstruction_pipeline_kwargs['data_container'] = \
185 reconstruction_pipeline_kwargs.get('data_container', data_container)
186 reconstruction_pipeline_kwargs['use_gpu'] = reconstruction_pipeline_kwargs.get('use_gpu', use_gpu)
187 reconstruction_pipeline_kwargs['use_absorbances'] = \
188 reconstruction_pipeline_kwargs.get('use_absorbances', use_absorbances)
189 reconstruction_pipeline_kwargs['no_tqdm'] = True
191 if shift_tolerance is None:
192 shift_tolerance = 1. / upsampling
193 if shift_cutoff is None: 193 ↛ 196line 193 didn't jump to line 196, because the condition on line 193 was never false
194 shift_cutoff = 5. / upsampling
196 if not (align_j or align_k):
197 raise ValueError('At least one of align_j and align_k must be set to True,'
198 ' but both are set to False.')
200 number_of_samples = int(
201 (np.ceil(len(data_container.geometry) - len(ignored_subset)) * relative_sample_size))
203 j_vectors = np.einsum(
204 'kij,i->kj',
205 data_container.geometry.rotations_as_array,
206 data_container.geometry.j_direction_0)
207 k_vectors = np.einsum(
208 'kij,i->kj',
209 data_container.geometry.rotations_as_array,
210 data_container.geometry.k_direction_0)
212 for i in tqdm.tqdm(range(maxiter), file=sys.stdout):
213 pipeline = reconstruction_pipeline(**reconstruction_pipeline_kwargs)
214 if upsampling == 1:
215 # Mask is boolean and has no "channel" index
216 mask = np.all(pipeline['weights'] > 0, -1)
218 # Project reconstruction into 2D
219 reconstruction = pipeline['result']['x']
220 com_3d = np.array(center_of_mass(reconstruction[..., 0]))
221 projector = pipeline['projector']
222 projections = projector.forward(reconstruction)
224 # Reinitialize shifts since we apply them at the end of each iteration
225 shifts = np.zeros((len(projections), 2), dtype=np.float64)
227 if reconstruction_pipeline_kwargs['use_absorbances'] is True:
228 reference = pipeline['absorbances']
229 else:
230 # If not absorbances, use mean of detector segments.
231 reference = np.mean(data_container.data, -1)
232 reference = reference.reshape(*reference.shape, 1)
233 projections = pipeline['basis_set'].forward(projections).mean(-1)
234 projections = projections.reshape(*projections.shape, 1)
236 valid_indices = list(set(range(len(projections))) - ignored_subset)
237 sampled_subset = rng.choice(valid_indices, number_of_samples, replace=False)
238 for i in sampled_subset:
239 p = projections[i, ..., 0][projection_cropping]
240 r = reference[i, ..., 0][projection_cropping]
241 if upsampling == 1:
242 m = mask[i][projection_cropping]
243 else:
244 m = None
246 shifts[i, :] = phase_xcorr(
247 p, r,
248 upsample_factor=upsampling,
249 reference_mask=m,
250 moving_mask=m)[0]
252 # The cross-correlation function is not always totally stable.
253 shifts = np.nan_to_num(shifts, posinf=0, neginf=0, nan=0)
255 # Rescale shifts that are too large.
256 shift_size = np.sqrt(shifts[:, 0] ** 2 + shifts[:, 1] ** 2)
257 shifts[shift_size > 0, :] *= (shift_size[shift_size > 0].clip(None, shift_cutoff) /
258 shift_size[shift_size > 0]).reshape(-1, 1)
260 # Add stochastic relaxation factor, tending to move shifts toward the mean.
261 shifts[sampled_subset, 0] -= _relax_offsets(
262 data_container.geometry.j_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight
263 shifts[sampled_subset, 1] -= _relax_offsets(
264 data_container.geometry.k_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight
266 # Add movement of projection center of mass toward reconstruction center of mass.
267 for i in sampled_subset:
268 com_2d = np.array(center_of_mass(projections[i, ..., 0]))
269 com_shifts = _shift_toward_center(com_2d,
270 com_3d,
271 j_vectors[i],
272 k_vectors[i],
273 data_container.geometry.j_offsets[i],
274 data_container.geometry.k_offsets[i])
275 shifts[i, 0] += com_shifts[0].clip(-1, 1) * center_of_mass_shift_weight
276 shifts[i, 1] += com_shifts[1].clip(-1, 1) * center_of_mass_shift_weight
278 if not align_j:
279 shifts[:, 0] = 0
280 if not align_k:
281 shifts[:, 1] = 0.
283 data_container.geometry.j_offsets = data_container.geometry.j_offsets_as_array + shifts[:, 0]
284 data_container.geometry.k_offsets = data_container.geometry.k_offsets_as_array + shifts[:, 1]
286 if np.max(shift_size) < shift_tolerance: 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true
287 logger.info(f'Maximal shift is {np.max(shift_size):.2f}, which is less than'
288 f' the specified tolerance {shift_tolerance:.2f}. Alignment completed.')
289 break
290 else:
291 logger.info('Maximal number of iterations reached. Alignment completed.')
292 return dict(reconstruction=reconstruction, projections=projections, reference=reference)