Coverage for local_installation_linux/mumott/core/simulator.py: 95%

205 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import sys 

2 

3import tqdm 

4import numpy as np 

5from scipy.spatial import KDTree 

6 

7from mumott.methods.basis_sets.base_basis_set import BasisSet 

8 

9 

10class Simulator: 

11 """ Simulator for tensor tomography samples based on a geometry and a few 

12 sources with associated influence functions. 

13 Designed primarily for local representations. Polynomial representations, 

14 such as spherical harmonics, may require other residual functions which take 

15 the different frequency bands into account. 

16 

17 Parameters 

18 ---------- 

19 volume_mask : np.ndarray[int] 

20 A three-dimensional mask. The shape of the mask defines the shape of the entire 

21 simulated volume. 

22 The non-zero entries of the mask determine where the simulated sample is 

23 located. The non-zero entries should be mostly contiguous for good results. 

24 basis_set : BasisSet 

25 The basis set used for the simulation. Ideally local representations such as 

26 :class:`GaussianKernels <mumott.methods.basis_sets.GaussianKernels>` 

27 should be used. 

28 Do not modify the basis set after creating the simulator. 

29 seed : int 

30 Seed for the random number generator. Useful for generating consistent 

31 simulations. By default no seed is used. 

32 distance_radius : float 

33 Radius for the balls used in determining interior distances. 

34 Usually, this value should not be changed, but it can be increased 

35 to take larger strides in the interior of the sample. 

36 Default value is ``np.sqrt(2) * 1.01``. 

37 """ 

38 def __init__(self, 

39 volume_mask: np.ndarray[int], 

40 basis_set: BasisSet, 

41 seed: int = None, 

42 distance_radius: float = np.sqrt(2) * 1.01) -> None: 

43 self._volume_mask = volume_mask > 0 

44 self._basis_set = basis_set 

45 self._basis_set_hash = hash(basis_set) 

46 x, y, z = (np.arange(self.shape[0]), 

47 np.arange(self.shape[1]), 

48 np.arange(self.shape[2])) 

49 X, Y, Z = np.meshgrid(x, y, z, indexing='ij') 

50 # set X-coordinates of non-considered points to be impossibly large 

51 X[~self._volume_mask] = np.prod(self.shape) 

52 self._positions = np.concatenate((X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)), axis=1) 

53 self._tree = KDTree(self._positions) 

54 self._source_locations = [] 

55 self._source_distances = [] 

56 self._source_exponents = [] 

57 self._source_scale_parameters = [] 

58 self._source_coefficients = [] 

59 self._simulation = np.zeros(self.shape + (len(self._basis_set),), dtype=float) 

60 self._source_weights = np.ones(self.shape, dtype=float) 

61 self._rng = np.random.default_rng(seed) 

62 self._distance_radius = distance_radius 

63 

64 def add_source(self, 

65 location: tuple[int, int, int] = None, 

66 coefficients: np.ndarray[float] = None, 

67 influence_exponent: float = 2, 

68 influence_scale_parameter: float = 10) -> None: 

69 """ 

70 Add a source to your simulation. 

71 

72 Parameters 

73 ---------- 

74 location 

75 Location, in terms of indices, of the source. 

76 If not given, will be randomized weighted by inverse of distance 

77 to other source points. 

78 coefficients 

79 Coefficients defining the source. If not given, will be 

80 randomized in the interval ``[0, 1]``. 

81 influence_exponent 

82 Exponent of the influence of the influence of the source. 

83 Default value is ``2``, giving a Gaussian. 

84 influence_scale_parameter 

85 Scale parameter of the influence of the influence of the source. 

86 Default value is ``10``. 

87 

88 Notes 

89 ----- 

90 The equation for the source influence is ``np.exp(-(d / (p * s)) ** p)``, 

91 where ``d`` is the interior distance, ``s`` is the scale parameter, 

92 and ``p`` is the exponent. 

93 

94 If a location is not given, a location will be searched for for no more than 

95 1e6 iterations. For large and sparse samples, consider specifying locations 

96 manually. 

97 """ 

98 if location is None: 

