Coverage for local_installation_linux/mumott/methods/residual_calculators/gradient_residual_calculator.py: 98%

88 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2from typing import Dict 

3 

4import numpy as np 

5from numpy.typing import NDArray 

6 

7from mumott import DataContainer 

8from mumott.core.hashing import list_to_hash 

9from .base_residual_calculator import ResidualCalculator 

10from mumott.methods.basis_sets.base_basis_set import BasisSet 

11from mumott.methods.projectors.base_projector import Projector 

12 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class GradientResidualCalculator(ResidualCalculator): 

18 """Class that implements the GradientResidualCalculator method. 

19 This residual calculator is an appropriate choice for :term:`SAXS` tensor tomography, as it relies 

20 on the small-angle approximation. It relies on inverting the John transform 

21 (also known as the X-ray transform) of a tensor field (where each tensor is a 

22 representation of a spherical function) by comparing it to scattering data which 

23 has been corrected for transmission. 

24 

25 Parameters 

26 ---------- 

27 data_container : DataContainer 

28 Container for the data which is to be reconstructed. 

29 basis_set : BasisSet 

30 The basis set used for representing spherical functions. 

31 projector : Projector 

32 The type of projector used together with this method. 

33 use_scalar_projections : bool 

34 Whether to use a set of scalar projections, rather than the data 

35 in :attr:`data_container`. 

36 scalar_projections : NDArray[float] 

37 If :attr:`use_scalar_projections` is true, the set of scalar projections to use. 

38 Should have the same shape as :attr:`data_container.data`, except with 

39 only one channel in the last index. 

40 """ 

41 

42 def __init__(self, 

43 data_container: DataContainer, 

44 basis_set: BasisSet, 

45 projector: Projector, 

46 use_scalar_projections: bool = False, 

47 scalar_projections: NDArray[float] = None): 

48 super().__init__(data_container, 

49 basis_set, 

50 projector, 

51 use_scalar_projections, 

52 scalar_projections,) 

53 

54 def get_residuals(self, 

55 get_gradient: bool = False, 

56 get_weights: bool = False, 

57 gradient_part: str = None) -> dict[str, NDArray[float]]: 

58 """ Calculates a residuals and possibly a gradient between 

59 coefficients projected using the :attr:`BasisSet` and :attr:`Projector` 

60 attached to this instance. 

61 

62 Parameters 

63 ---------- 

64 get_gradient 

65 Whether to return the gradient. Default is ``False``. 

66 get_weights 

67 Whether to return weights. Default is ``False``. If ``True`` along with 

68 :attr:`get_gradient`, the gradient will be computed with weights. 

69 gradient_part 

70 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

71 being calculated. Default is ``None``. Raises a ``NotImplementedError`` for any other value. 

72 

73 Returns 

74 ------- 

75 A dictionary containing the residuals, and possibly the 

76 gradient and/or weights. If gradient and/or weights 

77 are not returned, their value will be ``None``. 

78 """ 

79 

80 if gradient_part is not None: 

81 raise NotImplementedError('The GradientResidualCalculator class does not work with optimizing ' 

82 'angles. Use the ZHTTResidualCalculator class instead.') 

83 

84 projection = self._basis_set.forward(self._projector.forward(self._coefficients)) 

85 residuals = projection - self._data 

86 if get_gradient: 

87 # todo: consider if more complicated behaviour is useful, 

88 # e.g. providing function to be applied to weights 

89 if get_weights: 

90 gradient = self._projector.adjoint( 

91 self._basis_set.gradient(residuals * self._weights).astype(self.dtype)) 

92 else: 

93 gradient = self._projector.adjoint( 

94 self._basis_set.gradient(residuals).astype(self.dtype)) 

95 else: 

96 gradient = None 

97 

98 if get_weights: 

99 weights = self._weights 

100 else: 

101 weights = None 

102 

103 return dict(residuals=residuals, gradient=gradient, weights=weights) 

104 

105 def get_gradient_from_residual_gradient(self, residual_gradient: NDArray[float]) -> Dict: 

106 """ Projects a residual gradient into coefficient and volume space. Used 

107 to get gradients from more complicated residuals, e.g., the Huber loss. 

108 Assumes that any weighting to the residual gradient has already been applied. 

109 

110 Parameters 

111 ---------- 

112 residual_gradient 

113 The residual gradient, from which to calculate the gradient. 

114 

115 Returns 

116 ------- 

117 An ``NDArray`` containing the gradient. 

118 """ 

119 return self._projector.adjoint( 

120 self._basis_set.gradient(residual_gradient).astype(self.dtype)) 

121 

122 def _update(self, force_update: bool = False) -> None: 

123 """ Carries out necessary updates if anything changes with respect to 

124 the geometry or basis set. """ 

125 if not (self.is_dirty or force_update): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true

126 return 

127 self._basis_set.probed_coordinates = self.probed_coordinates 

128 len_diff = len(self._basis_set) - self._coefficients.shape[-1] 

