Coverage for local_installation_linux/mumott/pipelines/phase_matching_alignment.py: 94%

86 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2import sys 

3from typing import Any, Callable, Set 

4 

5import numpy as np 

6import tqdm 

7from skimage.registration import phase_cross_correlation as phase_xcorr 

8from scipy.ndimage import center_of_mass 

9 

10from mumott.data_handling import DataContainer 

11from .reconstruction import run_mitra 

12 

13logger = logging.getLogger(__name__) 

14rng = np.random.default_rng() 

15 

16 

17def _relax_offsets(offsets: np.ndarray[float]) -> np.ndarray[float]: 

18 """ Internal convenience function for adding a stochastic relaxation factor 

19 to offsets. """ 

20 diffs = offsets - offsets.mean() 

21 stds = np.std(diffs) 

22 relaxations = np.sign(diffs) * \ 

23 np.fmax(0, abs(diffs) - abs(stds * rng.standard_normal(diffs.shape))) 

24 return relaxations 

25 

26 

27def _shift_toward_center(center_of_mass_2d: np.ndarray[float], 

28 center_of_mass_3d: np.ndarray[float], 

29 j_vector: np.ndarray[float], 

30 k_vector: np.ndarray[float], 

31 j_offset: float, 

32 k_offset: float) -> np.ndarray[float]: 

33 """ Internal convenience function for aligning centers of mass. """ 

34 com_2d_xyz = j_vector * (center_of_mass_2d[0] + j_offset) + \ 

35 k_vector * (center_of_mass_2d[1] + k_offset) 

36 com_3d_diff = com_2d_xyz - center_of_mass_3d 

37 shifts = np.array((np.dot(j_vector, com_3d_diff), np.dot(k_vector, com_3d_diff))) 

38 return shifts 

39 

40 

41def run_phase_matching_alignment(data_container: DataContainer, 

42 ignored_subset: Set[int] = None, 

43 projection_cropping: tuple[slice, slice] = np.s_[:, :], 

44 reconstruction_pipeline: Callable = run_mitra, 

45 reconstruction_pipeline_kwargs: dict[str, any] = None, 

46 use_gpu: bool = False, 

47 use_absorbances: bool = True, 

48 maxiter: int = 20, 

49 upsampling: int = 1, 

50 shift_tolerance: float = None, 

51 shift_cutoff: float = None, 

52 relative_sample_size: float = 1.0, 

53 relaxation_weight: float = 0.0, 

54 center_of_mass_shift_weight: float = 0.0, 

55 align_j: bool = True, 

56 align_k: bool = True) -> dict[str, Any]: 

