Coverage for local_installation_linux/mumott/methods/basis_sets/trivial_basis.py: 83%

88 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2from typing import Any, Dict, Iterator, Tuple 

3 

4import numpy as np 

5from numpy.typing import NDArray 

6 

7from mumott import ProbedCoordinates 

8from mumott.core.hashing import list_to_hash 

9from .base_basis_set import BasisSet 

10 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15class TrivialBasis(BasisSet): 

16 """ Basis set class for the trivial basis, i.e., the identity basis. 

17 This can be used as a scaffolding class when implementing, e.g., scalar tomography, 

18 as it implements all the necessary functionality to qualify as a :class:`BasisSet`. 

19 

20 Parameters 

21 ---------- 

22 channels 

23 Number of channels in the last index. Default is ``1``. For scalar data, 

24 the default value of ``1`` is appropriate. For any other use-case, where the representation 

25 on the sphere and the representation in detector space are equivalent, 

26 such as reconstructing scalars of multiple q-ranges at once, a different 

27 number of channels can be set. 

28 """ 

29 def __init__(self, channels: int = 1): 

30 self._probed_coordinates = ProbedCoordinates(vector=np.array((0., 0., 1.))) 

31 self._probed_coordinates_hash = hash(self.probed_coordinates) 

32 self._channels = channels 

33 

34 def _get_projection_matrix(self, probed_coordinates: ProbedCoordinates = None): 

35 return np.eye(self._channels) 

36 

37 def forward(self, 

38 coefficients: NDArray, 

39 *args, 

40 **kwargs) -> NDArray: 

41 """ Returns the provided coefficients with no modification. 

42 

43 Parameters 

44 ---------- 

45 coefficients 

46 An array of coefficients, of arbitrary shape, except the last index 

47 must specify the same number of channels as was specified for this basis. 

48 

49 Returns 

50 ------- 

51 The provided :attr`coefficients` with no modification. 

52 

53 Notes 

54 ----- 

55 The :attr:`args` and :attr:`kwargs` are ignored, but included for compatibility with methods 

56 that input other arguments. 

57 """ 

58 assert coefficients.shape[-1] == len(self) 

59 return coefficients 

60 

61 def gradient(self, 

62 coefficients: NDArray, 

63 *args, 

64 **kwargs) -> NDArray: 

65 """ Returns the provided coefficients with no modification. 

66 

67 Parameters 

68 ---------- 

69 coefficients 

70 An array of coefficients of arbitrary shape except the last index 

71 must specify the same number of channels as was specified for this basis. 

72 

73 Returns 

74 ------- 

75 The provided :attr`coefficients` with no modification. 

76 

77 Notes 

78 ----- 

79 The :attr:`args` and :attr:`kwargs` are ignored, but included for compatibility with methods 

80 that input other argumetns. 

81 """ 

82 assert coefficients.shape[-1] == len(self) 

83 return coefficients 

84 

85 def get_inner_product(self, 

86 u: NDArray, 

87 v: NDArray) -> NDArray: 

88 r""" Retrieves the inner product of two coefficient arrays, that is to say, 

89 the sum-product over the last axis. 

90 

91 Parameters 

92 ---------- 

93 u 

94 The first coefficient array, of arbitrary shape and dimension. 

95 v 

96 The second coefficient array, of the same shape as :attr:`u`. 

97 """ 

98 assert u.shape[-1] == len(self) 

99 assert u.shape == v.shape 

100 return np.einsum('...i, ...i -> ...', u, v, 

101 optimize='greedy') 

102 

103 def get_output(self, 

104 coefficients: NDArray) -> Dict[str, Any]: 

105 r""" Returns a dictionary of output data for a given array of coefficients. 

106 

107 Parameters 

108 ---------- 

109 coefficients 

110 An array of coefficients of arbitrary shape and dimension. 

111 Computations only operate over the last axis of :attr:`coefficents`, so derived 

112 properties in the output will have the shape ``(*coefficients.shape[:-1], ...)``. 

113 

114 Returns 

115 ------- 

116 A dictionary containing a dictionary with the field ``basis_set``. 

117 

118 Notes 

119 ----- 

120 In detail, the dictionary under the key ``basis_set`` contains: 

121 

122 basis_set 

123 name 

124 The name of the basis set, i.e., ``'TrivialBasis'`` 

125 coefficients 

126 A copy of :attr:`coefficients`. 

127 projection_matrix 

128 The identity matrix of the same size as the number of chanenls. 

129 """ 