99 iterations = 0 

100 while True: 

101 location = tuple(self._rng.integers((0, 0, 0), self.shape)) 

102 if self._volume_mask[location] and self._rng.random() < self._source_weights[location]: 

103 break 

104 if iterations > 1e6: 

105 raise RuntimeError('Maximum number of iterations exceeded.' 

106 ' Specify location manually instead.') 

107 iterations += 1 

108 elif not self._volume_mask[location]: 108 ↛ 109line 108 didn't jump to line 109, because the condition on line 108 was never true

109 raise ValueError('location must be inside the valid region of the volume mask!') 

110 self._source_locations.append(location) 

111 self._source_distances.append(self._get_distances_to_point(location)) 

112 # weight likelihood of placing a source by rms of distances 

113 distance_sum = np.power(self._source_distances, 2).sum(0) 

114 self._source_weights = np.sqrt(distance_sum / distance_sum.max()) 

115 if coefficients is None: 

116 coefficients = self._rng.random(len(self.basis_set)) 

117 self._source_coefficients.append(coefficients) 

118 self._source_scale_parameters.append(influence_scale_parameter) 

119 self._source_exponents.append(influence_exponent) 

120 

121 def _get_distances_to_point(self, point): 

122 """ Internal method for computing interior distances. """ 

123 distances = np.zeros_like(self._volume_mask, dtype=np.float64) - 1 

124 # Compute distances based on ball around source point 

125 ball = self._tree.query_ball_point( 

126 self._positions[np.ravel_multi_index(point, self.shape)], r=self._distance_radius, p=2) 

127 distances[tuple(point)] = 0 

128 list_of_balls = [] 

129 list_of_distances = [] 

130 self._recursive_ball_dijkstra( 

131 ball, 

132 self._positions[np.ravel_multi_index(point, self.shape)], 

133 distances, 

134 list_of_balls, 

135 list_of_distances, 

136 total_distance=0, 

137 generation=0) 

138 distances[distances < 0] = np.prod(self.shape) 

139 return distances 

140 

141 def _recursive_ball_dijkstra(self, 

142 ball, 

143 point: tuple, 

144 distances: np.ndarray[float], 

145 list_of_balls: list, 

146 list_of_distances: list, 

147 total_distance, 

148 generation: int = 1): 

149 """ Recursive function that uses a variant of Dijkstra's Algorithm 

150 to find the internal distances in the simulation volume. """ 

151 for ind, i in enumerate(ball): 

152 # Ignore any already-computed distance 

153 if distances[np.unravel_index(i, self.shape)] != -1: 

154 continue 

155 else: 

156 # Add previously computed distance from source to new starting point 

157 list_of_balls.append(i) 

158 new_distance = total_distance + np.sqrt(((self._positions[i] - point) ** 2).sum()) 

159 list_of_distances.append(new_distance) 

160 distances[np.unravel_index(i, self.shape)] = new_distance 

161 # Only carry out recursion from lowest level with generation 0. 

162 while (generation == 0) and (len(list_of_balls) > 0): 

163 next_ind = np.argmin(list_of_distances) 

164 next_point = list_of_balls[next_ind] 

165 next_distance = list_of_distances[next_ind] 

166 del list_of_balls[next_ind], list_of_distances[next_ind] 

167 new_ball = self._tree.query_ball_point(self._positions[next_point], r=self._distance_radius, p=2) 

168 self._recursive_ball_dijkstra( 

169 new_ball, 

170 self._positions[next_point], 

171 distances, 

172 list_of_balls, 

173 list_of_distances, 

174 next_distance) 

175 

176 def _get_power_factor(self, 

177 power_factor_weights: np.ndarray[float] = None): 

178 """ Computes the residual norm, gradient and the squared residuals of the model 

179 implied by the source points and the current simulation.""" 

180 if power_factor_weights is None: 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true

181 power_factor_weights = np.ones(self.shape, dtype=float) 

182 power_factor_weights[~self._volume_mask] = 0. 

183 source_coefficients = np.array(self._source_coefficients).reshape(-1, 1, 1, 1, len(self.basis_set)) 

