Coverage for local_installation_linux/mumott/methods/basis_sets/nearest_neighbor.py: 92%
216 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
1import logging
2from typing import Any, Dict, Tuple
3from copy import deepcopy
5import numpy as np
6from numpy.typing import NDArray
8from mumott import ProbedCoordinates, DataContainer, Geometry, SphericalHarmonicMapper
9from mumott.core.hashing import list_to_hash
10from mumott.methods.utilities.tensor_operations import (framewise_contraction,
11 framewise_contraction_transpose)
12from .base_basis_set import BasisSet
14logger = logging.getLogger(__name__)
17class NearestNeighbor(BasisSet):
18 r""" Basis set class for nearest-neighbor interpolation. Used to construct methods similar to that
19 presented in `Schaff et al. (2015) <https://doi.org/10.1038/nature16060>`_.
20 By default this representation is sparse and maps only a single direction on the sphere
21 to each detector segment. This can be changed; see ``kwargs``.
23 Parameters
24 ----------
25 directions : NDArray[float]
26 Two-dimensional Array containing the ``N`` sensitivity directions with shape ``(N, 3)``.
27 probed_coordinates : ProbedCoordinates
28 Optional. Coordinates on the sphere probed at each detector segment by the
29 experimental method. Its construction from the system geometry is method-dependent.
30 By default, an empty instance of :class:`mumott.ProbedCoordinates` is created.
31 enforce_friedel_symmetry : bool
32 If set to ``True``, Friedel symmetry will be enforced, using the assumption that points
33 on opposite sides of the sphere are equivalent.
34 kwargs
35 Miscellaneous arguments which relate to segment integrations can be
36 passed as keyword arguments:
38 integration_mode
39 Mode to integrate line segments on the reciprocal space sphere. Possible options are
40 ``'simpson'``, ``'midpoint'``, ``'romberg'``, ``'trapezoid'``.
41 ``'simpson'``, ``'trapezoid'``, and ``'romberg'`` use adaptive
42 integration with the respective quadrature rule from ``scipy.integrate``.
43 ``'midpoint'`` uses a single mid-point approximation of the integral.
44 Default value is ``'simpson'``.
46 n_integration_starting_points
47 Number of points used in the first iteration of the adaptive integration.
48 The number increases by the rule ``N`` ← ``2 * N - 1`` for each iteration.
49 Default value is 3.
50 integration_tolerance
51 Tolerance for the maximum relative error between iterations before the integral
52 is considered converged. Default is ``1e-3``.
54 integration_maxiter
55 Maximum number of iterations. Default is ``10``.
56 enforce_sparsity
57 If ``True``, limites the number of basis set elements
58 that can map to each detector segemnt. Default is ``False``.
59 sparsity_count
60 If ``enforce_sparsity`` is set to ``True``, the number of
61 basis set elements that can map to each detector segment.
62 Default value is ``1``.
63 """
64 def __init__(self,
65 directions: NDArray[float],
66 probed_coordinates: ProbedCoordinates = None,
67 enforce_friedel_symmetry: bool = True,
68 **kwargs):
69 # This basis set struggles with integral convergence due to sharp transitions
70 kwargs.update(dict(integration_tolerance=kwargs.get('integration_tolerance', 1e-3),
71 sparsity_count=kwargs.get('sparsity_count', 1)))
72 super().__init__(probed_coordinates, **kwargs)
73 # Handling grid of directions
74 self._number_of_coefficients = directions.shape[0]
75 if enforce_friedel_symmetry: 75 ↛ 78line 75 didn't jump to line 78, because the condition on line 75 was never false
76 self._directions_full = np.concatenate((directions, -directions), axis=0)
77 else:
78 self._directions_full = np.array(directions)
80 self._probed_coordinates_hash = hash(self.probed_coordinates)
82 self._enforce_friedel_symmetry = enforce_friedel_symmetry
83 self._projection_matrix = self._get_integrated_projection_matrix()
85 def find_nearest_neighbor_index(self, probed_directions: NDArray[float]) -> NDArray[int]:
86 """
87 Caluculate the nearest neighbor sensitivity directions for an array of x-y-z vectors.
89 Parameters
90 ----------
91 probed_directions
92 Array with length 3 along its last axis
94 Returns
95 -------
96 Array with same shape as the input except for the last dimension, which
97 contains the index of the nearest-neighbor sensitivity direction.
98 """
100 # normalize input directions
101 input_shape = probed_directions.shape
102 normed_probed_directions = probed_directions / \
103 np.linalg.norm(probed_directions, axis=-1)[..., np.newaxis]
105 # Find distance (3D euclidian) between each probed direction and sensitivity direction
106 pad_dimension = (1,) * (len(input_shape)-1)
107 distance = np.sum((normed_probed_directions[np.newaxis, ...] -
108 self._directions_full.reshape(self._directions_full.shape[0],
109 *pad_dimension, 3))**2, axis=-1)
111 # Find nearest_neighbor
112 best_dir = np.argmin(distance, axis=0)
114 if self._enforce_friedel_symmetry: 114 ↛ 116line 114 didn't jump to line 116, because the condition on line 114 was never false
115 best_dir = best_dir % self._number_of_coefficients
116 return best_dir
118 def get_function_values(self, probed_directions: NDArray) -> NDArray[float]:
119 """
120 Calculate the value of the basis functions from an array of x-y-z vectors.
122 Parameters
123 ----------
124 probed_directions
125 Array with length 3 along its last axis
127 Returns
128 -------
129 Array with same shape as input array except for the last axis, which now
130 has length ``N``, i.e., the number of sensitivity directions.
132 """
134 best_dir = self.find_nearest_neighbor_index(probed_directions)
135 input_shape = probed_directions.shape
136 output_array = np.zeros((*input_shape[:-1], self._number_of_coefficients))
137 for mode_number in range(self._number_of_coefficients):
138 output_array[best_dir == mode_number, mode_number] = 1.0
139 return output_array
141 def get_amplitudes(self, coefficients: NDArray[float],
142 probed_directions: NDArray[float]) -> NDArray[float]:
143 """
144 Calculate function values of an array of coefficients.
146 Parameters
147 ----------
148 coefficients
149 Array of coefficients with coefficient number along its last index.
150 probed_directions
151 Array with length 3 along its last axis.
153 Returns
154 -------
155 Array with function values. The shape of the array is
156 ``(*coefficients.shape[:-1], *probed_directions.shape[:-1])``.
157 """
158 final_shape = (*coefficients.shape[:-1], *probed_directions.shape[:-1])
159 nn_index = self.find_nearest_neighbor_index(probed_directions).ravel()
160 amplitudes = np.zeros((np.prod(coefficients.shape[:-1]), np.prod(probed_directions.shape[:-1])))
161 coefficients = np.reshape(coefficients, (np.prod(coefficients.shape[:-1]),
162 coefficients.shape[-1]))
163 for coeff_index in range(amplitudes.shape[0]):
164 amplitudes[coeff_index, :] = coefficients[coeff_index, nn_index]
166 return amplitudes.reshape(final_shape)
168 def get_second_moments(self, coefficients: NDArray[float]) -> NDArray[float]:
169 """
170 Calculate the second moments of the functions described by :attr:`coefficients`.
172 Parameters
173 ----------
174 coefficients
175 An array of coefficients (or residuals) of arbitrary shape so long as the last
176 axis has the same size as the number of detector channels.
178 Returns
179 -------
180 Array containing the second moments of the functions described by coefficients,
181 formatted as rank-two tensors with tensor indices in the last 2 dimensions.
182 """
184 if not self._enforce_friedel_symmetry:
185 raise NotImplementedError('NearestNeighbor.get_second_moments does not support'
186 ' cases with Friedel symmetry.')
188 second_moments_array = np.zeros((*coefficients.shape[:-1], 3, 3))
190 sumint = np.zeros(coefficients.shape[:-1])
191 sumxx = np.zeros(coefficients.shape[:-1])
192 sumxy = np.zeros(coefficients.shape[:-1])
193 sumxz = np.zeros(coefficients.shape[:-1])
194 sumyy = np.zeros(coefficients.shape[:-1])
195 sumyz = np.zeros(coefficients.shape[:-1])
196 sumzz = np.zeros(coefficients.shape[:-1])
198 for mode_number in range(len(self)):
200 sumint += coefficients[..., mode_number]
201 sumxx += coefficients[..., mode_number] * self._directions_full[mode_number, 0]**2
202 sumxy += coefficients[..., mode_number] * self._directions_full[mode_number, 0]\
203 * self._directions_full[mode_number, 1]
204 sumxz += coefficients[..., mode_number] * self._directions_full[mode_number, 0]\
205 * self._directions_full[mode_number, 2]
206 sumyy += coefficients[..., mode_number] * self._directions_full[mode_number, 1]**2
207 sumyz += coefficients[..., mode_number] * self._directions_full[mode_number, 1]\
208 * self._directions_full[mode_number, 2]
209 sumzz += coefficients[..., mode_number] * self._directions_full[mode_number, 2]**2
211 second_moments_array[..., 0, 0] = sumxx
212 second_moments_array[..., 0, 1] = sumxy
213 second_moments_array[..., 0, 2] = sumxz
214 second_moments_array[..., 1, 0] = sumxy
215 second_moments_array[..., 1, 1] = sumyy
216 second_moments_array[..., 1, 2] = sumyz
217 second_moments_array[..., 2, 0] = sumxz
218 second_moments_array[..., 2, 1] = sumyz
219 second_moments_array[..., 2, 2] = sumzz
221 return second_moments_array
223 def get_spherical_harmonic_coefficients(
224 self,
225 coefficients: NDArray[float],
226 ell_max: int = None
227 ) -> NDArray[float]:
228 """ Computes and rturns the spherical harmonic coefficients of the spherical function
229 represented by the provided :attr:`coefficients` using a Driscoll-Healy grid.
231 For details on the Driscoll-Healy grid, see
232 `the SHTools page <https://shtools.github.io/SHTOOLS/grid-formats.html>`_ for a
233 comprehensive overview.
235 Parameters
236 ----------
237 coefficients
238 An array of coefficients of arbitrary shape, provided that the
239 last dimension contains the coefficients for one function.
240 ell_max
241 The bandlimit of the spherical harmonic expansion.
243 """
244 dh_grid_size = 2*ell_max + 1
245 mapper = SphericalHarmonicMapper(ell_max=ell_max, polar_resolution=dh_grid_size,
246 azimuthal_resolution=dh_grid_size,
247 enforce_friedel_symmetry=self._enforce_friedel_symmetry)
248 coordinates = mapper.unit_vectors
249 amplitudes = self.get_amplitudes(coefficients, coordinates)
250 spherical_harmonics_coefficients = mapper.get_harmonic_coefficients(amplitudes)
251 return spherical_harmonics_coefficients
253 def _get_projection_matrix(self, probed_coordinates: ProbedCoordinates = None) -> NDArray[float]:
254 """ Computes the matrix necessary for forward and gradient calculations.
255 Called when the coordinate system has been updated, or one of
256 :attr:`kernel_scale_parameter` or :attr:`grid_scale` has been changed."""
257 if probed_coordinates is None: 257 ↛ 258line 257 didn't jump to line 258, because the condition on line 257 was never true
258 probed_coordinates = self._probed_coordinates
259 return self.get_function_values(probed_coordinates.vector)
261 def get_sub_geometry(self,
262 direction_index: int,
263 geometry: Geometry,
264 data_container: DataContainer = None,
265 ) -> tuple[Geometry, tuple[NDArray[float], NDArray[float]]]:
266 """ Create and return a geometry object corresponding to a scalar tomography problem for
267 scattering along the sensitivity direction with index :attr:`direction_index`.
268 If optionally a :class:`mumott.DataContainer` is provided, the sinograms and weights for this
269 scalar tomography problem will alse be returned.
271 Used for an implementation of the algorithm descibed in [Schaff2015]_.
273 Parameters
274 ----------
275 direction_index
276 Index of the sensitivity direction.
277 geometry
278 :class:`mumott.Geometry` object of the full problem.
279 data_container (optional)
280 :class:`mumott.DataContainer` compatible with :attr:`Geometry` from which a scalar dataset
281 will be constructed.
283 returns
284 -------
285 sub_geometry
286 Geometry of the scalar problem.
287 data_tuple
288 :class:`Tuple` containing two numpy arrays. :attr:`data_tuple[0]` is the data of the
289 scalar problem. :attr:`data_tuple[1]` are the weights.
290 """
291 if self._integration_mode != 'midpoint':
292 logger.info("The 'Discrete Directions' reconstruction workflow has not been tested"
293 "with detector segment integration. Set :attr:`integration_mode` to ``'midpoint'``"
294 ' or proceed with caution.')
296 # Get projection weights
297 probed_coordinates = ProbedCoordinates()
298 probed_coordinates.vector = geometry.probed_coordinates.vector
299 projection_matrix = self._get_integrated_projection_matrix(probed_coordinates)[..., direction_index]
301 # Copy over certain parts of geometry
302 sub_geometry = deepcopy(geometry)
303 sub_geometry.delete_projections()
304 sub_geometry.detector_angles = np.array([0])
305 sub_geometry.detector_direction_origin = np.array([0, 0, 0])
306 sub_geometry.detector_direction_positive_90 = np.array([0, 0, 0])
308 if data_container is not None:
309 data_list = []
310 weight_list = []
312 for projection_index in range(len(geometry)):
313 if np.any(projection_matrix[projection_index, :] > 0.0):
315 # append sub geometry
316 sub_geometry.append(deepcopy(geometry[projection_index]))
318 # Load data if given
319 if data_container is not None:
321 projection_weight = projection_matrix[projection_index, :]
322 weighted_weights = data_container.projections[projection_index].weights\
323 * projection_weight[np.newaxis, np.newaxis, :]
324 weighted_data = data_container.projections[projection_index].data\
325 * weighted_weights
327 weight_list.append(np.sum(weighted_weights, axis=-1))
328 summed_data = np.sum(weighted_data, axis=-1)
329 data_list.append(
330 np.divide(summed_data,
331 weight_list[-1],
332 out=np.zeros(summed_data.shape),
333 where=weight_list[-1] != 0)
334 ) # Avoid runtime warning when weights are zero.
336 if data_container is None:
337 return sub_geometry, None
338 elif len(data_list) == 0:
339 logger.warning('No projections found for current direction.')
340 return sub_geometry, None
341 else:
342 data_array = np.stack(data_list, axis=0)
343 weight_array = np.stack(weight_list, axis=0)
344 return sub_geometry, (data_array, weight_array)
346 # TODO there could be a bit of a speedup by doing this without matrix products
347 def forward(self,
348 coefficients: NDArray[float],
349 indices: NDArray[int] = None) -> NDArray[float]:
350 """ Carries out a forward computation of projections from reciprocal space modes to
351 detector channels, for one or several tomographic projections.
353 Parameters
354 ----------
355 coefficients
356 An array of coefficients, of arbitrary shape so long as the last
357 axis has the same size as this basis set.
358 indices
359 Optional. Indices of the tomographic projections for which the forward
360 computation is to be performed. If ``None``, the forward computation will
361 be performed for all projections.
363 Returns
364 -------
365 An array of values on the detector corresponding to the :attr:`coefficients` given.
366 If :attr:`indices` contains exactly one index, the shape is ``(coefficients.shape[:-1], J)``
367 where ``J`` is the number of detector segments. If :attr:`indices` is ``None`` or contains
368 several indices, the shape is ``(N, coefficients.shape[1:-1], J)`` where ``N``
369 is the number of tomographic projections for which the computation is performed.
370 """
371 assert coefficients.shape[-1] == len(self)
372 self._update()
373 output = np.zeros(coefficients.shape[:-1] + (self._projection_matrix.shape[1],),
374 coefficients.dtype)
375 if indices is None: 375 ↛ 379line 375 didn't jump to line 379, because the condition on line 375 was never false
376 framewise_contraction_transpose(self._projection_matrix,
377 coefficients,
378 output)
379 elif indices.size == 1:
380 np.einsum('ijk, ...k -> ...j',
381 self._projection_matrix[indices],
382 coefficients,
383 out=output,
384 optimize='greedy',
385 casting='unsafe')
386 else:
387 framewise_contraction_transpose(self._projection_matrix[indices],
388 coefficients,
389 output)
390 return output
392 def gradient(self,
393 coefficients: NDArray[float],
394 indices: NDArray[int] = None) -> NDArray[float]:
395 """ Carries out a gradient computation of projections of projections from reciprocal space modes to
396 detector channels, for one or several tomographic projections.
398 Parameters
399 ----------
400 coefficients
401 An array of coefficients (or residuals) of arbitrary shape so long as the last
402 axis has the same size as the number of detector channels.
403 indices
404 Optional. Indices of the tomographic projections for which the gradient
405 computation is to be performed. If ``None``, the gradient computation will
406 be performed for all projections.
408 Returns
409 -------
410 An array of gradient values based on the :attr:`coefficients` given.
411 If :attr:`indices` contains exactly one index, the shape is ``(coefficients.shape[:-1], J)``
412 where ``J`` is the number of detector segments. If indices is ``None`` or contains
413 several indices, the shape is ``(N, coefficients.shape[1:-1], J)`` where ``N``
414 is the number of tomographic projections for which the computation is performed.
415 """
416 self._update()
417 output = np.zeros(coefficients.shape[:-1] + (self._projection_matrix.shape[2],),
418 coefficients.dtype)
419 if indices is None: 419 ↛ 423line 419 didn't jump to line 423, because the condition on line 419 was never false
420 framewise_contraction(self._projection_matrix,
421 coefficients,
422 output)
423 elif indices.size == 1:
424 np.einsum('ikj, ...k -> ...j',
425 self._projection_matrix[indices],
426 coefficients,
427 out=output,
428 optimize='greedy',
429 casting='unsafe')
430 else:
431 framewise_contraction(self._projection_matrix[indices],
432 coefficients,
433 output)
434 return output
436 def get_output(self,
437 coefficients: NDArray) -> Dict[str, Any]:
438 r""" Returns a dictionary of output data for a given array of basis set coefficients.
440 Parameters
441 ----------
442 coefficients
443 An array of coefficients of arbitrary shape and dimensions, except
444 its last dimension must be the same length as the :attr:`len` of this instance.
445 Computations only operate over the last axis of :attr:`coefficients`, so derived
446 properties in the output will have the shape ``(*coefficients.shape[:-1], ...)``.
448 Returns
449 -------
450 A dictionary containing information about the optimized function.
452 """
453 assert coefficients.shape[-1] == len(self)
454 # Update to ensure non-dirty output state.
455 self._update()
456 output_dictionary = {}
458 # basis set-specific information
459 output_dictionary['name'] = type(self).__name__
460 output_dictionary['coefficients'] = coefficients.copy()
461 output_dictionary['grid'] = self.grid
462 output_dictionary['enforce_friedel_symmetry'] = self._enforce_friedel_symmetry
463 output_dictionary['projection_matrix'] = self._projection_matrix.copy()
464 output_dictionary['hash'] = hex(hash(self))
466 # Analysis is easily done in real space.
467 tensors = self.get_second_moments(coefficients)
468 output_dictionary['second_moments'] = tensors
470 w, v = np.linalg.eigh(tensors.reshape(-1, 3, 3))
472 # Some complicated sorting logic to sort eigenvectors per ascending eigenvalues.
473 sorting = np.argsort(w, axis=1).reshape(-1, 3, 1)
474 v = v.transpose(0, 2, 1)
475 v = np.take_along_axis(v, sorting, axis=1)
476 v = v.transpose(0, 2, 1)
477 v = v / np.sqrt(np.sum(v ** 2, axis=1).reshape(-1, 1, 3))
478 eigenvectors = v.reshape(coefficients.shape[:-1] + (3, 3,))
479 output_dictionary['eigenvectors'] = eigenvectors
481 return output_dictionary
483 def __len__(self) -> int:
484 return self._number_of_coefficients
486 def __hash__(self) -> int:
487 """Returns a hash reflecting the internal state of the instance.
489 Returns
490 -------
491 A hash of the internal state of the instance,
492 cast as an ``int``.
493 """
494 to_hash = [self.grid,
495 self._enforce_friedel_symmetry,
496 self._projection_matrix,
497 self._probed_coordinates_hash]
498 return int(list_to_hash(to_hash), 16)
500 def _update(self) -> None:
501 # We only run updates if the hashes do not match.
502 if self.is_dirty: 502 ↛ 503line 502 didn't jump to line 503, because the condition on line 502 was never true
503 self._projection_matrix = self._get_integrated_projection_matrix()
504 self._probed_coordinates_hash = hash(self._probed_coordinates)
506 @property
507 def is_dirty(self) -> bool:
508 return hash(self._probed_coordinates) != self._probed_coordinates_hash
510 @property
511 def projection_matrix(self) -> NDArray:
512 """ The matrix used to project spherical functions from the unit sphere onto the detector.
513 If ``v`` is a vector of gaussian kernel coefficients, and ``M`` is the ``projection_matrix``,
514 then ``M @ v`` gives the corresponding values on the detector segments associated with
515 each projection. ``M[i] @ v`` gives the values on the detector segments associated with
516 projection ``i``.
517 """
518 self._update()
519 return self._projection_matrix
521 @property
522 def enforce_friedel_symmetry(self) -> bool:
523 """ If ``True``, Friedel symmetry is enforced, i.e., the point
524 :math:`-r` is treated as equivalent to :math:`r`. """
525 return self._enforce_friedel_symmetry
527 @property
528 def grid(self) -> Tuple[NDArray['float'], NDArray['float']]:
529 r""" Returns the polar and azimuthal angles of the grid used by the basis.
531 Returns
532 -------
533 A ``Tuple`` with contents ``(polar_angle, azimuthal_angle)``, where the
534 polar angle is defined as :math:`\arccos(z)`.
535 """
536 return self._directions_full[:self._number_of_coefficients, :]
538 @property
539 def grid_hash(self) -> str:
540 """ Returns a hash of :attr:`grid`.
541 """
542 return list_to_hash([self.grid])
544 @property
545 def projection_matrix_hash(self) -> str:
546 """ Returns a hash of :attr:`projection_matrix`.
547 """
548 return list_to_hash([self.projection_matrix])
550 def __str__(self) -> str:
551 wdt = 74
552 s = [self.__class__.__name__]
553 s += ['-' * wdt]
554 s += [''.center(wdt)]
555 s += ['-' * wdt]
556 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60):
557 s += ['{:18} : {}'.format('number of directions', len(self))]
558 s += ['{:18} : {}'.format('grid_hash', self.grid_hash[:6])]
559 s += ['{:18} : {}'.format('enforce_friedel_symmetry', self.enforce_friedel_symmetry)]
560 s += ['{:18} : {}'.format('projection_matrix_hash', self.projection_matrix_hash[2:8])]
561 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
562 s += ['-' * wdt]
563 return '\n'.join(s)
565 def _repr_html_(self) -> str:
566 s = []
567 s += [f'<h3>{self.__class__.__name__}</h3>']
568 s += ['<table border="1" class="dataframe">']
569 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
570 s += ['<tbody>']
571 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
572 s += ['<tr><td style="text-align: left;">grid_hash</td>']
573 s += [f'<td>{len(self.grid_hash)}</td><td>{self.grid_hash[:6]}</td></tr>']
574 s += ['<tr><td style="text-align: left;">enforce_friedel_symmetry</td>']
575 s += [f'<td>1</td>'
576 f'<td>{self.enforce_friedel_symmetry}</td></tr>']
577 s += ['<tr><td style="text-align: left;">projection_matrix</td>']
578 s += [f'<td>{len(self.projection_matrix_hash)}</td>'
579 f'<td>{self.projection_matrix_hash[:6]}</td></tr>']
580 s += ['<tr><td style="text-align: left;">hash</td>']
581 s += [f'<td>{len(hex(hash(self)))}</td><td>{hex(hash(self))[2:8]}</td></tr>']
582 s += ['</tbody>']
583 s += ['</table>']
584 return '\n'.join(s)