57 r"""A pipeline for alignment using the phase cross-correlation method as implemented 

58 by `scikit-image <https://scikit-image.org>`_. 

59 

60 For details on the cross-correlation algorithm, see 

61 `this article by Guizar-Sicairos et al., (2008) <https://doi.org/10.1364/OL.33.000156>`_. 

62 Briefly, the algorithm calculates the cross-correlation between a reference image (the data) and the 

63 corresponding projection of a reconstruction, and finds the shift that would result in 

64 maximal correlation between the two. It supports large upsampling factors with 

65 very little computational overhead. 

66 

67 This implementation applies this algorithm to a randomly sampled subset of the projections in each 

68 iteration, and adds to this two smoothing terms – a stochastic relaxation term, and a 

69 shift toward the center of mass of the reconstruction. These terms are added partly to reduce the 

70 determinism in the algorithm, and partly to improve the performance when no 

71 upsampling is used. 

72 

73 The relaxation term is given by 

74 

75 .. math:: 

76 d(x_i) = \text{sgn}(x_i) \cdot \text{max} 

77 (0, \vert x_i \vert - \vert \mathcal{N}(\overline{\mu}(x), \sigma(x)) \vert) 

78 

79 where :math:`x_i` is a given offset and :math:`\mathcal{N}(\mu, \sigma)` is a random variable 

80 from a normal distribution with mean :math:`\mu` and standard deviation :math:`\sigma`. 

81 :math:`x_i` is then updated by 

82 

83 .. math:: 

84 x_i \leftarrow x_i + \lambda \cdot \text{sign}(d(x_i)) \cdot \text{max}(1, \vert d(x_i) \vert) 

85 

86 where :math:`\lambda` is the :attr:`relaxation_weight`. 

87 

88 The shift toward the center of mass is given by 

89 

90 .. math:: 

91 t(x_i) = \mathbf{v_i} \cdot (\mathbf{v_i}(\text{CoM}(P_i) + x_i)_j - \text{CoM}(R)) 

92 

93 where :math:`\mathbf{v_i}` is the three-dimensional basis vector that maps out :math:`x_i`. 

94 This expression assumes that the basis vectors of the two shift directions are orthogonal, 

95 but the general expression is similar. The term :math:`t(x_i)` is then used to update 

96 :math:`x_i` similarly to :math:`d(x_i)`. 

97 

98 

99 Parameters 

100 ---------- 

101 data_container 

102 The data container from loading the data set of interest. Note that the offset 

103 factors in :class:`data_container.geometry <mumott.core.geometry.Geometry>` will be 

104 modified during the alignment. 

105 ignored_subset 

106 A subset of projection numbers which will not have their alignment modified. 

107 The subset is still used in the reconstruction. 

108 projection_cropping 

109 A tuple of two slices (``slice``), which specify the cropping of the 

110 ``projection`` and ``data``. For example, to clip the first and last 5 

111 pixels in each direction, set this parameter to ``(slice(5, -5), slice(5, -5))``. 

112 reconstruction_pipeline 

113 A ``callable``, typically from the :ref:`reconstruction pipelines <reconstruction_pipelines>`, 

114 that performs the reconstruction at each alignment iteration. Must return a dictionary with a 

115 entry labelled ``'result'``, which has an entry labelled ``'x'`` containing the 

116 reconstruction. Additionally, it must expose a ``'weights'`` entry containing the 

117 weights used during the reconstruction as well as a :ref:`Projector object <projectors>` 

118 under the keyword ``'projector'``. 

119 If the pipeline supports the :attr:`use_absorbances` keyword argument, then an 

120 ``absorbances`` entry must also be exposed. If the pipeline supports using multi-channel data, 

121 absorbances, then a :ref:`basis set object <basis_sets>` must be 

122 available under ``'basis_set'``. 

123 reconstruction_pipeline_kwargs 

124 Keyword arguments to pass to :attr:`reconstruction_pipeline`. If ``'data_container'`` or 

125 ``'use_gpu'`` are set as keys, they will override :attr:`data_container` and :attr:`use_gpu` 

126 use_gpu 

127 Whether to use GPU resources in computing the reconstruction. 

128 Default is ``False``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`. 

129 use_absorbances 

130 Whether to use the absorbances to compute the reconstruction and align the projections. 

131 Default is ``True``. Will be overridden if set in :attr:`reconstruction_pipeline_kwargs`. 

132 maxiter 

133 Maximum number of iterations for the alignment. 

134 upsampling 

135 Upsampling factor during alignment. If used, any masking in :attr:`data_container.weights` will 

136 be ignored, but :attr:`projection_clipping` will still be used. The suggested range of use 

137 is ``[1, 20]``. 

138 shift_tolerance 

139 Tolerance for the the maximal shift distance of each iteration of the alignment. 

140 The alignment will terminate when the maximal shift falls below this value. 

141 The maximal shift is the largest Euclidean distance any one projection is shifted by. 

142 Default value is ``1 / upsampling``. 

143 shift_cutoff 

144 Largest permissible shift due to cross-correlation in each iteration, as measured by the 

145 Euclidean distance. Larger shifts will be rescaled so as to not exceed this value. 

146 Default value is ``5 / upsampling``. 

147 relative_sample_size 

148 Fraction of projections to align in each iteration. At each alignment iteration, 

149 ``ceil(number_of_projections * relative_sample_size)`` will be randomly selected for alignment. 

150 If set to ``1``, all projections will be aligned at each iteration. 

151 relaxation_weight 

152 A relaxation parameter for stochastic relaxation; the larger this weight is, the more shifts will tend 

153 toward the mean shift in each direction. The relaxation step size in each direction at each iteration 

154 cannot be larger than this weight. This is :math:`\lambda` in the expression given above. 

155 center_of_mass_shift_weight 

156 A parameter that controls the tendency for the projection center of mass 

157 to be shifted toward the reconstruction center of mass. The relaxation step size in each direction 

158 at each iteration 

159 cannot be larger than this weight. 

160 align_j 

161 Whether to align in the ``j`` direction. Default is ``True``. 

162 align_k 

163 Whether to align in the ``k`` direction. Default is ``True``. 

164 

165 Returns 

166 ------- 

167 A dictionary with three entries for inspection: 

168 reconstruction 

169 The reconstruction used in the last alignment step. 

170 projections 

171 Projections of the ``reconstruction``. 

172 reference 

173 The reference image derived from the data used to align the ``projections``. 

174 """ 

175 # This is not strictly needed since we don't modify the list, but having a mutable default is bad. 

176 if ignored_subset is None: 176 ↛ 178line 176 didn't jump to line 178, because the condition on line 176 was never false

177 ignored_subset = set() 

178 if not isinstance(ignored_subset, set): 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true

179 raise TypeError(f'ignored_subset must be a set, but a {type(ignored_subset).__name__} was given!') 

180 

181 # Allow user to override arguments given to this function with pipeline kwargs. 

182 if reconstruction_pipeline_kwargs is None: 

183 reconstruction_pipeline_kwargs = dict() 

184 reconstruction_pipeline_kwargs['data_container'] = \ 

185 reconstruction_pipeline_kwargs.get('data_container', data_container) 

186 reconstruction_pipeline_kwargs['use_gpu'] = reconstruction_pipeline_kwargs.get('use_gpu', use_gpu) 

187 reconstruction_pipeline_kwargs['use_absorbances'] = \ 

188 reconstruction_pipeline_kwargs.get('use_absorbances', use_absorbances) 

189 reconstruction_pipeline_kwargs['no_tqdm'] = True 

190 

191 if shift_tolerance is None: 

