Coverage for local_installation_linux/mumott/methods/residual_calculators/zonal_harmonic_gradient_calculator.py: 85%
188 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
1import logging
3import numpy as np
4from numpy.typing import NDArray
6from mumott import DataContainer
7from mumott.core.wigner_d_utilities import (
8 load_d_matrices, calculate_sph_coefficients_rotated_around_z,
9 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x,
10 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x,
11 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle)
12from mumott.core.hashing import list_to_hash
13from mumott.methods.projectors.base_projector import Projector
14from mumott.methods.basis_sets.spherical_harmonics import SphericalHarmonics
15from .base_residual_calculator import ResidualCalculator
18logger = logging.getLogger(__name__)
21class ZHTTResidualCalculator(ResidualCalculator):
22 r"""Class that implements the gradient calculations for a model that uses a
23 :class:`SphericalHarmonics` basis set restricted to zonal harmonics parametrized
24 by a primary axis with polar coordinates :math:`\theta_0` and :math:`\phi_0`
25 ,defined as:
27 .. math::
29 \begin{pmatrix} x_0\\ y_0\\ z_0\end{pmatrix}
30 = \begin{pmatrix}
31 \sin(\theta_0) \sin(\phi_0) \\
32 \sin(\theta_0) \cos(\phi_0) \\
33 \cos(\theta_0)
34 \end{pmatrix}
36 This model is equivalent to the one used in [Liebi2015]_, but uses a different approach to
37 computation.
39 This implementation avoids doing some of the expensive calculations of trigonometric functions and
40 Legendre polynomials by doing the rotation in the space of the spherical harmonics using
41 `Wigner (small) d-matrices <https://en.wikipedia.org/wiki/Wigner_D-matrix>`_.
42 The forward model only involves a small number of trigonometric functions to evaluate the
43 :math:`d_z(\text{angle})` matrices for the :math:`\theta` and :math:`\phi` rotations.
44 Everything else is expressed as matrix products with precomputed matrices.
46 The full forward model may be written as:
48 .. math::
50 \boldsymbol{I} =
51 \boldsymbol{W} \boldsymbol{P}
52 \boldsymbol{d}_z(\phi_0) \boldsymbol{d}_y(\frac{\pi}{4})^T
53 \boldsymbol{d}_z(-\theta_0) \boldsymbol{d}_y(\frac{\pi}{4})
54 \boldsymbol{a}'_{l0},
56 where :math:`\boldsymbol{W}` is the mapping from spherical harmonic modes to detector segments,
57 which can be precomputed.
58 :math:`\boldsymbol{P}` is the typical projector from normal 3D tomography and
59 :math:`\boldsymbol{d}_i(\text{angle})` with :math:`i = x,y,z` are Wigner (small) d matrices
60 for real spherical harmonics. :math:`\theta_0`, :math:`\phi_0`, and :math:`\boldsymbol{a}_{l0}`
61 are the model parameters for each voxel.
63 Derivatives are easy to evaluate because the angles only appear in the
64 :math:`\boldsymbol{d}_z(\text{angle})`-matrices. All the expensive trigonometric and spherical
65 harmonics calculations have been put into the precomputation of :math:`\boldsymbol{W}`
66 and :math:`\boldsymbol{d}_y(\frac{\pi}{4})`.
68 Parameters
69 ----------
70 data_container : DataContainer
71 Container holding the data to be reconstructed.
72 basis_set : SphericalHarmonics
73 The basis set used for representing spherical functions.
74 projector : Projector
75 The type of projector used together with this method.
76 """
77 def __init__(self,
78 data_container: DataContainer,
79 basis_set: SphericalHarmonics,
80 projector: Projector):
81 super().__init__(data_container, basis_set, projector)
82 self._make_matrices()
83 self._make_starting_guess()
85 def _make_starting_guess(self) -> None:
86 """Initializes the optimization parameters by setting the
87 zonal coefficients to zero and randomizing the angles, which
88 corresponds to sampling directions uniformly on the unit
89 sphere.
90 """
91 volume_shape = self._projector.volume_shape
92 self._zonal_coefficients = np.zeros((*volume_shape, self._basis_set.ell_max // 2 + 1))
94 # Make random orientations by random sampling in 3D
95 rng = np.random.default_rng()
96 self._theta = np.arccos(rng.uniform(low=0, high=1, size=volume_shape))
97 self._phi = rng.uniform(low=-np.pi, high=np.pi, size=volume_shape)
99 def _make_matrices(self) -> None:
100 """
101 Loads Wigner d-matrices and creates the mapping from parameters
102 to spherical harmonics coefficients.
103 """
104 # Load precomputed d-matrices
105 ell_max = self._basis_set.ell_max
106 self.d_matrices = load_d_matrices(ell_max)
108 # Set up matrix for converting from zonal harmonics to full harmonics space
109 ell_list = self._basis_set.ell_indices
110 m_list = self._basis_set.emm_indices
111 self._E = np.zeros((len(ell_list), ell_max//2+1))
112 for full_index, (ell, m) in enumerate(zip(ell_list, m_list)):
113 if m == 0:
114 self._E[full_index, ell//2] = 1
116 @property
117 def coefficients(self) -> NDArray:
118 """Optimization coefficients for this method.
119 Contains both the zonal coefficients and the angles.
120 The first N-2 elements are zonal coefficients.
121 The N-1th element is the polar angle and the last element is the azimuthal angle.
122 """
123 self._cast_angles_to_symmetric_zone()
124 return np.concatenate((self._zonal_coefficients,
125 self._theta[..., np.newaxis],
126 self._phi[..., np.newaxis]), axis=3)
128 @coefficients.setter
129 def coefficients(self, val: NDArray) -> None:
130 # Convert from external to internal representation of optimization parameters
131 val = val.reshape((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 1 + 2))
132 assert np.shape(val[..., :-2]) == np.shape(self._zonal_coefficients), \
133 'Shape of new array inconsistent with expectation (zonal_coefficients)'
134 assert np.shape(val[..., -2]) == np.shape(self._theta), \
135 'Shape of new array inconsistent with expectation (theta)'
136 assert np.shape(val[..., -1]) == np.shape(self._phi), \
137 'Shape of new array inconsistent with expectation (phi)'
138 self._zonal_coefficients = val[..., :-2]
139 self._theta = val[..., -2]
140 self._phi = val[..., -1]
142 def _rotate_coeffs(self) -> NDArray:
143 """Expand from the zonal harmonics basis to a full spherical harmonics basis and
144 rotate the spherical harmonics coefficients from the symmetric coordinate system
145 to the sample xyz system.
147 Returns
148 -------
149 Array containing the rotated spherical harmonics coefficients.
150 """
151 ell_list = np.arange(0, self._basis_set.ell_max + 1, 2)
152 # Expand symmetric coefficients into full basis
153 self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E)
154 # Rotate by 90 degrees about x
155 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
156 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
157 # Rotate by theta about z
158 calculate_sph_coefficients_rotated_around_z(
159 self._coefficients, self._theta, ell_list, output_array=self._coefficients)
160 # Rotate by -90 degrees about x
161 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
162 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
163 # Rotate by phi about z
164 calculate_sph_coefficients_rotated_around_z(
165 self._coefficients, self._phi, ell_list, output_array=self._coefficients)
166 return self._coefficients
168 def _rotate_and_derive(self):
169 """
170 Rotate spherical harmonics coefficients from the symmetric coordinate system
171 to the sample xyz system and evaluate the derivative of the coefficients
172 with respect to the two rotation angles.
174 Returns
175 ----------
176 self._coefficients : NDArray
177 Array containing the rotated spherical harmonics coefficients.
178 theta_derivative : NDArray
179 Rotated spherical coefficients derived with respect to the polar rotation angle
180 evaluated at the current value of the rotation angles.
181 phi_derivative : NDArray
182 Rotated spherical coefficients derived with respect to the azimuthal rotation angle
183 evaluated at the current value of the rotation angles.
184 """
186 ell_list = np.arange(0, self._basis_set.ell_max+1, 2)
187 # Expand symmetric coefficients into full basis
188 self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E)
189 theta_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set)))
190 phi_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set)))
192 # Do 90 degree rotation around x
193 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
194 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
196 # Do z rotation of Theta and derivative
197 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle(
198 self._coefficients, self._theta, ell_list, output_array=theta_derivative)
199 calculate_sph_coefficients_rotated_around_z(
200 self._coefficients, self._theta, ell_list, output_array=self._coefficients)
202 # Do -90 degree rotation around x
203 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
204 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients)
205 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
206 theta_derivative, ell_list, self.d_matrices, output_array=theta_derivative)
208 # Do z rotation of Phi
209 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle(
210 self._coefficients, self._phi, ell_list, output_array=phi_derivative)
211 calculate_sph_coefficients_rotated_around_z(
212 self._coefficients, self._phi, ell_list, output_array=self._coefficients)
213 calculate_sph_coefficients_rotated_around_z(
214 theta_derivative, self._phi, ell_list, output_array=theta_derivative)
216 return self._coefficients, theta_derivative, phi_derivative
218 def _rotate_coeffs_inverse(self, coefficients: NDArray):
219 """
220 Rotate spherical harmonics coefficients from the sample xyz system
221 to the symmetric coordinate system.
222 """
223 ell_list = np.arange(0, self._basis_set.ell_max+1, 2)
225 # Do z rotation of -phi
226 calculate_sph_coefficients_rotated_around_z(
227 coefficients, -self._phi, ell_list, output_array=coefficients)
229 # Do 90 degree rotation around x
230 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x(
231 coefficients, ell_list, self.d_matrices, output_array=coefficients)
233 # Do z rotation of -theta
234 calculate_sph_coefficients_rotated_around_z(
235 coefficients, -self._theta, ell_list, output_array=coefficients)
237 # Do -90 degree rotation around x
238 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x(
239 coefficients, ell_list, self.d_matrices, output_array=coefficients)
241 return coefficients
243 def get_residuals(self,
244 get_gradient: bool = False,
245 get_weights: bool = False,
246 gradient_part: str = 'full') -> dict[str, NDArray[float]]:
247 """ Calculates the residuals and possibly the gradient of the residual square sum
248 (without the factor of -2!) with respect to the parameters.
249 The coefficients are projected using the :attr:`SphericalHarmonics` and :attr:`Projector`
250 attached to this instance.
252 Parameters
253 ----------
254 get_gradient
255 Whether to return the gradient. Default is ``False``.
256 get_weights
257 Whether to return weights. Default is ``False``. If ``True`` along with
258 :attr:`get_gradient`, the gradient will be computed with weights.
259 gradient_part
260 If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all
261 parameters;
262 if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed;
263 if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal
264 spherical harmonics coefficients is computed.
266 Returns
267 -------
268 A dictionary containing the residuals, and possibly the
269 gradient and/or weights. If gradient and/or weights
270 are not returned, their value will be ``None``.
271 """
273 if not get_gradient:
274 # Rotate the coefficients
275 self._rotate_coeffs()
276 # Project from voxel to detector space and from coefficient to angle space
277 projection = self._basis_set.forward(
278 self._projector.forward(self._coefficients.astype(self.dtype)))
279 # Calculate residuals
280 residuals = self._data - projection
281 if get_weights: 281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true
282 residuals *= self._weights
283 output = {'residuals': residuals, 'gradient': None}
285 elif get_gradient: 285 ↛ 289line 285 didn't jump to line 289, because the condition on line 285 was never false
286 output = self.get_gradient(get_weights=get_weights)
288 # Pass on weights, if asked to
289 if get_weights:
290 output['weights'] = self._weights
291 else:
292 output['weights'] = None
294 return output
296 def get_gradient(self,
297 get_weights: bool = False,
298 gradient_part: str = 'full') -> dict[str, NDArray[float]]:
299 """ Calculates the gradient of *half* the sum of residuals squared.
301 Parameters
302 ----------
303 get_gradient
304 Whether to return the gradient. Default is ``False``.
305 gradient_part
306 If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all
307 parameters;
308 if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed;
309 if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal
310 spherical harmonics coefficients is computed.
312 Returns
313 -------
314 A dictionary containing the residuals of the gradient. If only a part of the
315 gradient is computed, the rest of the elements will be filled with zeros.
316 """
317 # initialize output array
318 gradient = np.zeros((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 3))
320 # If only the coefficients are needed, do not evaluate the derivatives.
321 if gradient_part == 'coefficients': 321 ↛ 322line 321 didn't jump to line 322, because the condition on line 321 was never true
322 coefficients = self._rotate_coeffs()
323 else:
324 coefficients, theta_derivative, phi_derivative = self._rotate_and_derive()
326 # Project from voxel to detector space and the from coeff-space to angle-space
327 projection = self._basis_set.forward(self._projector.forward(coefficients.astype(self.dtype)))
328 # Calculate residuals
329 residuals = self._data - projection
330 if get_weights:
331 residuals *= self._weights
332 # Backproject residual
333 bp_res = self._projector.adjoint(
334 self._basis_set.gradient(residuals * self._weights).astype(self.dtype))
336 # If the gradient with respect to angles is needed, compute the inner products
337 if gradient_part in ['full', 'angles']: 337 ↛ 341line 337 didn't jump to line 341, because the condition on line 337 was never false
338 gradient[:, :, :, -2] = -np.einsum('xyzm,xyzm->xyz', bp_res, theta_derivative)
339 gradient[:, :, :, -1] = -np.einsum('xyzm,xyzm->xyz', bp_res, phi_derivative)
341 if gradient_part == 'full' or gradient_part == 'coefficients': 341 ↛ 346line 341 didn't jump to line 346, because the condition on line 341 was never false
342 # back-rotate coefficients
343 bp_res = self._rotate_coeffs_inverse(bp_res)
344 gradient[..., :-2] += -np.einsum('...i,ij->...j', bp_res, self._E)
346 return {'residuals': residuals, 'gradient': gradient}
348 def _cast_angles_to_symmetric_zone(self):
349 r"""
350 Casts internal angle arrays into the range :math:`\theta \in [0, \phi/2[` and
351 :math:`\phi \in [0, 2\phi[`.
352 """
353 self._theta = self._theta % np.pi
354 southern_hemisphere = self._theta > (np.pi / 2)
355 self._theta[southern_hemisphere] = np.pi - self._theta[southern_hemisphere]
356 self._phi[southern_hemisphere] = self._phi[southern_hemisphere] + np.pi
357 self._phi = self._phi % (2 * np.pi)
359 @property
360 def rotated_coefficients(self):
361 """
362 Returns the real spherical harmonics coefficients.
363 """
364 return self._rotate_coeffs()
366 @property
367 def directions(self):
368 """
369 Returns the direction of symmetry as a unit vector in in xyz coordinates.
370 The vector index is the last index of the output.
371 """
372 # Make unit direction vectors
373 directions = np.stack((np.cos(self._phi)*np.sin(self._theta),
374 np.sin(self._phi)*np.sin(self._theta),
375 np.cos(self._phi)), axis=-1)
376 return directions
378 @property
379 def ell_max(self) -> int:
380 """l max"""
381 return self._basis_set.ell_max
383 @property
384 def volume_shape(self) -> int:
385 """Shape of voxel volume"""
386 return self._projector.volume_shape
388 def _update(self, force_update: bool = False) -> None:
389 """ Carries out necessary updates if anything changes with respect to
390 the geometry or basis set. """
391 if not (self.is_dirty or force_update): 391 ↛ 392line 391 didn't jump to line 392, because the condition on line 391 was never true
392 return
393 self._basis_set.probed_coordinates = self.probed_coordinates
395 # See ell_max changed
396 old_ellmax = (self._zonal_coefficients.shape[-1] - 1) * 2
397 old_num_coeffs = (old_ellmax + 1) * (old_ellmax + 2) // 2
398 len_diff = len(self._basis_set) - old_num_coeffs
400 vol_diff = self._data_container.geometry.volume_shape - np.array(self._coefficients.shape[:-1])
401 # TODO: Think about whether the ``Method`` should do this or handle it differently
402 if len_diff != 0 and not np.any(vol_diff != 0): 402 ↛ 403line 402 didn't jump to line 403, because the condition on line 402 was never true
403 logger.warning('ell_max has changed. Coefficients will be truncated or appended with zeros.')
404 self._make_matrices()
406 old_params = np.array(self._zonal_coefficients)
407 self._zonal_coefficients = np.zeros((*self._projector.volume_shape,
408 self._basis_set.ell_max // 2 + 1))
409 if len_diff > 0:
410 self._zonal_coefficients[:, :, :, :old_params.shape[-1]] = old_params
411 self._coefficients = np.zeros((*self._data_container.geometry.volume_shape,
412 len(self._basis_set)), dtype=self.dtype)
413 if len_diff < 0:
414 self._zonal_coefficients = old_params[:, :, :, :self._zonal_coefficients.shape[-1]]
415 self._coefficients = np.zeros((*self._data_container.geometry.volume_shape,
416 len(self._basis_set)), dtype=self.dtype)
418 elif np.any(vol_diff != 0): 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true
419 logger.warning('Volume shape has changed.'
420 ' Coefficients have been reset to zero and angles have been randomized.')
421 self._make_matrices()
422 self._random_starting_guess()
424 self._geometry_hash = hash(self._data_container.geometry)
425 self._basis_set_hash = hash(self._basis_set)
427 def __hash__(self) -> int:
428 """ Returns a hash of the current state of this instance. """
429 to_hash = [self._zonal_coefficients,
430 self._theta,
431 self._phi,
432 hash(self._projector),
433 hash(self._data_container.geometry),
434 self._basis_set_hash,
435 self._geometry_hash]
436 return int(list_to_hash(to_hash), 16)
438 @property
439 def is_dirty(self) -> bool:
440 """ ``True`` if stored hashes of geometry or basis set objects do
441 not match their current hashes. Used to trigger updates """
442 return ((self._geometry_hash != hash(self._data_container.geometry)) or
443 (self._basis_set_hash != hash(self._basis_set)))
445 def __str__(self) -> str:
446 wdt = 74
447 s = []
448 s += ['=' * wdt]
449 s += [self.__class__.__name__.center(wdt)]
450 s += ['-' * wdt]
451 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
452 s += ['{:18} : {}'.format('BasisSet', self._basis_set.__class__.__name__)]
453 s += ['{:18} : {}'.format('Projector', self._projector.__class__.__name__)]
454 s += ['{:18} : {}'.format('is_dirty', self.is_dirty)]
455 s += ['{:18} : {}'.format('probed_coordinates (hash)',
456 hex(hash(self.probed_coordinates))[2:8])]
457 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
458 s += ['-' * wdt]
459 return '\n'.join(s)
461 def _repr_html_(self) -> str:
462 s = []
463 s += [f'<h3>{self.__class__.__name__}</h3>']
464 s += ['<table border="1" class="dataframe">']
465 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
466 s += ['<tbody>']
467 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
468 s += ['<tr><td style="text-align: left;">BasisSet</td>']
469 s += [f'<td>{1}</td><td>{self._basis_set.__class__.__name__}</td></tr>']
470 s += ['<tr><td style="text-align: left;">Projector</td>']
471 s += [f'<td>{len(self._projector.__class__.__name__)}</td>'
472 f'<td>{self._projector.__class__.__name__}</td></tr>']
473 s += ['<tr><td style="text-align: left;">Is dirty</td>']
474 s += [f'<td>{1}</td><td>{self.is_dirty}</td></tr>']
475 s += ['<tr><td style="text-align: left;">probed_coordinates</td>']
476 s += [f'<td>{self.probed_coordinates.vector.shape}</td>'
477 f'<td>{hex(hash(self.probed_coordinates))[2:8]} (hash)</td></tr>']
478 s += ['<tr><td style="text-align: left;">Hash</td>']
479 h = hex(hash(self))
480 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
481 s += ['</tbody>']
482 s += ['</table>']
483 return '\n'.join(s)