184 model = self._simulation.reshape((1,) + self._simulation.shape) 

185 # pseudo-power 

186 model_power = (model ** 2).sum(-1) 

187 source_power = (source_coefficients ** 2).sum(-1) 

188 # influences affect only importance 

189 differences = (model_power - source_power) * self.influences 

190 residuals = (differences ** 2).sum(0) 

191 gradient = ( 

192 2 * model * (power_factor_weights[None] * self.influences * differences)[..., None]).sum(0) 

193 norm = 0.5 * (residuals * power_factor_weights).sum() 

194 return norm, gradient, residuals 

195 

196 def _get_residuals(self, 

197 residual_weights: np.ndarray[float] = None): 

198 """ Computes the residual norm, gradient and the squared residuals of the model 

199 implied by the source points and the current simulation.""" 

200 if residual_weights is None: 200 ↛ 201line 200 didn't jump to line 201, because the condition on line 200 was never true

201 residual_weights = np.ones(self.shape, dtype=float) 

202 residual_weights[~self._volume_mask] = 0. 

203 source_coefficients = np.array(self._source_coefficients).reshape(-1, 1, 1, 1, len(self.basis_set)) 

204 model = self._simulation.reshape((1,) + self._simulation.shape) 

205 # pseudo-covariance 

206 covariance = (model * source_coefficients).sum(-1) 

207 source_variance = (source_coefficients ** 2).sum(-1) 

208 # influences affect both importance and expected degree of similarity 

209 differences = (covariance - source_variance * self.influences) * self.influences 

210 residuals = (differences ** 2).sum(0) 

211 gradient = ( 

212 source_coefficients * (residual_weights[None] * self.influences * differences)[..., None]).sum(0) 

213 norm = 0.5 * (residuals * residual_weights).sum() 

214 return norm, gradient, residuals 

215 

216 def _get_squared_total_variation(self, tv_weights: np.ndarray[float] = None): 

217 """ Computes the norm, gradient and squared value of the 

218 squared total variation of the simulation.""" 

219 if tv_weights is None: 219 ↛ 220line 219 didn't jump to line 220, because the condition on line 219 was never true

220 tv_weights = np.ones(self.shape, dtype=float) 

221 tv_weights[~self._volume_mask] = 0. 

222 sub_sim = self._simulation[1:-1, 1:-1, 1:-1] 

223 slices_1 = [np.s_[:-1, :, :], np.s_[:, :-1, :], np.s_[:, :, :-1]] 

224 slices_2 = [np.s_[1:, :, :], np.s_[:, 1:, :], np.s_[:, :, 1:]] 

225 value = np.zeros_like(self._simulation) 

226 # view into value 

227 sub_value = value[1:-1, 1:-1, 1:-1] 

228 gradient = np.zeros_like(self._simulation) 

229 # view into gradient 

230 sub_grad = gradient[1:-1, 1:-1, 1:-1] 

231 for s1, s2 in zip(slices_1, slices_2): 

232 difference = (sub_sim[s1] - sub_sim[s2]) 

233 sub_value[s1] += difference ** 2 

234 sub_grad[s1] += difference 

235 value = value.sum(-1) 

236 value_norm = 0.5 * (value * tv_weights).sum() 

237 gradient = gradient * tv_weights[..., None] 

238 return value_norm, gradient, value 

239 

240 def optimize(self, 

241 step_size: float = 0.01, 

242 iterations: int = 10, 

243 weighting_iterations: int = 5, 

244 momentum: float = 0.95, 

245 tv_weight: float = 0.1, 

246 tv_exponent: float = 1., 

247 tv_delta: float = 0.01, 

248 residual_exponent: float = 1., 

249 residual_delta: float = 0.05, 

250 power_weight: float = 0.01, 

251 power_exponent: float = 1, 

252 power_delta: float = 0.01, 

253 lower_bound: float = 0.): 

