Coverage for local_installation_linux/mumott/methods/projectors/saxs_projector.py: 98%

120 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2 

3from typing import Tuple 

4 

5import numpy as np 

6from numpy.typing import NDArray 

7 

8from mumott import Geometry 

9from mumott.core.john_transform import john_transform, john_transform_adjoint 

10from mumott.core.hashing import list_to_hash 

11from .base_projector import Projector 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16class SAXSProjector(Projector): 

17 """ 

18 Projector for transforms of tensor fields from three-dimensional space 

19 to projection space using a bilinear interpolation algorithm that produces results similar 

20 to those of :class:`SAXSProjectorCUDA <mumott.methods.projectors.SAXSProjectorCUDA>` 

21 using CPU computation. 

22 

23 Parameters 

24 ---------- 

25 geometry : Geometry 

26 An instance of :class:`Geometry <mumott.Geometry>` containing the 

27 necessary vectors to compute forwared and adjoint projections. 

28 """ 

29 def __init__(self, 

30 geometry: Geometry): 

31 

32 super().__init__(geometry) 

33 self._update(force_update=True) 

34 self._numba_hash = None 

35 self._compiled_john_transform = None 

36 self._compiled_john_transform_adjoint = None 

37 

38 @staticmethod 

39 def _get_zeros_method(array: NDArray): 

40 """ Internal method for dispatching functions for array allocation. 

41 Included to simplify subclassing.""" 

42 return np.zeros 

43 

44 def _get_john_transform_parameters(self, 

45 indices: NDArray[int] = None) -> Tuple: 

46 if indices is None: 

47 indices = np.s_[:] 

48 vector_p = self._basis_vector_projection[indices] 

49 vector_j = self._basis_vector_j[indices] 

50 vector_k = self._basis_vector_k[indices] 

51 j_offsets = self._geometry.j_offsets_as_array[indices] 

52 k_offsets = self._geometry.k_offsets_as_array[indices] 

53 return (vector_p, vector_j, vector_k, j_offsets, k_offsets) 

54 

55 def forward(self, 

56 field: NDArray, 

57 indices: NDArray[int] = None) -> NDArray: 

58 """ Compute the forward projection of a tensor field. 

59 

60 Parameters 

61 ---------- 

62 field 

63 An array containing coefficients in its fourth dimension, 

64 which are to be projected into two dimensions. The first three 

65 dimensions should match the ``volume_shape`` of the sample. 

66 indices 

67 A one-dimensional array containing one or more indices 

68 indicating which projections are to be computed. If ``None``, 

69 all projections will be computed. 

70 

71 Returns 

72 ------- 

73 An array with four dimensions ``(I, J, K, L)``, where 

74 the first dimension matches :attr:`indices`, such that 

75 ``projection[i]`` corresponds to the geometry of projection 

76 ``indices[i]``. The second and third dimension contain 

77 the pixels in the ``J`` and ``K`` dimension respectively, whereas 

78 the last dimension is the coefficient dimension, matching ``field[-1]``. 

79 """ 

80 if not np.allclose(field.shape[:-1], self._geometry.volume_shape): 

81 raise ValueError(f'The shape of the input field ({field.shape}) does not match the' 

82 f' volume shape expected by the projector ({self._geometry.volume_shape})') 

83 self._update() 

84 if indices is None: 

85 return self._forward_stack(field) 

86 return self._forward_subset(field, indices) 

87 

88 def _forward_subset(self, 

89 field: NDArray, 

90 indices: NDArray[int]) -> NDArray: 

91 """ Internal method for computing a subset of projections. 

92 

93 Parameters 

94 ---------- 

95 field 

96 The field to be projected. 

97 indices 

98 The indices indicating the subset of all projections in the 

99 system geometry to be computed. 

100 

101 Returns 

102 ------- 

103 The resulting projections. 

104 """ 

105 indices = np.array(indices).ravel() 

106 init_method = self._get_zeros_method(field) 

107 projections = init_method((indices.size,) + 

108 tuple(self._geometry.projection_shape) + 

109 (field.shape[-1],), dtype=self.dtype) 