192 shift_tolerance = 1. / upsampling 

193 if shift_cutoff is None: 193 ↛ 196line 193 didn't jump to line 196, because the condition on line 193 was never false

194 shift_cutoff = 5. / upsampling 

195 

196 if not (align_j or align_k): 

197 raise ValueError('At least one of align_j and align_k must be set to True,' 

198 ' but both are set to False.') 

199 

200 number_of_samples = int( 

201 (np.ceil(len(data_container.geometry) - len(ignored_subset)) * relative_sample_size)) 

202 

203 j_vectors = np.einsum( 

204 'kij,i->kj', 

205 data_container.geometry.rotations_as_array, 

206 data_container.geometry.j_direction_0) 

207 k_vectors = np.einsum( 

208 'kij,i->kj', 

209 data_container.geometry.rotations_as_array, 

210 data_container.geometry.k_direction_0) 

211 

212 for i in tqdm.tqdm(range(maxiter), file=sys.stdout): 

213 pipeline = reconstruction_pipeline(**reconstruction_pipeline_kwargs) 

214 if upsampling == 1: 

215 # Mask is boolean and has no "channel" index 

216 mask = np.all(pipeline['weights'] > 0, -1) 

217 

218 # Project reconstruction into 2D 

219 reconstruction = pipeline['result']['x'] 

220 com_3d = np.array(center_of_mass(reconstruction[..., 0])) 

221 projector = pipeline['projector'] 

222 projections = projector.forward(reconstruction) 

223 

224 # Reinitialize shifts since we apply them at the end of each iteration 

225 shifts = np.zeros((len(projections), 2), dtype=np.float64) 

226 

227 if reconstruction_pipeline_kwargs['use_absorbances'] is True: 

228 reference = pipeline['absorbances'] 

229 else: 

230 # If not absorbances, use mean of detector segments. 

231 reference = np.mean(data_container.data, -1) 

232 reference = reference.reshape(*reference.shape, 1) 

233 projections = pipeline['basis_set'].forward(projections).mean(-1) 

234 projections = projections.reshape(*projections.shape, 1) 

235 

236 valid_indices = list(set(range(len(projections))) - ignored_subset) 

237 sampled_subset = rng.choice(valid_indices, number_of_samples, replace=False) 

238 for i in sampled_subset: 

239 p = projections[i, ..., 0][projection_cropping] 

240 r = reference[i, ..., 0][projection_cropping] 

241 if upsampling == 1: 

242 m = mask[i][projection_cropping] 

243 else: 

244 m = None 

245 

246 shifts[i, :] = phase_xcorr( 

247 p, r, 

248 upsample_factor=upsampling, 

249 reference_mask=m, 

250 moving_mask=m)[0] 

251 

252 # The cross-correlation function is not always totally stable. 

253 shifts = np.nan_to_num(shifts, posinf=0, neginf=0, nan=0) 

254 

255 # Rescale shifts that are too large. 

256 shift_size = np.sqrt(shifts[:, 0] ** 2 + shifts[:, 1] ** 2) 

257 shifts[shift_size > 0, :] *= (shift_size[shift_size > 0].clip(None, shift_cutoff) / 

258 shift_size[shift_size > 0]).reshape(-1, 1) 

259 

260 # Add stochastic relaxation factor, tending to move shifts toward the mean. 

261 shifts[sampled_subset, 0] -= _relax_offsets( 

262 data_container.geometry.j_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight 

263 shifts[sampled_subset, 1] -= _relax_offsets( 

264 data_container.geometry.k_offsets_as_array)[sampled_subset].clip(-1, 1) * relaxation_weight 

265 

266 # Add movement of projection center of mass toward reconstruction center of mass. 

267 for i in sampled_subset: 

268 com_2d = np.array(center_of_mass(projections[i, ..., 0])) 

269 com_shifts = _shift_toward_center(com_2d, 

270 com_3d, 

271 j_vectors[i], 

272 k_vectors[i], 

273 data_container.geometry.j_offsets[i], 

274 data_container.geometry.k_offsets[i]) 

275 shifts[i, 0] += com_shifts[0].clip(-1, 1) * center_of_mass_shift_weight 

276 shifts[i, 1] += com_shifts[1].clip(-1, 1) * center_of_mass_shift_weight 

277 

278 if not align_j: 

279 shifts[:, 0] = 0 

280 if not align_k: 

281 shifts[:, 1] = 0. 

282 

283 data_container.geometry.j_offsets = data_container.geometry.j_offsets_as_array + shifts[:, 0] 

284 data_container.geometry.k_offsets = data_container.geometry.k_offsets_as_array + shifts[:, 1] 

285 

286 if np.max(shift_size) < shift_tolerance: 286 ↛ 287line 286 didn't jump to line 287, because the condition on line 286 was never true

287 logger.info(f'Maximal shift is {np.max(shift_size):.2f}, which is less than' 

288 f' the specified tolerance {shift_tolerance:.2f}. Alignment completed.') 

289 break 

290 else: 

291 logger.info('Maximal number of iterations reached. Alignment completed.') 

292 return dict(reconstruction=reconstruction, projections=projections, reference=reference)