Coverage for local_installation_linux/mumott/methods/residual_calculators/zonal_harmonic_gradient_calculator.py: 85%

188 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5 

6from mumott import DataContainer 

7from mumott.core.wigner_d_utilities import ( 

8 load_d_matrices, calculate_sph_coefficients_rotated_around_z, 

9 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x, 

10 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x, 

11 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle) 

12from mumott.core.hashing import list_to_hash 

13from mumott.methods.projectors.base_projector import Projector 

14from mumott.methods.basis_sets.spherical_harmonics import SphericalHarmonics 

15from .base_residual_calculator import ResidualCalculator 

16 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class ZHTTResidualCalculator(ResidualCalculator): 

22 r"""Class that implements the gradient calculations for a model that uses a 

23 :class:`SphericalHarmonics` basis set restricted to zonal harmonics parametrized 

24 by a primary axis with polar coordinates :math:`\theta_0` and :math:`\phi_0` 

25 ,defined as: 

26 

27 .. math:: 

28 

29 \begin{pmatrix} x_0\\ y_0\\ z_0\end{pmatrix} 

30 = \begin{pmatrix} 

31 \sin(\theta_0) \sin(\phi_0) \\ 

32 \sin(\theta_0) \cos(\phi_0) \\ 

33 \cos(\theta_0) 

34 \end{pmatrix} 

35 

36 This model is equivalent to the one used in [Liebi2015]_, but uses a different approach to 

37 computation. 

38 

39 This implementation avoids doing some of the expensive calculations of trigonometric functions and 

40 Legendre polynomials by doing the rotation in the space of the spherical harmonics using 

41 `Wigner (small) d-matrices <https://en.wikipedia.org/wiki/Wigner_D-matrix>`_. 

42 The forward model only involves a small number of trigonometric functions to evaluate the 

43 :math:`d_z(\text{angle})` matrices for the :math:`\theta` and :math:`\phi` rotations. 

44 Everything else is expressed as matrix products with precomputed matrices. 

45 

46 The full forward model may be written as: 

47 

48 .. math:: 

49 

50 \boldsymbol{I} = 

51 \boldsymbol{W} \boldsymbol{P} 

52 \boldsymbol{d}_z(\phi_0) \boldsymbol{d}_y(\frac{\pi}{4})^T 

53 \boldsymbol{d}_z(-\theta_0) \boldsymbol{d}_y(\frac{\pi}{4}) 

54 \boldsymbol{a}'_{l0}, 

55 

56 where :math:`\boldsymbol{W}` is the mapping from spherical harmonic modes to detector segments, 

57 which can be precomputed. 

58 :math:`\boldsymbol{P}` is the typical projector from normal 3D tomography and 

59 :math:`\boldsymbol{d}_i(\text{angle})` with :math:`i = x,y,z` are Wigner (small) d matrices 

60 for real spherical harmonics. :math:`\theta_0`, :math:`\phi_0`, and :math:`\boldsymbol{a}_{l0}` 

61 are the model parameters for each voxel. 

62 

63 Derivatives are easy to evaluate because the angles only appear in the 

64 :math:`\boldsymbol{d}_z(\text{angle})`-matrices. All the expensive trigonometric and spherical 

65 harmonics calculations have been put into the precomputation of :math:`\boldsymbol{W}` 

66 and :math:`\boldsymbol{d}_y(\frac{\pi}{4})`. 

67 

68 Parameters 

69 ---------- 

70 data_container : DataContainer 

71 Container holding the data to be reconstructed. 

72 basis_set : SphericalHarmonics 

73 The basis set used for representing spherical functions. 

74 projector : Projector 

75 The type of projector used together with this method. 

76 """ 

77 def __init__(self, 

78 data_container: DataContainer, 

79 basis_set: SphericalHarmonics, 

80 projector: Projector): 

81 super().__init__(data_container, basis_set, projector) 

82 self._make_matrices() 

83 self._make_starting_guess() 

84 

85 def _make_starting_guess(self) -> None: 

86 """Initializes the optimization parameters by setting the 

87 zonal coefficients to zero and randomizing the angles, which 

88 corresponds to sampling directions uniformly on the unit 

89 sphere. 

90 """ 

91 volume_shape = self._projector.volume_shape 

