import logging
from typing import Any, Dict, Tuple
from copy import deepcopy
import numpy as np
from numpy.typing import NDArray
from mumott import ProbedCoordinates, DataContainer, Geometry, SphericalHarmonicMapper
from mumott.core.hashing import list_to_hash
from mumott.methods.utilities.tensor_operations import (framewise_contraction,
framewise_contraction_transpose)
from .base_basis_set import BasisSet
logger = logging.getLogger(__name__)
[docs]class NearestNeighbor(BasisSet):
r""" Basis set class for nearest-neighbor interpolation. Used to construct methods similar to that
presented in `Schaff et al. (2015) <https://doi.org/10.1038/nature16060>`_.
By default this representation is sparse and maps only a single direction on the sphere
to each detector segment. This can be changed; see ``kwargs``.
Parameters
----------
directions : NDArray[float]
Two-dimensional Array containing the ``N`` sensitivity directions with shape ``(N, 3)``.
probed_coordinates : ProbedCoordinates
Optional. Coordinates on the sphere probed at each detector segment by the
experimental method. Its construction from the system geometry is method-dependent.
By default, an empty instance of :class:`mumott.ProbedCoordinates` is created.
enforce_friedel_symmetry : bool
If set to ``True``, Friedel symmetry will be enforced, using the assumption that points
on opposite sides of the sphere are equivalent.
kwargs
Miscellaneous arguments which relate to segment integrations can be
passed as keyword arguments:
integration_mode
Mode to integrate line segments on the reciprocal space sphere. Possible options are
``'simpson'``, ``'midpoint'``, ``'romberg'``, ``'trapezoid'``.
``'simpson'``, ``'trapezoid'``, and ``'romberg'`` use adaptive
integration with the respective quadrature rule from ``scipy.integrate``.
``'midpoint'`` uses a single mid-point approximation of the integral.
Default value is ``'simpson'``.
n_integration_starting_points
Number of points used in the first iteration of the adaptive integration.
The number increases by the rule ``N`` ← ``2 * N - 1`` for each iteration.
Default value is 3.
integration_tolerance
Tolerance for the maximum relative error between iterations before the integral
is considered converged. Default is ``1e-3``.
integration_maxiter
Maximum number of iterations. Default is ``10``.
enforce_sparsity
If ``True``, limites the number of basis set elements
that can map to each detector segemnt. Default is ``False``.
sparsity_count
If ``enforce_sparsity`` is set to ``True``, the number of
basis set elements that can map to each detector segment.
Default value is ``1``.
"""
def __init__(self,
directions: NDArray[float],
probed_coordinates: ProbedCoordinates = None,
enforce_friedel_symmetry: bool = True,
**kwargs):
# This basis set struggles with integral convergence due to sharp transitions
kwargs.update(dict(integration_tolerance=kwargs.get('integration_tolerance', 1e-3),
sparsity_count=kwargs.get('sparsity_count', 1)))
super().__init__(probed_coordinates, **kwargs)
# Handling grid of directions
self._number_of_coefficients = directions.shape[0]
if enforce_friedel_symmetry:
self._directions_full = np.concatenate((directions, -directions), axis=0)
else:
self._directions_full = np.array(directions)
self._probed_coordinates_hash = hash(self.probed_coordinates)
self._enforce_friedel_symmetry = enforce_friedel_symmetry
self._projection_matrix = self._get_integrated_projection_matrix()
[docs] def find_nearest_neighbor_index(self, probed_directions: NDArray[float]) -> NDArray[int]:
"""
Caluculate the nearest neighbor sensitivity directions for an array of x-y-z vectors.
Parameters
----------
probed_directions
Array with length 3 along its last axis
Returns
-------
Array with same shape as the input except for the last dimension, which
contains the index of the nearest-neighbor sensitivity direction.
"""
# normalize input directions
input_shape = probed_directions.shape
normed_probed_directions = probed_directions / \
np.linalg.norm(probed_directions, axis=-1)[..., np.newaxis]
# Find distance (3D euclidian) between each probed direction and sensitivity direction
pad_dimension = (1,) * (len(input_shape)-1)
distance = np.sum((normed_probed_directions[np.newaxis, ...] -
self._directions_full.reshape(self._directions_full.shape[0],
*pad_dimension, 3))**2, axis=-1)
# Find nearest_neighbor
best_dir = np.argmin(distance, axis=0)
if self._enforce_friedel_symmetry:
best_dir = best_dir % self._number_of_coefficients
return best_dir
[docs] def get_function_values(self, probed_directions: NDArray) -> NDArray[float]:
"""
Calculate the value of the basis functions from an array of x-y-z vectors.
Parameters
----------
probed_directions
Array with length 3 along its last axis
Returns
-------
Array with same shape as input array except for the last axis, which now
has length ``N``, i.e., the number of sensitivity directions.
"""
best_dir = self.find_nearest_neighbor_index(probed_directions)
input_shape = probed_directions.shape
output_array = np.zeros((*input_shape[:-1], self._number_of_coefficients))
for mode_number in range(self._number_of_coefficients):
output_array[best_dir == mode_number, mode_number] = 1.0
return output_array
[docs] def get_amplitudes(self, coefficients: NDArray[float],
probed_directions: NDArray[float]) -> NDArray[float]:
"""
Calculate function values of an array of coefficients.
Parameters
----------
coefficients
Array of coefficients with coefficient number along its last index.
probed_directions
Array with length 3 along its last axis.
Returns
-------
Array with function values. The shape of the array is
``(*coefficients.shape[:-1], *probed_directions.shape[:-1])``.
"""
final_shape = (*coefficients.shape[:-1], *probed_directions.shape[:-1])
nn_index = self.find_nearest_neighbor_index(probed_directions).ravel()
amplitudes = np.zeros((np.prod(coefficients.shape[:-1]), np.prod(probed_directions.shape[:-1])))
coefficients = np.reshape(coefficients, (np.prod(coefficients.shape[:-1]),
coefficients.shape[-1]))
for coeff_index in range(amplitudes.shape[0]):
amplitudes[coeff_index, :] = coefficients[coeff_index, nn_index]
return amplitudes.reshape(final_shape)
[docs] def get_second_moments(self, coefficients: NDArray[float]) -> NDArray[float]:
"""
Calculate the second moments of the functions described by :attr:`coefficients`.
Parameters
----------
coefficients
An array of coefficients (or residuals) of arbitrary shape so long as the last
axis has the same size as the number of detector channels.
Returns
-------
Array containing the second moments of the functions described by coefficients,
formatted as rank-two tensors with tensor indices in the last 2 dimensions.
"""
if not self._enforce_friedel_symmetry:
raise NotImplementedError('NearestNeighbor.get_second_moments does not support'
' cases with Friedel symmetry.')
second_moments_array = np.zeros((*coefficients.shape[:-1], 3, 3))
sumint = np.zeros(coefficients.shape[:-1])
sumxx = np.zeros(coefficients.shape[:-1])
sumxy = np.zeros(coefficients.shape[:-1])
sumxz = np.zeros(coefficients.shape[:-1])
sumyy = np.zeros(coefficients.shape[:-1])
sumyz = np.zeros(coefficients.shape[:-1])
sumzz = np.zeros(coefficients.shape[:-1])
for mode_number in range(len(self)):
sumint += coefficients[..., mode_number]
sumxx += coefficients[..., mode_number] * self._directions_full[mode_number, 0]**2
sumxy += coefficients[..., mode_number] * self._directions_full[mode_number, 0]\
* self._directions_full[mode_number, 1]
sumxz += coefficients[..., mode_number] * self._directions_full[mode_number, 0]\
* self._directions_full[mode_number, 2]
sumyy += coefficients[..., mode_number] * self._directions_full[mode_number, 1]**2
sumyz += coefficients[..., mode_number] * self._directions_full[mode_number, 1]\
* self._directions_full[mode_number, 2]
sumzz += coefficients[..., mode_number] * self._directions_full[mode_number, 2]**2
second_moments_array[..., 0, 0] = sumxx
second_moments_array[..., 0, 1] = sumxy
second_moments_array[..., 0, 2] = sumxz
second_moments_array[..., 1, 0] = sumxy
second_moments_array[..., 1, 1] = sumyy
second_moments_array[..., 1, 2] = sumyz
second_moments_array[..., 2, 0] = sumxz
second_moments_array[..., 2, 1] = sumyz
second_moments_array[..., 2, 2] = sumzz
return second_moments_array
[docs] def get_spherical_harmonic_coefficients(
self,
coefficients: NDArray[float],
ell_max: int = None
) -> NDArray[float]:
""" Computes and rturns the spherical harmonic coefficients of the spherical function
represented by the provided :attr:`coefficients` using a Driscoll-Healy grid.
For details on the Driscoll-Healy grid, see
`the SHTools page <https://shtools.github.io/SHTOOLS/grid-formats.html>`_ for a
comprehensive overview.
Parameters
----------
coefficients
An array of coefficients of arbitrary shape, provided that the
last dimension contains the coefficients for one function.
ell_max
The bandlimit of the spherical harmonic expansion.
"""
dh_grid_size = 2*ell_max + 1
mapper = SphericalHarmonicMapper(ell_max=ell_max, polar_resolution=dh_grid_size,
azimuthal_resolution=dh_grid_size,
enforce_friedel_symmetry=self._enforce_friedel_symmetry)
coordinates = mapper.unit_vectors
amplitudes = self.get_amplitudes(coefficients, coordinates)
spherical_harmonics_coefficients = mapper.get_harmonic_coefficients(amplitudes)
return spherical_harmonics_coefficients
def _get_projection_matrix(self, probed_coordinates: ProbedCoordinates = None) -> NDArray[float]:
""" Computes the matrix necessary for forward and gradient calculations.
Called when the coordinate system has been updated, or one of
:attr:`kernel_scale_parameter` or :attr:`grid_scale` has been changed."""
if probed_coordinates is None:
probed_coordinates = self._probed_coordinates
return self.get_function_values(probed_coordinates.vector)
[docs] def get_sub_geometry(self,
direction_index: int,
geometry: Geometry,
data_container: DataContainer = None,
) -> tuple[Geometry, tuple[NDArray[float], NDArray[float]]]:
""" Create and return a geometry object corresponding to a scalar tomography problem for
scattering along the sensitivity direction with index :attr:`direction_index`.
If optionally a :class:`mumott.DataContainer` is provided, the sinograms and weights for this
scalar tomography problem will alse be returned.
Used for an implementation of the algorithm descibed in [Schaff2015]_.
Parameters
----------
direction_index
Index of the sensitivity direction.
geometry
:class:`mumott.Geometry` object of the full problem.
data_container (optional)
:class:`mumott.DataContainer` compatible with :attr:`Geometry` from which a scalar dataset
will be constructed.
returns
-------
sub_geometry
Geometry of the scalar problem.
data_tuple
:class:`Tuple` containing two numpy arrays. :attr:`data_tuple[0]` is the data of the
scalar problem. :attr:`data_tuple[1]` are the weights.
"""
if self._integration_mode != 'midpoint':
logger.info("The 'Discrete Directions' reconstruction workflow has not been tested"
"with detector segment integration. Set :attr:`integration_mode` to ``'midpoint'``"
' or proceed with caution.')
# Get projection weights
probed_coordinates = ProbedCoordinates()
probed_coordinates.vector = geometry.probed_coordinates.vector
projection_matrix = self._get_integrated_projection_matrix(probed_coordinates)[..., direction_index]
# Copy over certain parts of geometry
sub_geometry = deepcopy(geometry)
sub_geometry.delete_projections()
sub_geometry.detector_angles = np.array([0])
sub_geometry.detector_direction_origin = np.array([0, 0, 0])
sub_geometry.detector_direction_positive_90 = np.array([0, 0, 0])
if data_container is not None:
data_list = []
weight_list = []
for projection_index in range(len(geometry)):
if np.any(projection_matrix[projection_index, :] > 0.0):
# append sub geometry
sub_geometry.append(deepcopy(geometry[projection_index]))
# Load data if given
if data_container is not None:
projection_weight = projection_matrix[projection_index, :]
weighted_weights = data_container.projections[projection_index].weights\
* projection_weight[np.newaxis, np.newaxis, :]
weighted_data = data_container.projections[projection_index].data\
* weighted_weights
weight_list.append(np.sum(weighted_weights, axis=-1))
summed_data = np.sum(weighted_data, axis=-1)
data_list.append(
np.divide(summed_data,
weight_list[-1],
out=np.zeros(summed_data.shape),
where=weight_list[-1] != 0)
) # Avoid runtime warning when weights are zero.
if data_container is None:
return sub_geometry, None
elif len(data_list) == 0:
logger.warning('No projections found for current direction.')
return sub_geometry, None
else:
data_array = np.stack(data_list, axis=0)
weight_array = np.stack(weight_list, axis=0)
return sub_geometry, (data_array, weight_array)
# TODO there could be a bit of a speedup by doing this without matrix products
[docs] def forward(self,
coefficients: NDArray[float],
indices: NDArray[int] = None) -> NDArray[float]:
""" Carries out a forward computation of projections from reciprocal space modes to
detector channels, for one or several tomographic projections.
Parameters
----------
coefficients
An array of coefficients, of arbitrary shape so long as the last
axis has the same size as this basis set.
indices
Optional. Indices of the tomographic projections for which the forward
computation is to be performed. If ``None``, the forward computation will
be performed for all projections.
Returns
-------
An array of values on the detector corresponding to the :attr:`coefficients` given.
If :attr:`indices` contains exactly one index, the shape is ``(coefficients.shape[:-1], J)``
where ``J`` is the number of detector segments. If :attr:`indices` is ``None`` or contains
several indices, the shape is ``(N, coefficients.shape[1:-1], J)`` where ``N``
is the number of tomographic projections for which the computation is performed.
"""
assert coefficients.shape[-1] == len(self)
self._update()
output = np.zeros(coefficients.shape[:-1] + (self._projection_matrix.shape[1],),
coefficients.dtype)
if indices is None:
framewise_contraction_transpose(self._projection_matrix,
coefficients,
output)
elif indices.size == 1:
np.einsum('ijk, ...k -> ...j',
self._projection_matrix[indices],
coefficients,
out=output,
optimize='greedy',
casting='unsafe')
else:
framewise_contraction_transpose(self._projection_matrix[indices],
coefficients,
output)
return output
[docs] def gradient(self,
coefficients: NDArray[float],
indices: NDArray[int] = None) -> NDArray[float]:
""" Carries out a gradient computation of projections of projections from reciprocal space modes to
detector channels, for one or several tomographic projections.
Parameters
----------
coefficients
An array of coefficients (or residuals) of arbitrary shape so long as the last
axis has the same size as the number of detector channels.
indices
Optional. Indices of the tomographic projections for which the gradient
computation is to be performed. If ``None``, the gradient computation will
be performed for all projections.
Returns
-------
An array of gradient values based on the :attr:`coefficients` given.
If :attr:`indices` contains exactly one index, the shape is ``(coefficients.shape[:-1], J)``
where ``J`` is the number of detector segments. If indices is ``None`` or contains
several indices, the shape is ``(N, coefficients.shape[1:-1], J)`` where ``N``
is the number of tomographic projections for which the computation is performed.
"""
self._update()
output = np.zeros(coefficients.shape[:-1] + (self._projection_matrix.shape[2],),
coefficients.dtype)
if indices is None:
framewise_contraction(self._projection_matrix,
coefficients,
output)
elif indices.size == 1:
np.einsum('ikj, ...k -> ...j',
self._projection_matrix[indices],
coefficients,
out=output,
optimize='greedy',
casting='unsafe')
else:
framewise_contraction(self._projection_matrix[indices],
coefficients,
output)
return output
[docs] def get_output(self,
coefficients: NDArray) -> Dict[str, Any]:
r""" Returns a dictionary of output data for a given array of basis set coefficients.
Parameters
----------
coefficients
An array of coefficients of arbitrary shape and dimensions, except
its last dimension must be the same length as the :attr:`len` of this instance.
Computations only operate over the last axis of :attr:`coefficients`, so derived
properties in the output will have the shape ``(*coefficients.shape[:-1], ...)``.
Returns
-------
A dictionary containing information about the optimized function.
"""
assert coefficients.shape[-1] == len(self)
# Update to ensure non-dirty output state.
self._update()
output_dictionary = {}
# basis set-specific information
output_dictionary['name'] = type(self).__name__
output_dictionary['coefficients'] = coefficients.copy()
output_dictionary['grid'] = self.grid
output_dictionary['enforce_friedel_symmetry'] = self._enforce_friedel_symmetry
output_dictionary['projection_matrix'] = self._projection_matrix.copy()
output_dictionary['hash'] = hex(hash(self))
# Analysis is easily done in real space.
tensors = self.get_second_moments(coefficients)
output_dictionary['second_moments'] = tensors
w, v = np.linalg.eigh(tensors.reshape(-1, 3, 3))
# Some complicated sorting logic to sort eigenvectors per ascending eigenvalues.
sorting = np.argsort(w, axis=1).reshape(-1, 3, 1)
v = v.transpose(0, 2, 1)
v = np.take_along_axis(v, sorting, axis=1)
v = v.transpose(0, 2, 1)
v = v / np.sqrt(np.sum(v ** 2, axis=1).reshape(-1, 1, 3))
eigenvectors = v.reshape(coefficients.shape[:-1] + (3, 3,))
output_dictionary['eigenvectors'] = eigenvectors
return output_dictionary
def __len__(self) -> int:
return self._number_of_coefficients
def __hash__(self) -> int:
"""Returns a hash reflecting the internal state of the instance.
Returns
-------
A hash of the internal state of the instance,
cast as an ``int``.
"""
to_hash = [self.grid,
self._enforce_friedel_symmetry,
self._projection_matrix,
self._probed_coordinates_hash]
return int(list_to_hash(to_hash), 16)
def _update(self) -> None:
# We only run updates if the hashes do not match.
if self.is_dirty:
self._projection_matrix = self._get_integrated_projection_matrix()
self._probed_coordinates_hash = hash(self._probed_coordinates)
@property
def is_dirty(self) -> bool:
return hash(self._probed_coordinates) != self._probed_coordinates_hash
@property
def projection_matrix(self) -> NDArray:
""" The matrix used to project spherical functions from the unit sphere onto the detector.
If ``v`` is a vector of gaussian kernel coefficients, and ``M`` is the ``projection_matrix``,
then ``M @ v`` gives the corresponding values on the detector segments associated with
each projection. ``M[i] @ v`` gives the values on the detector segments associated with
projection ``i``.
"""
self._update()
return self._projection_matrix
@property
def enforce_friedel_symmetry(self) -> bool:
""" If ``True``, Friedel symmetry is enforced, i.e., the point
:math:`-r` is treated as equivalent to :math:`r`. """
return self._enforce_friedel_symmetry
@property
def grid(self) -> Tuple[NDArray['float'], NDArray['float']]:
r""" Returns the polar and azimuthal angles of the grid used by the basis.
Returns
-------
A ``Tuple`` with contents ``(polar_angle, azimuthal_angle)``, where the
polar angle is defined as :math:`\arccos(z)`.
"""
return self._directions_full[:self._number_of_coefficients, :]
@property
def grid_hash(self) -> str:
""" Returns a hash of :attr:`grid`.
"""
return list_to_hash([self.grid])
@property
def projection_matrix_hash(self) -> str:
""" Returns a hash of :attr:`projection_matrix`.
"""
return list_to_hash([self.projection_matrix])
def __str__(self) -> str:
wdt = 74
s = [self.__class__.__name__]
s += ['-' * wdt]
s += [''.center(wdt)]
s += ['-' * wdt]
with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60):
s += ['{:18} : {}'.format('number of directions', len(self))]
s += ['{:18} : {}'.format('grid_hash', self.grid_hash[:6])]
s += ['{:18} : {}'.format('enforce_friedel_symmetry', self.enforce_friedel_symmetry)]
s += ['{:18} : {}'.format('projection_matrix_hash', self.projection_matrix_hash[2:8])]
s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
s += ['-' * wdt]
return '\n'.join(s)
def _repr_html_(self) -> str:
s = []
s += [f'<h3>{self.__class__.__name__}</h3>']
s += ['<table border="1" class="dataframe">']
s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
s += ['<tbody>']
with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
s += ['<tr><td style="text-align: left;">grid_hash</td>']
s += [f'<td>{len(self.grid_hash)}</td><td>{self.grid_hash[:6]}</td></tr>']
s += ['<tr><td style="text-align: left;">enforce_friedel_symmetry</td>']
s += [f'<td>1</td>'
f'<td>{self.enforce_friedel_symmetry}</td></tr>']
s += ['<tr><td style="text-align: left;">projection_matrix</td>']
s += [f'<td>{len(self.projection_matrix_hash)}</td>'
f'<td>{self.projection_matrix_hash[:6]}</td></tr>']
s += ['<tr><td style="text-align: left;">hash</td>']
s += [f'<td>{len(hex(hash(self)))}</td><td>{hex(hash(self))[2:8]}</td></tr>']
s += ['</tbody>']
s += ['</table>']
return '\n'.join(s)