Coverage for local_installation_linux/mumott/methods/basis_sets/base_basis_set.py: 92%
159 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2024-08-11 23:08 +0000
1import logging
2import numpy as np
4from abc import ABC, abstractmethod
5from mumott import ProbedCoordinates
6from scipy.integrate import simpson, romb, trapezoid
7from scipy.sparse import csr_array
9from numpy.typing import NDArray
11logger = logging.getLogger(__name__)
14class BasisSet(ABC):
16 """This is the base class from which specific basis sets are being derived.
17 """
19 def __init__(self,
20 probed_coordinates: ProbedCoordinates = None,
21 **kwargs):
22 if probed_coordinates is None:
23 probed_coordinates = ProbedCoordinates()
24 self.probed_coordinates = probed_coordinates
25 self._integration_mode = kwargs.get('integration_mode', 'simpson')
26 if self._integration_mode not in ('simpson', 'romberg', 'trapezoid', 'midpoint'): 26 ↛ 27line 26 didn't jump to line 27, because the condition on line 26 was never true
27 raise ValueError('integration_mode must be "simpson" (for integration with Simpson\'s rule), '
28 ' "midpoint" (for center-of-segment approximation), "romberg", or "trapezoid"!')
29 self._integration_tolerance = kwargs.get('integration_tolerance', 1e-5)
30 self._integration_maxiter = kwargs.get('integration_maxiter', 10)
31 self._n_integration_starting_points = kwargs.get('n_integration_starting_points', 3)
32 self._enforce_sparsity = kwargs.get('enforce_sparsity', False)
33 self._sparsity_count = kwargs.get('sparsity_count', 3)
35 @property
36 def probed_coordinates(self) -> ProbedCoordinates:
37 return self._probed_coordinates
39 @probed_coordinates.setter
40 def probed_coordinates(self, pc: ProbedCoordinates) -> None:
41 self._probed_coordinates = pc
43 @abstractmethod
44 def forward(self,
45 coefficients: NDArray,
46 indices: NDArray = None) -> NDArray:
47 pass
49 @abstractmethod
50 def gradient(self,
51 coefficients: NDArray,
52 indices: NDArray = None) -> NDArray:
53 pass
55 @abstractmethod
56 def get_spherical_harmonic_coefficients(self,
57 coefficients: NDArray,
58 ell_max: int = None,) -> NDArray:
59 pass
61 @abstractmethod
62 def _get_projection_matrix(self, probed_coordinates: ProbedCoordinates):
63 pass
65 def generate_map(self,
66 coefficients: NDArray[float],
67 resolution_in_degrees: int = 5,
68 map_half_sphere: bool = True) -> tuple[NDArray]:
69 """ Generate a (theta, phi) map of the function modeled by the input coefficients.
70 If :attr:`map_half_sphere=True` (default) a map of only the z>0 half sphere is returned.
72 Parameters
73 ----------
74 coefficients
75 One dimensional numpy array with length ``len(self)`` containing the coefficients
76 of the function to be plotted.
77 resolution_in_degrees
78 The resoution of the map in degrees. The map uses eqidistant lines in longitude
79 and latitude.
80 map_half_sphere
81 If `True` returns a map of the z>0 half sphere.
83 Returns
84 -------
85 map_intensity
86 Intensity values of the map.
87 map_theta
88 Polar cooridnates of the map.
89 map_phi
90 Azimuthal coordinates of the map.
91 """
92 # Generate coordinates.
93 if map_half_sphere:
94 steps = int(np.ceil(90/resolution_in_degrees))
95 map_theta = np.linspace(0, np.pi/2, steps + 1)
96 else:
97 steps = int(np.ceil(180/resolution_in_degrees))
98 map_theta = np.linspace(0, np.pi, steps + 1)
99 steps = int(np.ceil(360/resolution_in_degrees))
100 map_phi = np.linspace(0, 2*np.pi, steps + 1)
101 map_theta, map_phi = np.meshgrid(map_theta, map_phi, indexing='ij')
103 # Create a ProbedCoordinates to pass into `_get_projection_matrix`
104 x = np.cos(map_phi)*np.sin(map_theta)
105 y = np.sin(map_phi)*np.sin(map_theta)
106 z = np.cos(map_theta)
107 vector = np.stack((x, y, z), axis=-1)
108 probed_coordinates = ProbedCoordinates()
109 probed_coordinates.vector = vector[:, :, np.newaxis, :]
111 # Evaluate map intensity
112 basis_funciton_values = self._get_projection_matrix(probed_coordinates)[:, :, 0, :]
113 map_intensity = np.einsum('m,tpm->tp', coefficients, basis_funciton_values)
114 return map_intensity, map_theta, map_phi
116 def _make_projection_matrix_sparse(self, matrix: np.ndarray[float]) -> np.ndarray[float]:
117 if self._enforce_sparsity:
118 for i in range(matrix.shape[0]):
119 for j in range(matrix.shape[1]):
120 sorted_args = np.argsort(matrix[i, j, :])
121 for k in sorted_args[:-self._sparsity_count]:
122 matrix[i, j, k] = 0.
124 def _get_integrated_projection_matrix(self, probed_coordinates: ProbedCoordinates = None):
125 """ Uses Simpson's rule to integrate the basis set representation over
126 each detector segment on the unit sphere."""
127 # use 0 and -1 for backwards compatibility
128 if probed_coordinates is None:
129 probed_coordinates = self.probed_coordinates
130 start = probed_coordinates.vector[..., 0, :]
131 # don't normalize in-place to avoid modifying probed_coordinates
132 start = start / np.linalg.norm(start, axis=-1)[..., None]
133 start = start[..., None, :]
134 end = probed_coordinates.vector[..., -1, :]
135 end = end / np.linalg.norm(end, axis=-1)[..., None]
136 end = end[..., None, :]
137 # if segments lie on small circle, correct using waxs offset vector before slerp
138 offset = probed_coordinates.great_circle_offset[..., 0, :]
139 offset = offset[..., None, :]
140 # when run initially with probed_coordinates = None or similar
141 if np.allclose(start, end):
142 return self._get_projection_matrix(probed_coordinates)[:, :, 0]
143 if self._integration_mode == 'midpoint':
144 # Just use central point to get the projection matrix.
145 pc = ProbedCoordinates(probed_coordinates.vector[..., 1:2, :])
146 return self._get_projection_matrix(pc)[..., 0, :]
147 # segment length is subtended angle between start and end
148 corr_start = start - offset
149 corr_end = end - offset
150 segment_length = np.arccos(np.einsum('...i, ...i', corr_start, corr_end))
152 def slerp(t):
153 omega = segment_length.reshape(segment_length.shape + (1,))
154 t = t.reshape(1, 1, -1, 1)
155 # avoid division by 0, use 1st order approximation
156 where = np.isclose(abs(omega[..., 0, 0]), 0.)
157 sin_omega = np.sin(omega)
158 sin_tomega = np.sin(t * omega)
159 sin_1mtomega = np.sin((1 - t) * omega)
160 a = np.zeros(start.shape[:2] + (t.shape[2], 1), dtype=float)
161 b = np.zeros(start.shape[:2] + (t.shape[2], 1), dtype=float)
162 a[~where] = sin_1mtomega[~where] / sin_omega[~where]
163 b[~where] = sin_tomega[~where] / sin_omega[~where]
164 # sin(ax) / sin(x) ~ a for x ~ 0
165 a[where] = (1 - t) * np.ones_like(omega[where])
166 b[where] = t * np.ones_like(omega[where])
167 return a * corr_start + b * corr_end + offset
169 def quadrature(matrix, t):
170 if self._integration_mode == 'simpson':
171 return simpson(matrix, x=t, axis=-2)
172 elif self._integration_mode == 'romberg':
173 return romb(matrix, dx=1 / (t.size - 1), axis=-2)
174 elif self._integration_mode == 'trapezoid': 174 ↛ exitline 174 didn't return from function 'quadrature', because the condition on line 174 was never false
175 return trapezoid(matrix, x=t, axis=-2)
176 number_of_points = self._n_integration_starting_points
177 t = np.linspace(0, 1, number_of_points)
178 pc = ProbedCoordinates(slerp(t))
179 old_matrix = self._get_projection_matrix(pc)
180 # get an initial matrix for comparison
181 old_matrix = quadrature(old_matrix, t)
182 for i in range(self._integration_maxiter):
183 # double the number of intervals in each iteration
184 number_of_points += max(number_of_points - 1, 1)
185 t = np.linspace(0, 1, number_of_points)
186 vector = slerp(t)
187 pc = ProbedCoordinates(vector)
188 # integrate all matrices using simpson's rule
189 new_matrix = quadrature(self._get_projection_matrix(pc), t)
190 norm = np.max(abs(new_matrix - old_matrix)) / np.max(abs(new_matrix))
191 if norm < self._integration_tolerance:
192 break
193 old_matrix = new_matrix
194 else:
195 logger.warning('Projection matrix did not converge! '
196 'Try increasing integration_maxiter or reducing integration_tolerance.')
197 self._make_projection_matrix_sparse(new_matrix)
198 return new_matrix
200 @property
201 def integration_mode(self) -> str:
202 """
203 Mode of integration for calculating projection matrix.
204 Accepted values are ``'simpson'``, ``'romberg'``, ``'trapezoid'``,
205 and ``'midpoint'``.
206 """
207 return self._integration_mode
209 @integration_mode.setter
210 def integration_mode(self, val) -> None:
211 if val not in ('simpson', 'midpoint', 'romberg'): 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true
212 raise ValueError('integration_mode must have value "midpoint", '
213 '"romberg", "trapezoid" or "simpson", '
214 f'but a value of {val} was given!')
215 self._integration_mode = val
216 self._projection_matrix = self._get_integrated_projection_matrix()
218 @abstractmethod
219 def get_output(self,
220 coefficients: NDArray) -> dict:
221 pass
223 @abstractmethod
224 def __len__(self) -> int:
225 pass
227 @abstractmethod
228 def __dict__(self) -> dict:
229 pass
231 @abstractmethod
232 def __str__(self) -> str:
233 pass
235 @abstractmethod
236 def _repr_html_(self) -> str:
237 pass
239 @property
240 def csr_representation(self) -> tuple:
241 """ The projection matrix as a stack of sparse matrices in
242 CSR representation as a tuple. The information in the tuple consists of
243 the 3 dense matrices making up the representation,
244 in the order ``(pointers, indices, data)``."""
245 nnz = np.max((self._projection_matrix > 0).sum((-1, -2)))
246 sparse_data = np.zeros((self._projection_matrix.shape[0], nnz), dtype=np.float32)
247 sparse_indices = np.zeros((self._projection_matrix.shape[0], nnz), dtype=np.int32)
248 sparse_pointers = np.zeros((self._projection_matrix.shape[0],
249 self._projection_matrix.shape[1] + 1), dtype=np.int32)
250 for i in range(self._projection_matrix.shape[0]):
251 sparse_matrix = csr_array(self._projection_matrix[i])
252 sparse_matrix.eliminate_zeros()
253 sparse_data[i, :sparse_matrix.nnz] = sparse_matrix.data
254 sparse_indices[i, :sparse_matrix.nnz] = sparse_matrix.indices
255 sparse_pointers[i, :len(sparse_matrix.indptr)] = sparse_matrix.indptr
256 return sparse_pointers, sparse_indices, sparse_data