Coverage for local_installation_linux/mumott/optimization/loss_functions/base_loss_function.py: 97%

155 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1from __future__ import annotations 

2from abc import ABC, abstractmethod 

3from typing import Any 

4 

5import numpy as np 

6 

7from mumott.core.hashing import list_to_hash 

8from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator 

9from mumott.optimization.regularizers.base_regularizer import Regularizer 

10 

11 

12class LossFunction(ABC): 

13 

14 """This is the base class from which specific loss functions are derived. 

15 

16 Parameters 

17 ---------- 

18 residual_calculator 

19 A class derived from 

20 :class:`ResidualCalculator 

21 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator>` 

22 use_weights 

23 Whether to multiply residuals with weights before calculating the residual norm. The calculation 

24 is also applied to the gradient. 

25 preconditioner 

26 A preconditioner to be applied to the residual norm gradient. Must have the same shape as 

27 :attr:`residual_calculator.coefficients 

28 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>` 

29 or it must be possible to broadcast by multiplication. 

30 Entries that are set to ``0`` in the preconditioner will be masked out in the application of 

31 the regularization gradient if ``use_preconditioner_mask`` is set to ``True``. 

32 residual_norm_multiplier 

33 A multiplier that is applied to the residual norm and gradient. Useful in cases where 

34 a very small or large loss function value changes the optimizer behaviour. 

35 use_preconditioner_mask 

36 If set to ``True`` (default), a mask will be derived from the ``preconditioner`` which 

37 masks out the entire gradient in areas where the ``preconditioner`` is not greater than 0. 

38 """ 

39 

40 def __init__(self, 

41 residual_calculator: ResidualCalculator, 

42 use_weights: bool = False, 

43 preconditioner: np.ndarray[float] = None, 

44 residual_norm_multiplier: float = 1, 

45 use_preconditioner_mask: bool = True): 

46 self._residual_calculator = residual_calculator 

47 self.use_weights = use_weights 

48 self._preconditioner = preconditioner 

49 self.use_preconditioner_mask = use_preconditioner_mask 

50 self.residual_norm_multiplier = residual_norm_multiplier 

51 self._regularizers = {} 

52 self._regularization_weights = {} 

53 

54 def get_loss(self, 

55 coefficients: np.ndarray[float] = None, 

56 get_gradient: bool = False, 

57 gradient_part: str = None): 

58 """Returns loss function value and possibly gradient based on the given :attr:`coefficients`. 

59 

60 Notes 

61 ----- 

62 This method simply calls the methods :meth:`get_residual_norm` and :meth:`get_regularization_norm` 

63 and sums up their respective contributions. 

64 

65 Parameters 

66 ---------- 

67 coefficients 

68 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients 

69 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`. 

70 Default value is ``None``, which leaves it up to :meth:`get_residual_norm` 

71 to handle the choice of coefficients, which in general defaults to using the coefficients 

72 of the attached :attr:`residual_calculator`. 

73 get_gradient 

74 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`residual_calculator.coefficients 

75 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`. 

76 Otherwise, the entry ``'gradient'`` will be ``None``. 

77 

78 Returns 

79 ------- 

80 A dictionary with at least two entries, ``loss`` and ``gradient``. 

81 """ 

82 residual_norm = self.get_residual_norm(coefficients, get_gradient, gradient_part) 

83 regularization = self.get_regularization_norm(self._residual_calculator.coefficients, 

84 get_gradient, gradient_part) 

85 

86 result = dict(loss=0., gradient=None) 

87 result['loss'] += residual_norm['residual_norm'] * self.residual_norm_multiplier 

88 for name in regularization: 

89 result['loss'] += regularization[name]['regularization_norm'] * self.regularization_weights[name] 

90 

91 if get_gradient: 

92 result['gradient'] = residual_norm['gradient'] * self.residual_norm_multiplier 

93 if self.preconditioner is not None: 

94 if self._residual_calculator.coefficients.shape[:-1] != self.preconditioner.shape[:-1]: 

95 raise ValueError('The first three dimensions of the preconditioner must' 

96 ' have the same size as the coefficients of the residual calculator,' 

97 ' and the last index must be the same or 1, but the' 

98 ' residual calculator coefficients have shape' 

99 f' {self._residual_calculator.coefficients.shape},' 

100 ' while the preconditioner has shape' 

101 f' {self.preconditioner.shape}!') 

102 result['gradient'] *= self.preconditioner 

103 for name in regularization: 

104 result['gradient'] += regularization[name]['gradient'] * self.regularization_weights[name] 

105 result['gradient'] *= self._gradient_mask 

106 return result 

107 

108 def get_residual_norm(self, 

109 coefficients: np.ndarray[float] = None, 

110 get_gradient: bool = False, 

111 gradient_part: str = None,) -> dict: 