129 vol_diff = self._data_container.geometry.volume_shape - np.array(self._coefficients.shape[:-1]) 

130 # TODO: Think about whether the ``Method`` should do this or handle it differently 

131 if np.any(vol_diff != 0) or len_diff != 0: 

132 logger.warning('Shape of coefficient array has changed, array will be padded' 

133 ' or truncated.') 

134 # save old array, no copy needed 

135 old_coefficients = self._coefficients 

136 # initialize new array 

137 self._coefficients = \ 

138 np.zeros((*self._data_container.geometry.volume_shape, len(self._basis_set)), 

139 dtype=self.dtype) 

140 # for comparison of volume shapes 

141 shapes = zip(old_coefficients.shape[:-1], self._coefficients.shape[:-1]) 

142 # old coefficients go into middle of new coefficients except in last index 

143 slice_1 = tuple([slice(max(0, (d-s) // 2), min(d, (s + d) // 2)) for s, d in shapes]) + \ 

144 (slice(0, min(old_coefficients.shape[-1], self._coefficients.shape[-1])),) 

145 # zip objects are depleted 

146 shapes = zip(old_coefficients.shape[:-1], self._coefficients.shape[:-1]) 

147 slice_2 = tuple([slice(max(0, (s-d) // 2), min(s, (s + d) // 2)) for s, d in shapes]) + \ 

148 (slice(0, min(old_coefficients.shape[-1], self._coefficients.shape[-1])),) 

149 # assumption made that old_coefficients[..., 0] correspnds to self._coefficients[..., 0] 

150 self._coefficients[slice_1] = old_coefficients[slice_2] 

151 # Assumption may not be true for all representations! 

152 # TODO: Consider more logic here using e.g. basis set properties. 

153 if len_diff != 0: 

154 logger.warning('Size of basis set has changed. Coefficients have' 

155 ' been copied over starting at index 0. If coefficients' 

156 ' of new size do not line up with the old size,' 

157 ' please reinitialize the coefficients.') 

158 self._geometry_hash = hash(self._data_container.geometry) 

159 self._basis_set_hash = hash(self._basis_set) 

160 

161 def __hash__(self) -> int: 

162 """ Returns a hash of the current state of this instance. """ 

163 to_hash = [self._coefficients, 

164 hash(self._projector), 

165 hash(self._data_container.geometry), 

166 self._basis_set_hash, 

167 self._geometry_hash] 

168 return int(list_to_hash(to_hash), 16) 

169 

170 @property 

171 def is_dirty(self) -> bool: 

172 """ ``True`` if stored hashes of geometry or basis set objects do 

173 not match their current hashes. Used to trigger updates """ 

174 return ((self._geometry_hash != hash(self._data_container.geometry)) or 

175 (self._basis_set_hash != hash(self._basis_set))) 

176 

177 def __str__(self) -> str: 

178 wdt = 74 

179 s = [] 

180 s += ['=' * wdt] 

181 s += [self.__class__.__name__.center(wdt)] 

182 s += ['-' * wdt] 

183 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1): 

184 s += ['{:18} : {}'.format('BasisSet', self._basis_set.__class__.__name__)] 

185 s += ['{:18} : {}'.format('Projector', self._projector.__class__.__name__)] 

186 s += ['{:18} : {}'.format('is_dirty', self.is_dirty)] 

187 s += ['{:18} : {}'.format('probed_coordinates (hash)', 

188 hex(hash(self.probed_coordinates))[:6])] 

189 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

190 s += ['-' * wdt] 

191 return '\n'.join(s) 

192 

193 def _repr_html_(self) -> str: 

194 s = [] 

195 s += [f'<h3>{self.__class__.__name__}</h3>'] 

196 s += ['<table border="1" class="dataframe">'] 

197 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

198 s += ['<tbody>'] 

199 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

200 s += ['<tr><td style="text-align: left;">BasisSet</td>'] 

201 s += [f'<td>{1}</td><td>{self._basis_set.__class__.__name__}</td></tr>'] 

202 s += ['<tr><td style="text-align: left;">Projector</td>'] 

203 s += [f'<td>{len(self._projector.__class__.__name__)}</td>' 

204 f'<td>{self._projector.__class__.__name__}</td></tr>'] 

205 s += ['<tr><td style="text-align: left;">Is dirty</td>'] 

206 s += [f'<td>{1}</td><td>{self.is_dirty}</td></tr>'] 

207 s += ['<tr><td style="text-align: left;">probed_coordinates</td>'] 

208 s += [f'<td>{self.probed_coordinates.vector.shape}</td>' 

209 f'<td>{hex(hash(self.probed_coordinates))[:6]} (hash)</td></tr>'] 

210 s += ['<tr><td style="text-align: left;">Hash</td>'] 

211 h = hex(hash(self)) 

212 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>'] 

213 s += ['</tbody>'] 

214 s += ['</table>'] 

215 return '\n'.join(s)