130 assert coefficients.shape[-1] == len(self) 

131 # Update to ensure non-dirty output state. 

132 self._update() 

133 output_dictionary = {} 

134 

135 # basis set-specific information 

136 basis_set = {} 

137 output_dictionary['basis_set'] = basis_set 

138 basis_set['name'] = type(self).__name__ 

139 basis_set['coefficients'] = coefficients.copy() 

140 basis_set['projection_matrix'] = self.projection_matrix 

141 basis_set['hash'] = hex(hash(self)) 

142 return output_dictionary 

143 

144 def get_spherical_harmonic_coefficients( 

145 self, 

146 coefficients: NDArray[float], 

147 ell_max: int = None 

148 ) -> NDArray[float]: 

149 """ Convert a set of spherical harmonics coefficients to a different :attr:`ell_max` 

150 by either zero-padding or truncation and return the result. 

151 

152 Parameters 

153 ---------- 

154 coefficients 

155 An array of coefficients of arbitrary shape, provided that the 

156 last dimension contains the coefficients for one function. 

157 ell_max 

158 The band limit of the spherical harmonic expansion. 

159 """ 

160 

161 if coefficients.shape[-1] != len(self): 

162 raise ValueError(f'The number of coefficients ({coefficients.shape[-1]}) does not match ' 

163 f'the expected value. ({len(self)})') 

164 

165 num_coeff_output = (ell_max+1) * (ell_max+2) // 2 

166 

167 output_array = np.zeros((*coefficients.shape[:-1], num_coeff_output)) 

168 output_array[..., 0] = coefficients[..., 0] 

169 return output_array 

170 

171 def __iter__(self) -> Iterator[Tuple[str, Any]]: 

172 """ Allows class to be iterated over and in particular be cast as a dictionary. 

173 """ 

174 yield 'name', type(self).__name__ 

175 yield 'projection_matrix', self._projection_matrix 

176 yield 'hash', hex(hash(self))[2:] 

177 

178 def __len__(self) -> int: 

179 return self._channels 

180 

181 def __hash__(self) -> int: 

182 """Returns a hash reflecting the internal state of the instance. 

183 

184 Returns 

185 ------- 

186 A hash of the internal state of the instance, 

187 cast as an ``int``. 

188 """ 

189 to_hash = [self._channels, 

190 self._probed_coordinates_hash] 

191 return int(list_to_hash(to_hash), 16) 

192 

193 def _update(self) -> None: 

194 if self.is_dirty: 

195 self._probed_coordinates_hash = hash(self._probed_coordinates) 

196 

197 @property 

198 def channels(self) -> int: 

199 """ The number of channels this basis supports. """ 

200 return self._channels 

201 

202 @channels.setter 

203 def channels(self, value: int) -> None: 

204 self._channels = value 

205 

206 @property 

207 def projection_matrix(self): 

208 """The identity matrix of the same rank as the number of channels 

209 specified.""" 

210 return np.eye(self._channels, dtype=np.float64) 

211 

212 @property 

213 def is_dirty(self) -> bool: 

214 return hash(self._probed_coordinates) != self._probed_coordinates_hash 

215 

216 def __str__(self) -> str: 

217 wdt = 74 

218 s = [] 

219 s += ['-' * wdt] 

220 s += [self.__class__.__name__.center(wdt)] 

221 s += ['-' * wdt] 

222 with np.printoptions(threshold=4, edgeitems=2, precision=5, linewidth=60): 

223 s += ['{:18} : {}'.format('Hash', hex(hash(self))[2:8])] 

224 s += ['-' * wdt] 

225 return '\n'.join(s) 

226 

227 def _repr_html_(self) -> str: 

228 s = [] 

229 s += [f'<h3>{self.__class__.__name__}</h3>'] 

230 s += ['<table border="1" class="dataframe">'] 

231 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

232 s += ['<tbody>'] 

233 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

234 s += [f'<td>{len(hex(hash(self)))}</td><td>{hex(hash(self))[2:8]}</td></tr>'] 

235 s += ['</tbody>'] 

236 s += ['</table>'] 

237 return '\n'.join(s)