Coverage for local_installation_linux/mumott/core/simulator.py: 95%
205 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import sys
3import tqdm
4import numpy as np
5from scipy.spatial import KDTree
7from mumott.methods.basis_sets.base_basis_set import BasisSet
10class Simulator:
11 """ Simulator for tensor tomography samples based on a geometry and a few
12 sources with associated influence functions.
13 Designed primarily for local representations. Polynomial representations,
14 such as spherical harmonics, may require other residual functions which take
15 the different frequency bands into account.
17 Parameters
18 ----------
19 volume_mask : np.ndarray[int]
20 A three-dimensional mask. The shape of the mask defines the shape of the entire
21 simulated volume.
22 The non-zero entries of the mask determine where the simulated sample is
23 located. The non-zero entries should be mostly contiguous for good results.
24 basis_set : BasisSet
25 The basis set used for the simulation. Ideally local representations such as
26 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>`
27 should be used.
28 Do not modify the basis set after creating the simulator.
29 seed : int
30 Seed for the random number generator. Useful for generating consistent
31 simulations. By default no seed is used.
32 distance_radius : float
33 Radius for the balls used in determining interior distances.
34 Usually, this value should not be changed, but it can be increased
35 to take larger strides in the interior of the sample.
36 Default value is ``np.sqrt(2) * 1.01``.
37 """
38 def __init__(self,
39 volume_mask: np.ndarray[int],
40 basis_set: BasisSet,
41 seed: int = None,
42 distance_radius: float = np.sqrt(2) * 1.01) -> None:
43 self._volume_mask = volume_mask > 0
44 self._basis_set = basis_set
45 self._basis_set_hash = hash(basis_set)
46 x, y, z = (np.arange(self.shape[0]),
47 np.arange(self.shape[1]),
48 np.arange(self.shape[2]))
49 X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
50 # set X-coordinates of non-considered points to be impossibly large
51 X[~self._volume_mask] = np.prod(self.shape)
52 self._positions = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)), axis=1)
53 self._tree = KDTree(self._positions)
54 self._source_locations = []
55 self._source_distances = []
56 self._source_exponents = []
57 self._source_scale_parameters = []
58 self._source_coefficients = []
59 self._simulation = np.zeros(self.shape + (len(self._basis_set),), dtype=float)
60 self._source_weights = np.ones(self.shape, dtype=float)
61 self._rng = np.random.default_rng(seed)
62 self._distance_radius = distance_radius
64 def add_source(self,
65 location: tuple[int, int, int] = None,
66 coefficients: np.ndarray[float] = None,
67 influence_exponent: float = 2,
68 influence_scale_parameter: float = 10) -> None:
69 """
70 Add a source to your simulation.
72 Parameters
73 ----------
74 location
75 Location, in terms of indices, of the source.
76 If not given, will be randomized weighted by inverse of distance
77 to other source points.
78 coefficients
79 Coefficients defining the source. If not given, will be
80 randomized in the interval ``[0, 1]``.
81 influence_exponent
82 Exponent of the influence of the influence of the source.
83 Default value is ``2``, giving a Gaussian.
84 influence_scale_parameter
85 Scale parameter of the influence of the influence of the source.
86 Default value is ``10``.
88 Notes
89 -----
90 The equation for the source influence is ``np.exp(-(d / (p * s)) ** p)``,
91 where ``d`` is the interior distance, ``s`` is the scale parameter,
92 and ``p`` is the exponent.
94 If a location is not given, a location will be searched for for no more than
95 1e6 iterations. For large and sparse samples, consider specifying locations
96 manually.
97 """
98 if location is None:
99 iterations = 0
100 while True:
101 location = tuple(self._rng.integers((0, 0, 0), self.shape))
102 if self._volume_mask[location] and self._rng.random() < self._source_weights[location]:
103 break
104 if iterations > 1e6:
105 raise RuntimeError('Maximum number of iterations exceeded.'
106 ' Specify location manually instead.')
107 iterations += 1
108 elif not self._volume_mask[location]: 108 ↛ 109line 108 didn't jump to line 109, because the condition on line 108 was never true
109 raise ValueError('location must be inside the valid region of the volume mask!')
110 self._source_locations.append(location)
111 self._source_distances.append(self._get_distances_to_point(location))
112 # weight likelihood of placing a source by rms of distances
113 distance_sum = np.power(self._source_distances, 2).sum(0)
114 self._source_weights = np.sqrt(distance_sum / distance_sum.max())
115 if coefficients is None:
116 coefficients = self._rng.random(len(self.basis_set))
117 self._source_coefficients.append(coefficients)
118 self._source_scale_parameters.append(influence_scale_parameter)
119 self._source_exponents.append(influence_exponent)
121 def _get_distances_to_point(self, point):
122 """ Internal method for computing interior distances. """
123 distances = np.zeros_like(self._volume_mask, dtype=np.float64) - 1
124 # Compute distances based on ball around source point
125 ball = self._tree.query_ball_point(
126 self._positions[np.ravel_multi_index(point, self.shape)], r=self._distance_radius, p=2)
127 distances[tuple(point)] = 0
128 list_of_balls = []
129 list_of_distances = []
130 self._recursive_ball_dijkstra(
131 ball,
132 self._positions[np.ravel_multi_index(point, self.shape)],
133 distances,
134 list_of_balls,
135 list_of_distances,
136 total_distance=0,
137 generation=0)
138 distances[distances < 0] = np.prod(self.shape)
139 return distances
141 def _recursive_ball_dijkstra(self,
142 ball,
143 point: tuple,
144 distances: np.ndarray[float],
145 list_of_balls: list,
146 list_of_distances: list,
147 total_distance,
148 generation: int = 1):
149 """ Recursive function that uses a variant of Dijkstra's Algorithm
150 to find the internal distances in the simulation volume. """
151 for ind, i in enumerate(ball):
152 # Ignore any already-computed distance
153 if distances[np.unravel_index(i, self.shape)] != -1:
154 continue
155 else:
156 # Add previously computed distance from source to new starting point
157 list_of_balls.append(i)
158 new_distance = total_distance + np.sqrt(((self._positions[i] - point) ** 2).sum())
159 list_of_distances.append(new_distance)
160 distances[np.unravel_index(i, self.shape)] = new_distance
161 # Only carry out recursion from lowest level with generation 0.
162 while (generation == 0) and (len(list_of_balls) > 0):
163 next_ind = np.argmin(list_of_distances)
164 next_point = list_of_balls[next_ind]
165 next_distance = list_of_distances[next_ind]
166 del list_of_balls[next_ind], list_of_distances[next_ind]
167 new_ball = self._tree.query_ball_point(self._positions[next_point], r=self._distance_radius, p=2)
168 self._recursive_ball_dijkstra(
169 new_ball,
170 self._positions[next_point],
171 distances,
172 list_of_balls,
173 list_of_distances,
174 next_distance)
176 def _get_power_factor(self,
177 power_factor_weights: np.ndarray[float] = None):
178 """ Computes the residual norm, gradient and the squared residuals of the model
179 implied by the source points and the current simulation."""
180 if power_factor_weights is None: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true
181 power_factor_weights = np.ones(self.shape, dtype=float)
182 power_factor_weights[~self._volume_mask] = 0.
183 source_coefficients = np.array(self._source_coefficients).reshape(-1, 1, 1, 1, len(self.basis_set))
184 model = self._simulation.reshape((1,) + self._simulation.shape)
185 # pseudo-power
186 model_power = (model ** 2).sum(-1)
187 source_power = (source_coefficients ** 2).sum(-1)
188 # influences affect only importance
189 differences = (model_power - source_power) * self.influences
190 residuals = (differences ** 2).sum(0)
191 gradient = (
192 2 * model * (power_factor_weights[None] * self.influences * differences)[..., None]).sum(0)
193 norm = 0.5 * (residuals * power_factor_weights).sum()
194 return norm, gradient, residuals
196 def _get_residuals(self,
197 residual_weights: np.ndarray[float] = None):
198 """ Computes the residual norm, gradient and the squared residuals of the model
199 implied by the source points and the current simulation."""
200 if residual_weights is None: 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true
201 residual_weights = np.ones(self.shape, dtype=float)
202 residual_weights[~self._volume_mask] = 0.
203 source_coefficients = np.array(self._source_coefficients).reshape(-1, 1, 1, 1, len(self.basis_set))
204 model = self._simulation.reshape((1,) + self._simulation.shape)
205 # pseudo-covariance
206 covariance = (model * source_coefficients).sum(-1)
207 source_variance = (source_coefficients ** 2).sum(-1)
208 # influences affect both importance and expected degree of similarity
209 differences = (covariance - source_variance * self.influences) * self.influences
210 residuals = (differences ** 2).sum(0)
211 gradient = (
212 source_coefficients * (residual_weights[None] * self.influences * differences)[..., None]).sum(0)
213 norm = 0.5 * (residuals * residual_weights).sum()
214 return norm, gradient, residuals
216 def _get_squared_total_variation(self, tv_weights: np.ndarray[float] = None):
217 """ Computes the norm, gradient and squared value of the
218 squared total variation of the simulation."""
219 if tv_weights is None: 219 ↛ 220line 219 didn't jump to line 220, because the condition on line 219 was never true
220 tv_weights = np.ones(self.shape, dtype=float)
221 tv_weights[~self._volume_mask] = 0.
222 sub_sim = self._simulation[1:-1, 1:-1, 1:-1]
223 slices_1 = [np.s_[:-1, :, :], np.s_[:, :-1, :], np.s_[:, :, :-1]]
224 slices_2 = [np.s_[1:, :, :], np.s_[:, 1:, :], np.s_[:, :, 1:]]
225 value = np.zeros_like(self._simulation)
226 # view into value
227 sub_value = value[1:-1, 1:-1, 1:-1]
228 gradient = np.zeros_like(self._simulation)
229 # view into gradient
230 sub_grad = gradient[1:-1, 1:-1, 1:-1]
231 for s1, s2 in zip(slices_1, slices_2):
232 difference = (sub_sim[s1] - sub_sim[s2])
233 sub_value[s1] += difference ** 2
234 sub_grad[s1] += difference
235 value = value.sum(-1)
236 value_norm = 0.5 * (value * tv_weights).sum()
237 gradient = gradient * tv_weights[..., None]
238 return value_norm, gradient, value
240 def optimize(self,
241 step_size: float = 0.01,
242 iterations: int = 10,
243 weighting_iterations: int = 5,
244 momentum: float = 0.95,
245 tv_weight: float = 0.1,
246 tv_exponent: float = 1.,
247 tv_delta: float = 0.01,
248 residual_exponent: float = 1.,
249 residual_delta: float = 0.05,
250 power_weight: float = 0.01,
251 power_exponent: float = 1,
252 power_delta: float = 0.01,
253 lower_bound: float = 0.):
254 """ Optimizer for carrying out the simulation. Can be called repeatedly.
255 Uses iteratively reweighted least squares, with weights calculated based on
256 the Euclidean norm over each voxel.
258 Parameters
259 ----------
260 step_size
261 Step size for gradient descent.
262 iterations
263 Number of iterations for each gradient descent solution.
264 weighting_iterations
265 Number of reweighting iterations.
266 momentum
267 Nesterov momentum term.
268 tv_weight
269 Weight for the total variation term.
270 tv_exponent
271 Exponent for the total variation reweighting.
272 Default is 1, which will approach a Euclidean norm considered.
273 tv_delta
274 Huber-style cutoff for the total variation factor normalization.
275 residual_exponent
276 Exponent for the residual norm reweighting.
277 residual_delta
278 Huber-style cutoff for the residual normalization.
279 power_weight
280 Weight for the power term.
281 power_exponent
282 Exponent for the power term.
283 power_delta
284 Huber-style cutoff for the power normalization.
285 lower_bound
286 Lower bound for coefficients. Coefficients will be
287 clipped at these bounds at every reweighting.
288 """
289 residual = np.ones(self.shape, dtype=float)
290 tv_value = np.ones(self.shape, dtype=float)
291 power_value = np.ones(self.shape, dtype=float)
292 pbar = tqdm.trange(weighting_iterations, file=sys.stdout)
294 for _ in range(weighting_iterations):
295 total_gradient = np.zeros_like(self._simulation)
296 residual_weights = np.ones(self.shape, dtype=float)
297 residual_weights[~self._volume_mask] = 0.
298 residual_weights[self._volume_mask] *= \
299 residual[self._volume_mask].clip(residual_delta, None) ** ((residual_exponent - 2) / 2)
300 tv_weights = np.ones(self.shape, dtype=float)
301 tv_weights[~self._volume_mask] = 0.
302 tv_weights[self._volume_mask] *= \
303 tv_value[self._volume_mask].clip(tv_delta, None) ** ((tv_exponent - 2) / 2)
304 power_weights = np.ones(self.shape, dtype=float)
305 power_weights[~self._volume_mask] = 0.
306 power_weights[self._volume_mask] *= \
307 power_value[self._volume_mask].clip(power_delta, None) ** ((power_exponent - 2) / 2)
309 for _ in range(iterations):
310 _, residual_gradient, _ = self._get_residuals(residual_weights)
311 _, tv_gradient, _ = self._get_squared_total_variation(tv_weights)
312 _, power_gradient, _ = self._get_power_factor(power_weights)
313 gradient = residual_gradient + tv_gradient * tv_weight + power_gradient * power_weight
314 self._simulation -= gradient * step_size
315 total_gradient += gradient
316 total_gradient *= momentum
317 self._simulation -= total_gradient * step_size
319 res_norm, _, residual = self._get_residuals(residual_weights)
320 tv_norm, _, tv_value = self._get_squared_total_variation(tv_weights)
321 power_norm, _, power_value = self._get_power_factor(power_weights)
322 lf = res_norm + tv_norm * tv_weight + power_norm * power_weight
323 pbar.set_description(
324 f'Loss: {lf:.2e} Resid: {res_norm:.2e} TV: {tv_norm:.2e} Pow: {power_norm:.2e}')
325 pbar.update(1)
326 self._simulation.clip(lower_bound, None, out=self._simulation)
328 def reset_simulation(self) -> None:
329 """Resets the simulation by setting all elements to 0."""
330 self._simulation[...] = 0
332 @property
333 def volume_mask(self) -> np.ndarray[float]:
334 """ Mask defining valid sample voxels within the sample.
335 Read-only property; create a new simulation to modify. """
336 return self._volume_mask.copy()
338 @property
339 def basis_set(self) -> BasisSet:
340 """ Basis set defining the representation used in the sample.
341 Read-only property; do not modify. """
342 if hash(self._basis_set) != self._basis_set_hash:
343 raise ValueError('Hash of basis set does not match! Please recreate simulator.')
344 return self._basis_set
346 @property
347 def shape(self) -> tuple[int, int, int]:
348 """ Shape of the simulation and volume mask. """
349 return self._volume_mask.shape
351 @property
352 def simulation(self) -> np.ndarray[float]:
353 """ The simulated sample. """
354 return self._simulation
356 @property
357 def distance_radius(self) -> np.ndarray[float]:
358 """ Distance for ball defining interior distances. """
359 return self._distance_radius
361 @property
362 def sources(self) -> dict:
363 """ Dictionary of source properties.
364 They are given as arrays where the first index specifies
365 the source, so that ``len(array)`` is the number of source points.
367 Notes
368 -----
369 The items are, in order:
371 coefficients
372 The coefficients of each source point.
373 distances
374 The interior distance from each support point to each
375 point in the volume.
376 scale_parameters
377 The scale parameter of each source point.
378 locations
379 The location of each source point.
380 exponents
381 The exponent of each source point.
382 influences
383 The influence of each source point.
384 """
385 return dict(coefficients=np.array(self._source_coefficients),
386 distances=np.array(self._source_distances),
387 scale_parameters=np.array(self._source_scale_parameters),
388 locations=np.array(self._source_locations),
389 exponents=np.array(self._source_exponents),
390 influences=np.array(self.influences))
392 @property
393 def influences(self) -> np.ndarray[float]:
394 """ Influence of each source point through the volume. """
395 influence = np.exp(-(np.array(self._source_distances) /
396 (np.array(self._source_exponents) *
397 np.array(self._source_scale_parameters)).reshape(-1, 1, 1, 1)) **
398 np.array(self._source_exponents).reshape(-1, 1, 1, 1))
399 # Normalize so that largest value is 1 for all source points.
400 influence = influence / influence.reshape(len(influence), -1).max(-1).reshape(-1, 1, 1, 1)
401 # Normalize so that all influences sum to 1.
402 influence[:, self._volume_mask] = (
403 influence[:, self._volume_mask] / influence[:, self._volume_mask].sum(0)[None, ...])
404 influence[:, ~self._volume_mask] = 0.
405 return influence
407 def __str__(self) -> str:
408 wdt = 74
409 s = []
410 s += ['-' * wdt]
411 s += [self.__class__.__name__.center(wdt)]
412 s += ['-' * wdt]
413 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60):
414 s += ['{:18} : {}'.format('shape', self.shape)]
415 s += ['{:18} : {}'.format('distance_radius', self.distance_radius)]
416 s += ['{:18} : {}'.format('basis_set (hash)', hex(hash(self.basis_set))[2:8])]
417 s += ['{:18} : {}'.format('sources', len(self._source_distances))]
419 s += ['-' * wdt]
420 return '\n'.join(s)
422 def _repr_html_(self) -> str:
423 s = []
424 s += [f'<h3>{self.__class__.__name__}</h3>']
425 s += ['<table border="1" class="dataframe">']
426 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
427 s += ['<tbody>']
428 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
429 s += ['<tr><td style="text-align: left;">shape</td>']
430 s += [f'<td>1</td><td>{self.shape}</td></tr>']
431 s += ['<tr><td style="text-align: left;">distance_radius</td>']
432 s += [f'<td>1</td><td>{self.distance_radius}</td></tr>']
433 s += ['<tr><td style="text-align: left;">basis_set (hash)</td>']
434 s += [f'<td>{len(hex(hash(self.basis_set)))}</td><td>{hex(hash(self.basis_set))[2:8]}</td></tr>']
435 s += ['<tr><td style="text-align: left;">sources</td>']
436 s += [f'<td>1</td><td>{len(self._source_distances)}</td></tr>']
437 s += ['</tbody>']
438 s += ['</table>']
439 return '\n'.join(s)