254 """ Optimizer for carrying out the simulation. Can be called repeatedly. 

255 Uses iteratively reweighted least squares, with weights calculated based on 

256 the Euclidean norm over each voxel. 

257 

258 Parameters 

259 ---------- 

260 step_size 

261 Step size for gradient descent. 

262 iterations 

263 Number of iterations for each gradient descent solution. 

264 weighting_iterations 

265 Number of reweighting iterations. 

266 momentum 

267 Nesterov momentum term. 

268 tv_weight 

269 Weight for the total variation term. 

270 tv_exponent 

271 Exponent for the total variation reweighting. 

272 Default is 1, which will approach a Euclidean norm considered. 

273 tv_delta 

274 Huber-style cutoff for the total variation factor normalization. 

275 residual_exponent 

276 Exponent for the residual norm reweighting. 

277 residual_delta 

278 Huber-style cutoff for the residual normalization. 

279 power_weight 

280 Weight for the power term. 

281 power_exponent 

282 Exponent for the power term. 

283 power_delta 

284 Huber-style cutoff for the power normalization. 

285 lower_bound 

286 Lower bound for coefficients. Coefficients will be 

287 clipped at these bounds at every reweighting. 

288 """ 

289 residual = np.ones(self.shape, dtype=float) 

290 tv_value = np.ones(self.shape, dtype=float) 

291 power_value = np.ones(self.shape, dtype=float) 

292 pbar = tqdm.trange(weighting_iterations, file=sys.stdout) 

293 

294 for _ in range(weighting_iterations): 

295 total_gradient = np.zeros_like(self._simulation) 

296 residual_weights = np.ones(self.shape, dtype=float) 

297 residual_weights[~self._volume_mask] = 0. 

298 residual_weights[self._volume_mask] *= \ 

299 residual[self._volume_mask].clip(residual_delta, None) ** ((residual_exponent - 2) / 2) 

300 tv_weights = np.ones(self.shape, dtype=float) 

301 tv_weights[~self._volume_mask] = 0. 

302 tv_weights[self._volume_mask] *= \ 

303 tv_value[self._volume_mask].clip(tv_delta, None) ** ((tv_exponent - 2) / 2) 

304 power_weights = np.ones(self.shape, dtype=float) 

305 power_weights[~self._volume_mask] = 0. 

306 power_weights[self._volume_mask] *= \ 

307 power_value[self._volume_mask].clip(power_delta, None) ** ((power_exponent - 2) / 2) 

308 

309 for _ in range(iterations): 

310 _, residual_gradient, _ = self._get_residuals(residual_weights) 

311 _, tv_gradient, _ = self._get_squared_total_variation(tv_weights) 

312 _, power_gradient, _ = self._get_power_factor(power_weights) 

313 gradient = residual_gradient + tv_gradient * tv_weight + power_gradient * power_weight 

314 self._simulation -= gradient * step_size 

315 total_gradient += gradient 

316 total_gradient *= momentum 

317 self._simulation -= total_gradient * step_size 

318 

319 res_norm, _, residual = self._get_residuals(residual_weights) 

320 tv_norm, _, tv_value = self._get_squared_total_variation(tv_weights) 

321 power_norm, _, power_value = self._get_power_factor(power_weights) 

322 lf = res_norm + tv_norm * tv_weight + power_norm * power_weight 

323 pbar.set_description( 

324 f'Loss: {lf:.2e} Resid: {res_norm:.2e} TV: {tv_norm:.2e} Pow: {power_norm:.2e}') 

325 pbar.update(1) 

326 self._simulation.clip(lower_bound, None, out=self._simulation) 

327 

328 def reset_simulation(self) -> None: 

329 """Resets the simulation by setting all elements to 0.""" 

330 self._simulation[...] = 0 

331 

332 @property 

333 def volume_mask(self) -> np.ndarray[float]: 

334 """ Mask defining valid sample voxels within the sample. 

335 Read-only property; create a new simulation to modify. """ 

336 return self._volume_mask.copy() 

337 

338 @property 

339 def basis_set(self) -> BasisSet: 

340 """ Basis set defining the representation used in the sample. 

341 Read-only property; do not modify. """ 

342 if hash(self._basis_set) != self._basis_set_hash: 

343 raise ValueError('Hash of basis set does not match! Please recreate simulator.') 

344 return self._basis_set 