110 self._check_indices_kind_is_integer(indices) 

111 return self._john_transform( 

112 field, projections, *self._get_john_transform_parameters(indices)) 

113 

114 def _forward_stack(self, 

115 field: NDArray) -> NDArray: 

116 """Internal method for forward projecting an entire stack. 

117 

118 Parameters 

119 ---------- 

120 field 

121 The field to be projected. 

122 

123 Returns 

124 ------- 

125 The resulting projections. 

126 """ 

127 init_method = self._get_zeros_method(field) 

128 projections = init_method((len(self._geometry),) + 

129 tuple(self._geometry.projection_shape) + 

130 (field.shape[-1],), dtype=self.dtype) 

131 return self._john_transform(field, projections, *self._get_john_transform_parameters()) 

132 

133 def adjoint(self, 

134 projections: NDArray, 

135 indices: NDArray[int] = None) -> NDArray: 

136 """ Compute the adjoint of a set of projections according to the system geometry. 

137 

138 Parameters 

139 ---------- 

140 projections 

141 An array containing coefficients in its last dimension, 

142 from e.g. the residual of measured data and forward projections. 

143 The first dimension should match :attr:`indices` in size, and the 

144 second and third dimensions should match the system projection geometry. 

145 The array must be contiguous and row-major. 

146 indices 

147 A one-dimensional array containing one or more indices 

148 indicating from which projections the adjoint is to be computed. 

149 

150 Returns 

151 ------- 

152 The adjoint of the provided projections. 

153 An array with four dimensions ``(X, Y, Z, P)``, where the first 

154 three dimensions are spatial and the last dimension runs over 

155 coefficients. 

156 """ 

157 if not np.allclose(projections.shape[-3:-1], self._geometry.projection_shape): 

158 raise ValueError(f'The shape of the projections ({projections.shape}) does not match the' 

159 f' projection shape expected by the projector' 

160 f' ({self._geometry.projection_shape})') 

161 if not projections.flags['C_CONTIGUOUS']: 

162 raise ValueError('The projections array must be contiguous and row-major, ' 

163 f'but has strides {projections.strides}.') 

164 

165 self._update() 

166 if indices is None: 

167 return self._adjoint_stack(projections) 

168 return self._adjoint_subset(projections, indices) 

169 

170 def _adjoint_subset(self, 

171 projections: NDArray, 

172 indices: NDArray[int]) -> NDArray: 

173 """ Internal method for computing the adjoint of only a subset of projections. 

174 

175 Parameters 

176 ---------- 

177 projections 

178 An array containing coefficients in its last dimension, 

179 from e.g. the residual of measured data and forward projections. 

180 The first dimension should match :attr:`indices` in size, and the 

181 second and third dimensions should match the system projection geometry. 

182 indices 

183 A one-dimensional array containing one or more indices 

184 indicating from which projections the adjoint is to be computed. 

185 

186 Returns 

187 ------- 

188 The adjoint of the provided projections. 

189 An array with four dimensions ``(X, Y, Z, P)``, where the first 

190 three dimensions are spatial and the last dimension runs over 

191 coefficients. """ 

192 indices = np.array(indices).ravel() 

193 if projections.ndim == 3: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 assert indices.size == 1 

195 projections = projections[np.newaxis, ...] 

196 else: 

197 assert indices.size == projections.shape[0] 

198 self._check_indices_kind_is_integer(indices) 

199 init_method = self._get_zeros_method(projections) 

200 field = init_method(tuple(self._geometry.volume_shape) + 

201 (projections.shape[-1],), dtype=self.dtype) 

202 return self._john_transform_adjoint( 

203 field, projections, *self._get_john_transform_parameters(indices)) 

204 

205 def _adjoint_stack(self, 

206 projections: NDArray) -> NDArray: 

207 """ Internal method for computing the adjoint of a whole stack of projections. 

208 

209 Parameters 

210 ---------- 

211 projections 

212 An array containing coefficients in its last dimension, 

213 from e.g. the residual of measured data and forward projections. 

214 The first dimension should run over all the projection directions 

215 in the system geometry. 

216 

217 Returns 

218 ------- 

219 The adjoint of the provided projections. 

220 An array with four dimensions ``(X, Y, Z, P)``, where the first 

221 three dimensions are spatial, and the last dimension runs over 

222 coefficients. """ 