112 """Returns residual norm and possibly gradient based on the attached 

113 :attr:`residual_calculator 

114 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator>`. 

115 If :attr:`coefficients` is given, :attr:`residual_calculator.coefficients 

116 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>` 

117 will be updated with these new values, otherwise, the residual norm and possibly the gradient 

118 will just be calculated using the current coefficients. 

119 

120 Parameters 

121 ---------- 

122 coefficients 

123 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients 

124 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`. 

125 get_gradient 

126 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`residual_calculator.coefficients 

127 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`. 

128 Otherwise, the entry ``'gradient'`` will be ``None``. 

129 

130 Returns 

131 ------- 

132 A dictionary with at least two entries, ``residual_norm`` and ``gradient``. 

133 """ 

134 if coefficients is not None: 

135 self._residual_calculator.coefficients = coefficients 

136 result = self._get_residual_norm_internal(get_gradient, gradient_part) 

137 return result 

138 

139 @abstractmethod 

140 def _get_residual_norm_internal(self, get_gradient: bool, gradient_part: str) -> dict: 

141 """ Method that implements the actual calculation of the residual norm. """ 

142 pass 

143 

144 @property 

145 def use_weights(self) -> bool: 

146 """ Whether to use weights or not in calculating the residual 

147 and gradient. """ 

148 return self._use_weights 

149 

150 @use_weights.setter 

151 def use_weights(self, val: bool) -> None: 

152 self._use_weights = val 

153 

154 @property 

155 def residual_norm_multiplier(self) -> float: 

156 """ Multiplicative factor by which the residual norm will be scaled. Can be used, 

157 together with any :attr:`regularization_weights`, to scale the loss function, 

158 in order to address unexpected behaviour that arises when some optimizers are given 

159 very small or very large loss functions. """ 

160 return self._residual_norm_multiplier 

161 

162 @residual_norm_multiplier.setter 

163 def residual_norm_multiplier(self, val: float) -> None: 

164 self._residual_norm_multiplier = val 

165 

166 @property 

167 def initial_values(self) -> np.ndarray[float]: 

168 """ Initial coefficient values for optimizer; defaults to zeros. """ 

169 return np.zeros_like(self._residual_calculator.coefficients) 

170 

171 @property 

172 def use_preconditioner_mask(self) -> bool: 

173 """ Determines whether a mask is calculated from the preconditioner.""" 

174 return self._use_preconditioner_mask 

175 

176 @use_preconditioner_mask.setter 

177 def use_preconditioner_mask(self, value: bool) -> None: 

178 self._use_preconditioner_mask = value 

179 self._update_mask() 

180 

181 @property 

182 def preconditioner(self) -> np.ndarray[float]: 

183 """ Preconditioner that is applied to the gradient by multiplication. """ 

184 return self._preconditioner 

185 

186 @property 

187 def preconditioner_hash(self) -> str: 

188 """ Hash of the preconditioner. """ 

189 if self._preconditioner is None: 189 ↛ 191line 189 didn't jump to line 191, because the condition on line 189 was never false

190 return None 

191 return list_to_hash([self._preconditioner])[:6] 

192 

193 @preconditioner.setter 

194 def preconditioner(self, value: np.ndarray[float]) -> None: 

195 self._preconditioner = value 

196 self._update_mask() 

197 

198 def _update_mask(self): 

199 """Updates the gradient mask.""" 

200 if self.preconditioner is not None and self.use_preconditioner_mask is True: 

201 self._gradient_mask = np.round(self.preconditioner > 0).astype(self._residual_calculator.dtype) 

202 else: 

203 self._gradient_mask = np.ones_like(self._residual_calculator.coefficients) 

204 

205 @property 

206 @abstractmethod 

207 def _function_as_str(self) -> str: 

208 """ Should return a string representation of the associated loss function 

209 of the residual in Python idiom, e.g. 'L(r) = 0.5 * r ** 2' for squared loss. """ 

210 pass 

211 

212 @property 

213 @abstractmethod 

214 def _function_as_tex(self) -> str: 

215 """ Should return a string representation of the associated loss function 

216 of the residual in MathJax-renderable TeX, e.g. $L(r) = \frac{r^2}{2}$ for squared loss""" 

217 pass 

218 

219 def get_regularization_norm(self, 

220 coefficients: np.ndarray[float] = None, 

221 get_gradient: bool = False, 

222 gradient_part: str = None) -> dict[str, dict[str, Any]]: 

223 """ Returns the regularization norm, and if requested, the gradient, from all 

224 regularizers attached to this instance, based on the provided :attr`coefficients`. 

225 If no coefficients are provided, the ones from the attached :attr:`residual_calculator` are used. 

226 

227 Parameters 

228 ---------- 

229 coefficients 

230 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients 

231 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`. 

232 get_gradient 

233 Whether to compute and return the gradient. Default is ``False`` 

234 gradient_part 

235 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

236 being calculated. Default is None. 

237 

238 Returns 

239 ------- 

240 A dictionary with one entry for each regularizer in :attr:`regularizers`, containing 

241 ``'regularization_norm'`` and ``'gradient'`` as entries. 

242 """ 