345 

346 @property 

347 def shape(self) -> tuple[int, int, int]: 

348 """ Shape of the simulation and volume mask. """ 

349 return self._volume_mask.shape 

350 

351 @property 

352 def simulation(self) -> np.ndarray[float]: 

353 """ The simulated sample. """ 

354 return self._simulation 

355 

356 @property 

357 def distance_radius(self) -> np.ndarray[float]: 

358 """ Distance for ball defining interior distances. """ 

359 return self._distance_radius 

360 

361 @property 

362 def sources(self) -> dict: 

363 """ Dictionary of source properties. 

364 They are given as arrays where the first index specifies 

365 the source, so that ``len(array)`` is the number of source points. 

366 

367 Notes 

368 ----- 

369 The items are, in order: 

370 

371 coefficients 

372 The coefficients of each source point. 

373 distances 

374 The interior distance from each support point to each 

375 point in the volume. 

376 scale_parameters 

377 The scale parameter of each source point. 

378 locations 

379 The location of each source point. 

380 exponents 

381 The exponent of each source point. 

382 influences 

383 The influence of each source point. 

384 """ 

385 return dict(coefficients=np.array(self._source_coefficients), 

386 distances=np.array(self._source_distances), 

387 scale_parameters=np.array(self._source_scale_parameters), 

388 locations=np.array(self._source_locations), 

389 exponents=np.array(self._source_exponents), 

390 influences=np.array(self.influences)) 

391 

392 @property 

393 def influences(self) -> np.ndarray[float]: 

394 """ Influence of each source point through the volume. """ 

395 influence = np.exp(-(np.array(self._source_distances) / 

396 (np.array(self._source_exponents) * 

397 np.array(self._source_scale_parameters)).reshape(-1, 1, 1, 1)) ** 

398 np.array(self._source_exponents).reshape(-1, 1, 1, 1)) 

399 # Normalize so that largest value is 1 for all source points. 

400 influence = influence / influence.reshape(len(influence), -1).max(-1).reshape(-1, 1, 1, 1) 

401 # Normalize so that all influences sum to 1. 

402 influence[:, self._volume_mask] = ( 

403 influence[:, self._volume_mask] / influence[:, self._volume_mask].sum(0)[None, ...]) 

404 influence[:, ~self._volume_mask] = 0. 

405 return influence 

406 

407 def __str__(self) -> str: 

408 wdt = 74 

409 s = [] 

410 s += ['-' * wdt] 

411 s += [self.__class__.__name__.center(wdt)] 

412 s += ['-' * wdt] 

413 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60): 

414 s += ['{:18} : {}'.format('shape', self.shape)] 

415 s += ['{:18} : {}'.format('distance_radius', self.distance_radius)] 

416 s += ['{:18} : {}'.format('basis_set (hash)', hex(hash(self.basis_set))[2:8])] 

417 s += ['{:18} : {}'.format('sources', len(self._source_distances))] 

418 

419 s += ['-' * wdt] 

420 return '\n'.join(s) 

421 

422 def _repr_html_(self) -> str: 

423 s = [] 

424 s += [f'<h3>{self.__class__.__name__}</h3>'] 

425 s += ['<table border="1" class="dataframe">'] 

426 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

427 s += ['<tbody>'] 

428 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

429 s += ['<tr><td style="text-align: left;">shape</td>'] 

430 s += [f'<td>1</td><td>{self.shape}</td></tr>'] 

431 s += ['<tr><td style="text-align: left;">distance_radius</td>'] 

432 s += [f'<td>1</td><td>{self.distance_radius}</td></tr>'] 

433 s += ['<tr><td style="text-align: left;">basis_set (hash)</td>'] 

434 s += [f'<td>{len(hex(hash(self.basis_set)))}</td><td>{hex(hash(self.basis_set))[2:8]}</td></tr>'] 

435 s += ['<tr><td style="text-align: left;">sources</td>'] 

436 s += [f'<td>1</td><td>{len(self._source_distances)}</td></tr>'] 

437 s += ['</tbody>'] 

438 s += ['</table>'] 

439 return '\n'.join(s)