92 self._zonal_coefficients = np.zeros((*volume_shape, self._basis_set.ell_max // 2 + 1)) 

93 

94 # Make random orientations by random sampling in 3D 

95 rng = np.random.default_rng() 

96 self._theta = np.arccos(rng.uniform(low=0, high=1, size=volume_shape)) 

97 self._phi = rng.uniform(low=-np.pi, high=np.pi, size=volume_shape) 

98 

99 def _make_matrices(self) -> None: 

100 """ 

101 Loads Wigner d-matrices and creates the mapping from parameters 

102 to spherical harmonics coefficients. 

103 """ 

104 # Load precomputed d-matrices 

105 ell_max = self._basis_set.ell_max 

106 self.d_matrices = load_d_matrices(ell_max) 

107 

108 # Set up matrix for converting from zonal harmonics to full harmonics space 

109 ell_list = self._basis_set.ell_indices 

110 m_list = self._basis_set.emm_indices 

111 self._E = np.zeros((len(ell_list), ell_max//2+1)) 

112 for full_index, (ell, m) in enumerate(zip(ell_list, m_list)): 

113 if m == 0: 

114 self._E[full_index, ell//2] = 1 

115 

116 @property 

117 def coefficients(self) -> NDArray: 

118 """Optimization coefficients for this method. 

119 Contains both the zonal coefficients and the angles. 

120 The first N-2 elements are zonal coefficients. 

121 The N-1th element is the polar angle and the last element is the azimuthal angle. 

122 """ 

123 self._cast_angles_to_symmetric_zone() 

124 return np.concatenate((self._zonal_coefficients, 

125 self._theta[..., np.newaxis], 

126 self._phi[..., np.newaxis]), axis=3) 

127 

128 @coefficients.setter 

129 def coefficients(self, val: NDArray) -> None: 

130 # Convert from external to internal representation of optimization parameters 

131 val = val.reshape((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 1 + 2)) 

132 assert np.shape(val[..., :-2]) == np.shape(self._zonal_coefficients), \ 

133 'Shape of new array inconsistent with expectation (zonal_coefficients)' 

134 assert np.shape(val[..., -2]) == np.shape(self._theta), \ 

135 'Shape of new array inconsistent with expectation (theta)' 

136 assert np.shape(val[..., -1]) == np.shape(self._phi), \ 

137 'Shape of new array inconsistent with expectation (phi)' 

138 self._zonal_coefficients = val[..., :-2] 

139 self._theta = val[..., -2] 

140 self._phi = val[..., -1] 

141 

142 def _rotate_coeffs(self) -> NDArray: 

143 """Expand from the zonal harmonics basis to a full spherical harmonics basis and 

144 rotate the spherical harmonics coefficients from the symmetric coordinate system 

145 to the sample xyz system. 

146 

147 Returns 

148 ------- 

149 Array containing the rotated spherical harmonics coefficients. 

150 """ 

151 ell_list = np.arange(0, self._basis_set.ell_max + 1, 2) 

152 # Expand symmetric coefficients into full basis 

153 self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E) 

154 # Rotate by 90 degrees about x 

155 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x( 

156 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients) 

157 # Rotate by theta about z 

158 calculate_sph_coefficients_rotated_around_z( 

159 self._coefficients, self._theta, ell_list, output_array=self._coefficients) 

160 # Rotate by -90 degrees about x 

161 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x( 

162 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients) 

163 # Rotate by phi about z 

164 calculate_sph_coefficients_rotated_around_z( 

165 self._coefficients, self._phi, ell_list, output_array=self._coefficients) 

166 return self._coefficients 

167 

168 def _rotate_and_derive(self): 

169 """ 

170 Rotate spherical harmonics coefficients from the symmetric coordinate system 

171 to the sample xyz system and evaluate the derivative of the coefficients 

172 with respect to the two rotation angles. 

173 

174 Returns 

175 ---------- 

176 self._coefficients : NDArray 

177 Array containing the rotated spherical harmonics coefficients. 

178 theta_derivative : NDArray 

179 Rotated spherical coefficients derived with respect to the polar rotation angle 

180 evaluated at the current value of the rotation angles. 

181 phi_derivative : NDArray 

182 Rotated spherical coefficients derived with respect to the azimuthal rotation angle 

183 evaluated at the current value of the rotation angles. 

184 """ 

185 

186 ell_list = np.arange(0, self._basis_set.ell_max+1, 2) 

187 # Expand symmetric coefficients into full basis 

188 self._coefficients = np.einsum('...i,ji->...j', self._zonal_coefficients, self._E) 

189 theta_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set))) 

190 phi_derivative = np.zeros((*self._projector.volume_shape, len(self._basis_set))) 

191 

192 # Do 90 degree rotation around x 

193 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x( 

194 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients) 

195 

196 # Do z rotation of Theta and derivative 

197 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle( 

198 self._coefficients, self._theta, ell_list, output_array=theta_derivative) 

199 calculate_sph_coefficients_rotated_around_z( 

200 self._coefficients, self._theta, ell_list, output_array=self._coefficients) 

201 

202 # Do -90 degree rotation around x 

203 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x( 

204 self._coefficients, ell_list, self.d_matrices, output_array=self._coefficients) 

205 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x( 

206 theta_derivative, ell_list, self.d_matrices, output_array=theta_derivative) 

207 

208 # Do z rotation of Phi 

209 calculate_sph_coefficients_rotated_around_z_derived_wrt_the_angle( 

210 self._coefficients, self._phi, ell_list, output_array=phi_derivative) 

211 calculate_sph_coefficients_rotated_around_z( 

212 self._coefficients, self._phi, ell_list, output_array=self._coefficients) 

213 calculate_sph_coefficients_rotated_around_z( 

214 theta_derivative, self._phi, ell_list, output_array=theta_derivative) 

215 

216 return self._coefficients, theta_derivative, phi_derivative 

217 

218 def _rotate_coeffs_inverse(self, coefficients: NDArray): 

219 """ 

220 Rotate spherical harmonics coefficients from the sample xyz system 

221 to the symmetric coordinate system. 

222 """ 

223 ell_list = np.arange(0, self._basis_set.ell_max+1, 2) 

224 

225 # Do z rotation of -phi 

226 calculate_sph_coefficients_rotated_around_z( 

227 coefficients, -self._phi, ell_list, output_array=coefficients) 

228 

229 # Do 90 degree rotation around x 

230 calculate_sph_coefficients_rotated_by_90_degrees_around_positive_x( 

231 coefficients, ell_list, self.d_matrices, output_array=coefficients) 

232 

233 # Do z rotation of -theta 

234 calculate_sph_coefficients_rotated_around_z( 

235 coefficients, -self._theta, ell_list, output_array=coefficients) 

236 

237 # Do -90 degree rotation around x 

238 calculate_sph_coefficients_rotated_by_90_degrees_around_negative_x( 

239 coefficients, ell_list, self.d_matrices, output_array=coefficients) 

240 

241 return coefficients 

242 

243 def get_residuals(self, 

244 get_gradient: bool = False, 

245 get_weights: bool = False, 

246 gradient_part: str = 'full') -> dict[str, NDArray[float]]: 

247 """ Calculates the residuals and possibly the gradient of the residual square sum 

248 (without the factor of -2!) with respect to the parameters. 

249 The coefficients are projected using the :attr:`SphericalHarmonics` and :attr:`Projector` 

250 attached to this instance. 

251 

252 Parameters 

253 ---------- 

254 get_gradient 

255 Whether to return the gradient. Default is ``False``. 

256 get_weights 

257 Whether to return weights. Default is ``False``. If ``True`` along with 

258 :attr:`get_gradient`, the gradient will be computed with weights. 

259 gradient_part 

260 If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all 

261 parameters; 

262 if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed; 

263 if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal 

264 spherical harmonics coefficients is computed. 

265 

266 Returns 

267 ------- 

268 A dictionary containing the residuals, and possibly the 

269 gradient and/or weights. If gradient and/or weights 

270 are not returned, their value will be ``None``. 

271 """ 

272 

273 if not get_gradient: 

274 # Rotate the coefficients 

275 self._rotate_coeffs() 

276 # Project from voxel to detector space and from coefficient to angle space 

277 projection = self._basis_set.forward( 

278 self._projector.forward(self._coefficients.astype(self.dtype))) 

279 # Calculate residuals 

280 residuals = self._data - projection 

281 if get_weights: 281 ↛ 282line 281 didn't jump to line 282, because the condition on line 281 was never true

282 residuals *= self._weights 

283 output = {'residuals': residuals, 'gradient': None} 

284 

285 elif get_gradient: 285 ↛ 289line 285 didn't jump to line 289, because the condition on line 285 was never false

286 output = self.get_gradient(get_weights=get_weights) 

287 

288 # Pass on weights, if asked to 

289 if get_weights: 

290 output['weights'] = self._weights 

291 else: 

292 output['weights'] = None 

293 

294 return output 

295 

296 def get_gradient(self, 

297 get_weights: bool = False, 

298 gradient_part: str = 'full') -> dict[str, NDArray[float]]: 

299 """ Calculates the gradient of *half* the sum of residuals squared. 

300 

301 Parameters 

302 ---------- 

303 get_gradient 

304 Whether to return the gradient. Default is ``False``. 

305 gradient_part 

306 If :attr:`gradient_part` is ``'full'`` (Default) the gradient is computed with respect to all 

307 parameters; 

308 if :attr:`gradient_part` is ``'angles'`` only the gradient with respect to the angles is computed; 

309 if :attr:`gradient_part` is ``'coefficients'`` only the gradient with respect to the zonal 

310 spherical harmonics coefficients is computed. 

311 

312 Returns 

313 ------- 

314 A dictionary containing the residuals of the gradient. If only a part of the 

315 gradient is computed, the rest of the elements will be filled with zeros. 

316 """ 

317 # initialize output array 

318 gradient = np.zeros((*self._projector.volume_shape, self._basis_set.ell_max // 2 + 3)) 

319 

320 # If only the coefficients are needed, do not evaluate the derivatives. 

321 if gradient_part == 'coefficients': 321 ↛ 322line 321 didn't jump to line 322, because the condition on line 321 was never true

322 coefficients = self._rotate_coeffs() 

323 else: 

324 coefficients, theta_derivative, phi_derivative = self._rotate_and_derive() 

325 

326 # Project from voxel to detector space and the from coeff-space to angle-space 

327 projection = self._basis_set.forward(self._projector.forward(coefficients.astype(self.dtype))) 

328 # Calculate residuals 

329 residuals = self._data - projection 

330 if get_weights: 

331 residuals *= self._weights 

332 # Backproject residual 

333 bp_res = self._projector.adjoint( 

334 self._basis_set.gradient(residuals * self._weights).astype(self.dtype)) 

335 

336 # If the gradient with respect to angles is needed, compute the inner products 

337 if gradient_part in ['full', 'angles']: 337 ↛ 341line 337 didn't jump to line 341, because the condition on line 337 was never false

338 gradient[:, :, :, -2] = -np.einsum('xyzm,xyzm->xyz', bp_res, theta_derivative) 

339 gradient[:, :, :, -1] = -np.einsum('xyzm,xyzm->xyz', bp_res, phi_derivative) 

340 

341 if gradient_part == 'full' or gradient_part == 'coefficients': 341 ↛ 346line 341 didn't jump to line 346, because the condition on line 341 was never false

342 # back-rotate coefficients 

343 bp_res = self._rotate_coeffs_inverse(bp_res) 

344 gradient[..., :-2] += -np.einsum('...i,ij->...j', bp_res, self._E) 

345 

346 return {'residuals': residuals, 'gradient': gradient} 

347 

348 def _cast_angles_to_symmetric_zone(self): 

349 r""" 

350 Casts internal angle arrays into the range :math:`\theta \in [0, \phi/2[` and 

351 :math:`\phi \in [0, 2\phi[`. 

352 """ 

353 self._theta = self._theta % np.pi 

354 southern_hemisphere = self._theta > (np.pi / 2) 

355 self._theta[southern_hemisphere] = np.pi - self._theta[southern_hemisphere] 

356 self._phi[southern_hemisphere] = self._phi[southern_hemisphere] + np.pi 

357 self._phi = self._phi % (2 * np.pi) 

358 

359 @property 

360 def rotated_coefficients(self): 

361 """ 

362 Returns the real spherical harmonics coefficients. 

363 """ 

364 return self._rotate_coeffs() 

365 

366 @property 

367 def directions(self): 

368 """ 

369 Returns the direction of symmetry as a unit vector in in xyz coordinates. 

370 The vector index is the last index of the output. 

371 """ 

372 # Make unit direction vectors 

373 directions = np.stack((np.cos(self._phi)*np.sin(self._theta), 

374 np.sin(self._phi)*np.sin(self._theta), 

375 np.cos(self._phi)), axis=-1) 

376 return directions 

377 

378 @property 

379 def ell_max(self) -> int: 

380 """l max""" 

381 return self._basis_set.ell_max 

382 

383 @property 

384 def volume_shape(self) -> int: 

385 """Shape of voxel volume""" 

386 return self._projector.volume_shape 

387 

388 def _update(self, force_update: bool = False) -> None: 

389 """ Carries out necessary updates if anything changes with respect to 

390 the geometry or basis set. """ 

391 if not (self.is_dirty or force_update): 391 ↛ 392line 391 didn't jump to line 392, because the condition on line 391 was never true

392 return 

393 self._basis_set.probed_coordinates = self.probed_coordinates 

394 

395 # See ell_max changed 

396 old_ellmax = (self._zonal_coefficients.shape[-1] - 1) * 2 

397 old_num_coeffs = (old_ellmax + 1) * (old_ellmax + 2) // 2 

398 len_diff = len(self._basis_set) - old_num_coeffs 

399 

400 vol_diff = self._data_container.geometry.volume_shape - np.array(self._coefficients.shape[:-1]) 

401 # TODO: Think about whether the ``Method`` should do this or handle it differently 

402 if len_diff != 0 and not np.any(vol_diff != 0): 402 ↛ 403line 402 didn't jump to line 403, because the condition on line 402 was never true

403 logger.warning('ell_max has changed. Coefficients will be truncated or appended with zeros.') 

404 self._make_matrices() 

405 

406 old_params = np.array(self._zonal_coefficients) 

407 self._zonal_coefficients = np.zeros((*self._projector.volume_shape, 

408 self._basis_set.ell_max // 2 + 1)) 

409 if len_diff > 0: 

410 self._zonal_coefficients[:, :, :, :old_params.shape[-1]] = old_params 

411 self._coefficients = np.zeros((*self._data_container.geometry.volume_shape, 

412 len(self._basis_set)), dtype=self.dtype) 

413 if len_diff < 0: 

414 self._zonal_coefficients = old_params[:, :, :, :self._zonal_coefficients.shape[-1]] 

415 self._coefficients = np.zeros((*self._data_container.geometry.volume_shape, 

416 len(self._basis_set)), dtype=self.dtype) 

417 

418 elif np.any(vol_diff != 0): 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true

419 logger.warning('Volume shape has changed.' 

420 ' Coefficients have been reset to zero and angles have been randomized.') 

421 self._make_matrices() 

422 self._random_starting_guess() 

423 

424 self._geometry_hash = hash(self._data_container.geometry) 

425 self._basis_set_hash = hash(self._basis_set) 

426 

427 def __hash__(self) -> int: 

428 """ Returns a hash of the current state of this instance. """ 

429 to_hash = [self._zonal_coefficients, 

430 self._theta, 

431 self._phi, 

432 hash(self._projector), 

433 hash(self._data_container.geometry), 

434 self._basis_set_hash, 

435 self._geometry_hash] 

436 return int(list_to_hash(to_hash), 16) 

437 

438 @property 

439 def is_dirty(self) -> bool: 

440 """ ``True`` if stored hashes of geometry or basis set objects do 

441 not match their current hashes. Used to trigger updates """ 

442 return ((self._geometry_hash != hash(self._data_container.geometry)) or 

443 (self._basis_set_hash != hash(self._basis_set))) 

444 

445 def __str__(self) -> str: 

446 wdt = 74 

447 s = [] 

448 s += ['=' * wdt] 

449 s += [self.__class__.__name__.center(wdt)] 

450 s += ['-' * wdt] 

451 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1): 

452 s += ['{:18} : {}'.format('BasisSet', self._basis_set.__class__.__name__)] 

453 s += ['{:18} : {}'.format('Projector', self._projector.__class__.__name__)] 

454 s += ['{:18} : {}'.format('is_dirty', self.is_dirty)] 

455 s += ['{:18} : {}'.format('probed_coordinates (hash)', 

456 hex(hash(self.probed_coordinates))[2:8])] 

457 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

458 s += ['-' * wdt] 

459 return '\n'.join(s) 

460 

461 def _repr_html_(self) -> str: 

462 s = [] 

463 s += [f'<h3>{self.__class__.__name__}</h3>'] 

464 s += ['<table border="1" class="dataframe">'] 

465 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

466 s += ['<tbody>'] 

467 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

468 s += ['<tr><td style="text-align: left;">BasisSet</td>'] 

469 s += [f'<td>{1}</td><td>{self._basis_set.__class__.__name__}</td></tr>'] 

470 s += ['<tr><td style="text-align: left;">Projector</td>'] 

471 s += [f'<td>{len(self._projector.__class__.__name__)}</td>' 

472 f'<td>{self._projector.__class__.__name__}</td></tr>'] 

473 s += ['<tr><td style="text-align: left;">Is dirty</td>'] 

474 s += [f'<td>{1}</td><td>{self.is_dirty}</td></tr>'] 

475 s += ['<tr><td style="text-align: left;">probed_coordinates</td>'] 

476 s += [f'<td>{self.probed_coordinates.vector.shape}</td>' 

477 f'<td>{hex(hash(self.probed_coordinates))[2:8]} (hash)</td></tr>'] 

478 s += ['<tr><td style="text-align: left;">Hash</td>'] 

479 h = hex(hash(self)) 

480 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>'] 

481 s += ['</tbody>'] 

482 s += ['</table>'] 

483 return '\n'.join(s)