223 assert projections.shape[0] == len(self._geometry) 

224 init_method = self._get_zeros_method(projections) 

225 field = init_method(tuple(self._geometry.volume_shape) + 

226 (projections.shape[-1],), dtype=self.dtype) 

227 return self._john_transform_adjoint( 

228 field, projections, *self._get_john_transform_parameters()) 

229 

230 def _compile_john_transform(self, 

231 field: NDArray[float], 

232 projections: NDArray[float], 

233 *args) -> None: 

234 """ Internal method for compiling John transform only as needed. """ 

235 self._compiled_john_transform = john_transform( 

236 field, projections, *args) 

237 self._compiled_john_transform_adjoint = john_transform_adjoint( 

238 field, projections, *args) 

239 

240 def _john_transform(self, 

241 field: NDArray[float], 

242 projections: NDArray[float], 

243 *args) -> None: 

244 """ Internal method for dispatching John Transform call. Included to 

245 simplify subclassing. Note that the result is calculated in-place.""" 

246 to_hash = [field.shape[-1], *args] 

247 current_hash = list_to_hash(to_hash) 

248 if list_to_hash(to_hash) != self._numba_hash: 

249 self._compile_john_transform(field, projections, *args) 

250 self._numba_hash = current_hash 

251 return self._compiled_john_transform(field, projections) 

252 

253 def _john_transform_adjoint(self, 

254 field: NDArray[float], 

255 projections: NDArray[float], 

256 *args) -> None: 

257 """ Internal method for dispatching john transform adjoint function call. Included to 

258 simplify subclassing. Note that the result is calculated in-place.""" 

259 to_hash = [field.shape[-1], *args] 

260 current_hash = list_to_hash(to_hash) 

261 if list_to_hash(to_hash) != self._numba_hash: 

262 self._compile_john_transform(field, projections, *args) 

263 self._numba_hash = current_hash 

264 return self._compiled_john_transform_adjoint(field, projections) 

265 

266 @property 

267 def john_transform_parameters(self) -> tuple: 

268 """ Tuple of John Transform parameters, which can be passed manually 

269 to compile John Transform kernels and construct low-level pipelines. 

270 For advanced users only.""" 

271 return self._get_john_transform_parameters() 

272 

273 @property 

274 def dtype(self) -> np.typing.DTypeLike: 

275 """ Preferred dtype of this ``Projector``. """ 

276 return np.float64 

277 

278 def __hash__(self) -> int: 

279 to_hash = [self._basis_vector_projection, 

280 self._basis_vector_j, 

281 self._basis_vector_k, 

282 self._geometry_hash, 

283 hash(self._geometry), 

284 self._numba_hash] 

285 return int(list_to_hash(to_hash), 16) 

286 

287 def __str__(self) -> str: 

288 wdt = 74 

289 s = [] 

290 s += ['-' * wdt] 

291 s += [self.__class__.__name__.center(wdt)] 

292 s += ['-' * wdt] 

293 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60): 

294 s += ['{:18} : {}'.format('is_dirty', self.is_dirty)] 

295 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

296 s += ['-' * wdt] 

297 return '\n'.join(s) 

298 

299 def _repr_html_(self) -> str: 

300 s = [] 

301 s += [f'<h3>{self.__class__.__name__}</h3>'] 

302 s += ['<table border="1" class="dataframe">'] 

303 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

304 s += ['<tbody>'] 

305 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

306 s += ['<tr><td style="text-align: left;">is_dirty</td>'] 

307 s += [f'<td>1</td><td>{self.is_dirty}</td></tr>'] 

308 s += ['<tr><td style="text-align: left;">hash</td>'] 

309 s += [f'<td>{len(hex(hash(self)))}</td><td>{hex(hash(self))[2:8]}</td></tr>'] 

310 s += ['</tbody>'] 

311 s += ['</table>'] 

312 return '\n'.join(s)