243 regularization = dict() 

244 

245 if coefficients is None: 

246 coefficients = self._residual_calculator.coefficients 

247 

248 for name in self._regularizers: 

249 reg = self._regularizers[name].get_regularization_norm(coefficients=coefficients, 

250 get_gradient=get_gradient) 

251 sub_dict = dict(gradient=None) 

252 sub_dict['regularization_norm'] = reg['regularization_norm'] 

253 

254 if get_gradient: 

255 sub_dict['gradient'] = reg['gradient'] 

256 regularization[name] = sub_dict 

257 return regularization 

258 

259 def add_regularizer(self, 

260 name: str, 

261 regularizer: Regularizer, 

262 regularization_weight: float) -> None: 

263 r""" Add a :ref:`regularizer <regularizers>` to the loss function. 

264 

265 Parameters 

266 ---------- 

267 name 

268 Name of the regularizer, to be used as its key. 

269 regularizer 

270 The :class:`Regularizer` instance to be attached. 

271 regularization_weight 

272 The regularization weight (often denoted :math:`\lambda`), 

273 by which the residual norm and gradient will be scaled. 

274 """ 

275 self._regularizers[name] = regularizer 

276 self._regularization_weights[name] = regularization_weight 

277 

278 @property 

279 def regularizers(self) -> dict[str, Regularizer]: 

280 """ The dictionary of regularizers appended to this loss function.""" 

281 return self._regularizers 

282 

283 @property 

284 def regularization_weights(self) -> dict[str, float]: 

285 """ The dictionary of regularization weights appended to this 

286 loss function. """ 

287 return self._regularization_weights 

288 

289 def __str__(self) -> str: 

290 s = [] 

291 wdt = 74 

292 s += ['=' * wdt] 

293 s += [self.__class__.__name__.center(wdt)] 

294 s += ['-' * wdt] 

295 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1): 

296 s += ['{:18} : {}'.format('ResidualCalculator', self._residual_calculator.__class__.__name__)] 

297 s += ['{:18} : {}'.format('Uses weights', self.use_weights)] 

298 s += ['{:18} : {}'.format('preconditioner_hash', self.preconditioner_hash)] 

299 s += ['{:18} : {}'.format('residual_norm_multiplier', 

300 self._residual_norm_multiplier)] 

301 s += ['{:18} : {}'.format('Function of residual', self._function_as_str)] 

302 s += ['{:18} : {}'.format('Use preconditioner mask', self.use_preconditioner_mask)] 

303 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

304 s += ['-' * wdt] 

305 return '\n'.join(s) 

306 

307 def _repr_html_(self) -> str: 

308 s = [] 

309 s += [f'<h3>{__class__.__name__}</h3>'] 

310 s += ['<table border="1" class="dataframe">'] 

311 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

312 s += ['<tbody>'] 

313 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

314 s += ['<tr><td style="text-align: left;">ResidualCalculator</td>'] 

315 s += [f'<td>{1}</td><td>{self._residual_calculator.__class__.__name__}</td></tr>'] 

316 s += ['<tr><td style="text-align: left;">use_weights</td>'] 

317 s += [f'<td>{1}</td><td>{self.use_weights}</td></tr>'] 

318 s += ['<tr><td style="text-align: left;">preconditioner_hash</td>'] 

319 s += [f'<td>{1}</td><td>{self.preconditioner_hash}</td></tr>'] 

320 s += ['<tr><td style="text-align: left;">residual_norm_multiplier</td>'] 

321 s += [f'<td>{1}</td><td>{self._residual_norm_multiplier}</td></tr>'] 

322 s += ['<tr><td style="text-align: left;">Function of residual r</td>'] 

323 s += [f'<td>1</td><td>{self._function_as_tex}</td></tr>'] 

324 s += ['<tr><td style="text-align: left;">Use preconditioner mask</td>'] 

325 s += [f'<td>1</td><td>{self._use_preconditioner_mask}</td></tr>'] 

326 s += ['<tr><td style="text-align: left;">Hash</td>'] 

327 h = hex(hash(self)) 

328 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>'] 

329 s += ['</tbody>'] 

330 s += ['</table>'] 

331 return '\n'.join(s) 

332 

333 def __hash__(self) -> int: 

334 to_hash = [hash(self._residual_calculator), 

335 self.use_weights, 

336 self._residual_norm_multiplier, 

337 self._preconditioner, 

338 self._use_preconditioner_mask] 

339 for r in self.regularizers: 

340 to_hash.append(hash(self.regularizers[r])) 

341 to_hash.append(hash(self.regularization_weights[r])) 

342 return int(list_to_hash(to_hash), 16)