Coverage for local_installation_linux/mumott/methods/basis_sets/trivial_basis.py: 83%
88 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
2from typing import Any, Dict, Iterator, Tuple
4import numpy as np
5from numpy.typing import NDArray
7from mumott import ProbedCoordinates
8from mumott.core.hashing import list_to_hash
9from .base_basis_set import BasisSet
12logger = logging.getLogger(__name__)
15class TrivialBasis(BasisSet):
16 """ Basis set class for the trivial basis, i.e., the identity basis.
17 This can be used as a scaffolding class when implementing, e.g., scalar tomography,
18 as it implements all the necessary functionality to qualify as a :class:`BasisSet`.
20 Parameters
21 ----------
22 channels
23 Number of channels in the last index. Default is ``1``. For scalar data,
24 the default value of ``1`` is appropriate. For any other use-case, where the representation
25 on the sphere and the representation in detector space are equivalent,
26 such as reconstructing scalars of multiple q-ranges at once, a different
27 number of channels can be set.
28 """
29 def __init__(self, channels: int = 1):
30 self._probed_coordinates = ProbedCoordinates(vector=np.array((0., 0., 1.)))
31 self._probed_coordinates_hash = hash(self.probed_coordinates)
32 self._channels = channels
34 def _get_projection_matrix(self, probed_coordinates: ProbedCoordinates = None):
35 return np.eye(self._channels)
37 def forward(self,
38 coefficients: NDArray,
39 *args,
40 **kwargs) -> NDArray:
41 """ Returns the provided coefficients with no modification.
43 Parameters
44 ----------
45 coefficients
46 An array of coefficients, of arbitrary shape, except the last index
47 must specify the same number of channels as was specified for this basis.
49 Returns
50 -------
51 The provided :attr`coefficients` with no modification.
53 Notes
54 -----
55 The :attr:`args` and :attr:`kwargs` are ignored, but included for compatibility with methods
56 that input other arguments.
57 """
58 assert coefficients.shape[-1] == len(self)
59 return coefficients
61 def gradient(self,
62 coefficients: NDArray,
63 *args,
64 **kwargs) -> NDArray:
65 """ Returns the provided coefficients with no modification.
67 Parameters
68 ----------
69 coefficients
70 An array of coefficients of arbitrary shape except the last index
71 must specify the same number of channels as was specified for this basis.
73 Returns
74 -------
75 The provided :attr`coefficients` with no modification.
77 Notes
78 -----
79 The :attr:`args` and :attr:`kwargs` are ignored, but included for compatibility with methods
80 that input other argumetns.
81 """
82 assert coefficients.shape[-1] == len(self)
83 return coefficients
85 def get_inner_product(self,
86 u: NDArray,
87 v: NDArray) -> NDArray:
88 r""" Retrieves the inner product of two coefficient arrays, that is to say,
89 the sum-product over the last axis.
91 Parameters
92 ----------
93 u
94 The first coefficient array, of arbitrary shape and dimension.
95 v
96 The second coefficient array, of the same shape as :attr:`u`.
97 """
98 assert u.shape[-1] == len(self)
99 assert u.shape == v.shape
100 return np.einsum('...i, ...i -> ...', u, v,
101 optimize='greedy')
103 def get_output(self,
104 coefficients: NDArray) -> Dict[str, Any]:
105 r""" Returns a dictionary of output data for a given array of coefficients.
107 Parameters
108 ----------
109 coefficients
110 An array of coefficients of arbitrary shape and dimension.
111 Computations only operate over the last axis of :attr:`coefficents`, so derived
112 properties in the output will have the shape ``(*coefficients.shape[:-1], ...)``.
114 Returns
115 -------
116 A dictionary containing a dictionary with the field ``basis_set``.
118 Notes
119 -----
120 In detail, the dictionary under the key ``basis_set`` contains:
122 basis_set
123 name
124 The name of the basis set, i.e., ``'TrivialBasis'``
125 coefficients
126 A copy of :attr:`coefficients`.
127 projection_matrix
128 The identity matrix of the same size as the number of chanenls.
129 """
130 assert coefficients.shape[-1] == len(self)
131 # Update to ensure non-dirty output state.
132 self._update()
133 output_dictionary = {}
135 # basis set-specific information
136 basis_set = {}
137 output_dictionary['basis_set'] = basis_set
138 basis_set['name'] = type(self).__name__
139 basis_set['coefficients'] = coefficients.copy()
140 basis_set['projection_matrix'] = self.projection_matrix
141 basis_set['hash'] = hex(hash(self))
142 return output_dictionary
144 def get_spherical_harmonic_coefficients(
145 self,
146 coefficients: NDArray[float],
147 ell_max: int = None
148 ) -> NDArray[float]:
149 """ Convert a set of spherical harmonics coefficients to a different :attr:`ell_max`
150 by either zero-padding or truncation and return the result.
152 Parameters
153 ----------
154 coefficients
155 An array of coefficients of arbitrary shape, provided that the
156 last dimension contains the coefficients for one function.
157 ell_max
158 The band limit of the spherical harmonic expansion.
159 """
161 if coefficients.shape[-1] != len(self):
162 raise ValueError(f'The number of coefficients ({coefficients.shape[-1]}) does not match '
163 f'the expected value. ({len(self)})')
165 num_coeff_output = (ell_max+1) * (ell_max+2) // 2
167 output_array = np.zeros((*coefficients.shape[:-1], num_coeff_output))
168 output_array[..., 0] = coefficients[..., 0]
169 return output_array
171 def __iter__(self) -> Iterator[Tuple[str, Any]]:
172 """ Allows class to be iterated over and in particular be cast as a dictionary.
173 """
174 yield 'name', type(self).__name__
175 yield 'projection_matrix', self._projection_matrix
176 yield 'hash', hex(hash(self))[2:]
178 def __len__(self) -> int:
179 return self._channels
181 def __hash__(self) -> int:
182 """Returns a hash reflecting the internal state of the instance.
184 Returns
185 -------
186 A hash of the internal state of the instance,
187 cast as an ``int``.
188 """
189 to_hash = [self._channels,
190 self._probed_coordinates_hash]
191 return int(list_to_hash(to_hash), 16)
193 def _update(self) -> None:
194 if self.is_dirty:
195 self._probed_coordinates_hash = hash(self._probed_coordinates)
197 @property
198 def channels(self) -> int:
199 """ The number of channels this basis supports. """
200 return self._channels
202 @channels.setter
203 def channels(self, value: int) -> None:
204 self._channels = value
206 @property
207 def projection_matrix(self):
208 """The identity matrix of the same rank as the number of channels
209 specified."""
210 return np.eye(self._channels, dtype=np.float64)
212 @property
213 def is_dirty(self) -> bool:
214 return hash(self._probed_coordinates) != self._probed_coordinates_hash
216 def __str__(self) -> str:
217 wdt = 74
218 s = []
219 s += ['-' * wdt]
220 s += [self.__class__.__name__.center(wdt)]
221 s += ['-' * wdt]
222 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60):
223 s += ['{:18} : {}'.format('Hash', hex(hash(self))[2:8])]
224 s += ['-' * wdt]
225 return '\n'.join(s)
227 def _repr_html_(self) -> str:
228 s = []
229 s += [f'<h3>{self.__class__.__name__}</h3>']
230 s += ['<table border="1" class="dataframe">']
231 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
232 s += ['<tbody>']
233 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
234 s += [f'<td>{len(hex(hash(self)))}</td><td>{hex(hash(self))[2:8]}</td></tr>']
235 s += ['</tbody>']
236 s += ['</table>']
237 return '\n'.join(s)