Coverage for local_installation_linux/mumott/optimization/loss_functions/base_loss_function.py: 97%
155 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1from __future__ import annotations
2from abc import ABC, abstractmethod
3from typing import Any
5import numpy as np
7from mumott.core.hashing import list_to_hash
8from mumott.methods.residual_calculators.base_residual_calculator import ResidualCalculator
9from mumott.optimization.regularizers.base_regularizer import Regularizer
12class LossFunction(ABC):
14 """This is the base class from which specific loss functions are derived.
16 Parameters
17 ----------
18 residual_calculator
19 A class derived from
20 :class:`ResidualCalculator
21 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator>`
22 use_weights
23 Whether to multiply residuals with weights before calculating the residual norm. The calculation
24 is also applied to the gradient.
25 preconditioner
26 A preconditioner to be applied to the residual norm gradient. Must have the same shape as
27 :attr:`residual_calculator.coefficients
28 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`
29 or it must be possible to broadcast by multiplication.
30 Entries that are set to ``0`` in the preconditioner will be masked out in the application of
31 the regularization gradient if ``use_preconditioner_mask`` is set to ``True``.
32 residual_norm_multiplier
33 A multiplier that is applied to the residual norm and gradient. Useful in cases where
34 a very small or large loss function value changes the optimizer behaviour.
35 use_preconditioner_mask
36 If set to ``True`` (default), a mask will be derived from the ``preconditioner`` which
37 masks out the entire gradient in areas where the ``preconditioner`` is not greater than 0.
38 """
40 def __init__(self,
41 residual_calculator: ResidualCalculator,
42 use_weights: bool = False,
43 preconditioner: np.ndarray[float] = None,
44 residual_norm_multiplier: float = 1,
45 use_preconditioner_mask: bool = True):
46 self._residual_calculator = residual_calculator
47 self.use_weights = use_weights
48 self._preconditioner = preconditioner
49 self.use_preconditioner_mask = use_preconditioner_mask
50 self.residual_norm_multiplier = residual_norm_multiplier
51 self._regularizers = {}
52 self._regularization_weights = {}
54 def get_loss(self,
55 coefficients: np.ndarray[float] = None,
56 get_gradient: bool = False,
57 gradient_part: str = None):
58 """Returns loss function value and possibly gradient based on the given :attr:`coefficients`.
60 Notes
61 -----
62 This method simply calls the methods :meth:`get_residual_norm` and :meth:`get_regularization_norm`
63 and sums up their respective contributions.
65 Parameters
66 ----------
67 coefficients
68 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients
69 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`.
70 Default value is ``None``, which leaves it up to :meth:`get_residual_norm`
71 to handle the choice of coefficients, which in general defaults to using the coefficients
72 of the attached :attr:`residual_calculator`.
73 get_gradient
74 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`residual_calculator.coefficients
75 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`.
76 Otherwise, the entry ``'gradient'`` will be ``None``.
78 Returns
79 -------
80 A dictionary with at least two entries, ``loss`` and ``gradient``.
81 """
82 residual_norm = self.get_residual_norm(coefficients, get_gradient, gradient_part)
83 regularization = self.get_regularization_norm(self._residual_calculator.coefficients,
84 get_gradient, gradient_part)
86 result = dict(loss=0., gradient=None)
87 result['loss'] += residual_norm['residual_norm'] * self.residual_norm_multiplier
88 for name in regularization:
89 result['loss'] += regularization[name]['regularization_norm'] * self.regularization_weights[name]
91 if get_gradient:
92 result['gradient'] = residual_norm['gradient'] * self.residual_norm_multiplier
93 if self.preconditioner is not None:
94 if self._residual_calculator.coefficients.shape[:-1] != self.preconditioner.shape[:-1]:
95 raise ValueError('The first three dimensions of the preconditioner must'
96 ' have the same size as the coefficients of the residual calculator,'
97 ' and the last index must be the same or 1, but the'
98 ' residual calculator coefficients have shape'
99 f' {self._residual_calculator.coefficients.shape},'
100 ' while the preconditioner has shape'
101 f' {self.preconditioner.shape}!')
102 result['gradient'] *= self.preconditioner
103 for name in regularization:
104 result['gradient'] += regularization[name]['gradient'] * self.regularization_weights[name]
105 result['gradient'] *= self._gradient_mask
106 return result
108 def get_residual_norm(self,
109 coefficients: np.ndarray[float] = None,
110 get_gradient: bool = False,
111 gradient_part: str = None,) -> dict:
112 """Returns residual norm and possibly gradient based on the attached
113 :attr:`residual_calculator
114 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator>`.
115 If :attr:`coefficients` is given, :attr:`residual_calculator.coefficients
116 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`
117 will be updated with these new values, otherwise, the residual norm and possibly the gradient
118 will just be calculated using the current coefficients.
120 Parameters
121 ----------
122 coefficients
123 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients
124 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`.
125 get_gradient
126 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`residual_calculator.coefficients
127 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`.
128 Otherwise, the entry ``'gradient'`` will be ``None``.
130 Returns
131 -------
132 A dictionary with at least two entries, ``residual_norm`` and ``gradient``.
133 """
134 if coefficients is not None:
135 self._residual_calculator.coefficients = coefficients
136 result = self._get_residual_norm_internal(get_gradient, gradient_part)
137 return result
139 @abstractmethod
140 def _get_residual_norm_internal(self, get_gradient: bool, gradient_part: str) -> dict:
141 """ Method that implements the actual calculation of the residual norm. """
142 pass
144 @property
145 def use_weights(self) -> bool:
146 """ Whether to use weights or not in calculating the residual
147 and gradient. """
148 return self._use_weights
150 @use_weights.setter
151 def use_weights(self, val: bool) -> None:
152 self._use_weights = val
154 @property
155 def residual_norm_multiplier(self) -> float:
156 """ Multiplicative factor by which the residual norm will be scaled. Can be used,
157 together with any :attr:`regularization_weights`, to scale the loss function,
158 in order to address unexpected behaviour that arises when some optimizers are given
159 very small or very large loss functions. """
160 return self._residual_norm_multiplier
162 @residual_norm_multiplier.setter
163 def residual_norm_multiplier(self, val: float) -> None:
164 self._residual_norm_multiplier = val
166 @property
167 def initial_values(self) -> np.ndarray[float]:
168 """ Initial coefficient values for optimizer; defaults to zeros. """
169 return np.zeros_like(self._residual_calculator.coefficients)
171 @property
172 def use_preconditioner_mask(self) -> bool:
173 """ Determines whether a mask is calculated from the preconditioner."""
174 return self._use_preconditioner_mask
176 @use_preconditioner_mask.setter
177 def use_preconditioner_mask(self, value: bool) -> None:
178 self._use_preconditioner_mask = value
179 self._update_mask()
181 @property
182 def preconditioner(self) -> np.ndarray[float]:
183 """ Preconditioner that is applied to the gradient by multiplication. """
184 return self._preconditioner
186 @property
187 def preconditioner_hash(self) -> str:
188 """ Hash of the preconditioner. """
189 if self._preconditioner is None: 189 ↛ 191line 189 didn't jump to line 191, because the condition on line 189 was never false
190 return None
191 return list_to_hash([self._preconditioner])[:6]
193 @preconditioner.setter
194 def preconditioner(self, value: np.ndarray[float]) -> None:
195 self._preconditioner = value
196 self._update_mask()
198 def _update_mask(self):
199 """Updates the gradient mask."""
200 if self.preconditioner is not None and self.use_preconditioner_mask is True:
201 self._gradient_mask = np.round(self.preconditioner > 0).astype(self._residual_calculator.dtype)
202 else:
203 self._gradient_mask = np.ones_like(self._residual_calculator.coefficients)
205 @property
206 @abstractmethod
207 def _function_as_str(self) -> str:
208 """ Should return a string representation of the associated loss function
209 of the residual in Python idiom, e.g. 'L(r) = 0.5 * r ** 2' for squared loss. """
210 pass
212 @property
213 @abstractmethod
214 def _function_as_tex(self) -> str:
215 """ Should return a string representation of the associated loss function
216 of the residual in MathJax-renderable TeX, e.g. $L(r) = \frac{r^2}{2}$ for squared loss"""
217 pass
219 def get_regularization_norm(self,
220 coefficients: np.ndarray[float] = None,
221 get_gradient: bool = False,
222 gradient_part: str = None) -> dict[str, dict[str, Any]]:
223 """ Returns the regularization norm, and if requested, the gradient, from all
224 regularizers attached to this instance, based on the provided :attr`coefficients`.
225 If no coefficients are provided, the ones from the attached :attr:`residual_calculator` are used.
227 Parameters
228 ----------
229 coefficients
230 An ``np.ndarray`` of values of the same shape as :attr:`residual_calculator.coefficients
231 <mumott.methods.residual_calculators.base_residual_calculator.ResidualCalculator.coefficients>`.
232 get_gradient
233 Whether to compute and return the gradient. Default is ``False``
234 gradient_part
235 Used for the zonal harmonics resonstructions to determine what part of the gradient is
236 being calculated. Default is None.
238 Returns
239 -------
240 A dictionary with one entry for each regularizer in :attr:`regularizers`, containing
241 ``'regularization_norm'`` and ``'gradient'`` as entries.
242 """
243 regularization = dict()
245 if coefficients is None:
246 coefficients = self._residual_calculator.coefficients
248 for name in self._regularizers:
249 reg = self._regularizers[name].get_regularization_norm(coefficients=coefficients,
250 get_gradient=get_gradient)
251 sub_dict = dict(gradient=None)
252 sub_dict['regularization_norm'] = reg['regularization_norm']
254 if get_gradient:
255 sub_dict['gradient'] = reg['gradient']
256 regularization[name] = sub_dict
257 return regularization
259 def add_regularizer(self,
260 name: str,
261 regularizer: Regularizer,
262 regularization_weight: float) -> None:
263 r""" Add a :ref:`regularizer <regularizers>` to the loss function.
265 Parameters
266 ----------
267 name
268 Name of the regularizer, to be used as its key.
269 regularizer
270 The :class:`Regularizer` instance to be attached.
271 regularization_weight
272 The regularization weight (often denoted :math:`\lambda`),
273 by which the residual norm and gradient will be scaled.
274 """
275 self._regularizers[name] = regularizer
276 self._regularization_weights[name] = regularization_weight
278 @property
279 def regularizers(self) -> dict[str, Regularizer]:
280 """ The dictionary of regularizers appended to this loss function."""
281 return self._regularizers
283 @property
284 def regularization_weights(self) -> dict[str, float]:
285 """ The dictionary of regularization weights appended to this
286 loss function. """
287 return self._regularization_weights
289 def __str__(self) -> str:
290 s = []
291 wdt = 74
292 s += ['=' * wdt]
293 s += [self.__class__.__name__.center(wdt)]
294 s += ['-' * wdt]
295 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
296 s += ['{:18} : {}'.format('ResidualCalculator', self._residual_calculator.__class__.__name__)]
297 s += ['{:18} : {}'.format('Uses weights', self.use_weights)]
298 s += ['{:18} : {}'.format('preconditioner_hash', self.preconditioner_hash)]
299 s += ['{:18} : {}'.format('residual_norm_multiplier',
300 self._residual_norm_multiplier)]
301 s += ['{:18} : {}'.format('Function of residual', self._function_as_str)]
302 s += ['{:18} : {}'.format('Use preconditioner mask', self.use_preconditioner_mask)]
303 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
304 s += ['-' * wdt]
305 return '\n'.join(s)
307 def _repr_html_(self) -> str:
308 s = []
309 s += [f'<h3>{__class__.__name__}</h3>']
310 s += ['<table border="1" class="dataframe">']
311 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
312 s += ['<tbody>']
313 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
314 s += ['<tr><td style="text-align: left;">ResidualCalculator</td>']
315 s += [f'<td>{1}</td><td>{self._residual_calculator.__class__.__name__}</td></tr>']
316 s += ['<tr><td style="text-align: left;">use_weights</td>']
317 s += [f'<td>{1}</td><td>{self.use_weights}</td></tr>']
318 s += ['<tr><td style="text-align: left;">preconditioner_hash</td>']
319 s += [f'<td>{1}</td><td>{self.preconditioner_hash}</td></tr>']
320 s += ['<tr><td style="text-align: left;">residual_norm_multiplier</td>']
321 s += [f'<td>{1}</td><td>{self._residual_norm_multiplier}</td></tr>']
322 s += ['<tr><td style="text-align: left;">Function of residual r</td>']
323 s += [f'<td>1</td><td>{self._function_as_tex}</td></tr>']
324 s += ['<tr><td style="text-align: left;">Use preconditioner mask</td>']
325 s += [f'<td>1</td><td>{self._use_preconditioner_mask}</td></tr>']
326 s += ['<tr><td style="text-align: left;">Hash</td>']
327 h = hex(hash(self))
328 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
329 s += ['</tbody>']
330 s += ['</table>']
331 return '\n'.join(s)
333 def __hash__(self) -> int:
334 to_hash = [hash(self._residual_calculator),
335 self.use_weights,
336 self._residual_norm_multiplier,
337 self._preconditioner,
338 self._use_preconditioner_mask]
339 for r in self.regularizers:
340 to_hash.append(hash(self.regularizers[r]))
341 to_hash.append(hash(self.regularization_weights[r]))
342 return int(list_to_hash(to_hash), 16)