Coverage for local_installation_linux/mumott/optimization/regularizers/base_regularizer.py: 94%
47 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1from abc import ABC, abstractmethod
3import numpy as np
4from numpy.typing import NDArray
6from mumott.core.hashing import list_to_hash
9class Regularizer(ABC):
11 """This is the base class from which specific regularizers are derived.
12 """
14 def __init__(self):
15 pass
17 @abstractmethod
18 def get_regularization_norm(self,
19 coefficients: NDArray[float] = None,
20 get_gradient: bool = False,
21 gradient_part: str = None) -> dict[str, NDArray[float]]:
22 """Returns regularization norm and possibly gradient based on the provided coefficients.
24 Parameters
25 ----------
26 coefficients
27 An ``np.ndarray`` of values, with shape `(X, Y, Z, W)`, where
28 the last channel contains e.g. tensor components.
29 get_gradient
30 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`.
31 Otherwise, the entry ``'gradient'`` will be ``None``.
32 gradient_part
33 Used for the zonal harmonics resonstructions to determine what part of the gradient is
34 being calculated. Default is None.
36 Returns
37 -------
38 A dictionary with at least two entries, ``residual_norm`` and ``gradient``.
39 """
40 pass
42 @property
43 @abstractmethod
44 def _function_as_str(self) -> str:
45 """ Should return a string representation of the associated norm
46 of the coefficients in Python idiom, e.g. 'R(x) = 0.5 * x ** 2' for L2. """
47 pass
49 @property
50 @abstractmethod
51 def _function_as_tex(self) -> str:
52 """ Should return a string representation of the associated norm
53 of the coefficients in MathJax-renderable TeX, e.g. $R(x) = \frac{r^2}{2}$ for L2"""
54 pass
56 def __hash__(self) -> int:
57 to_hash = [self._function_as_str]
58 return int(list_to_hash(to_hash), 16)
60 def __str__(self) -> str:
61 s = []
62 wdt = 74
63 s += ['=' * wdt]
64 s += [self.__class__.__name__.center(wdt)]
65 s += ['-' * wdt]
66 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
67 s += ['{:18} : {}'.format('Function of coefficients', self._function_as_str)]
68 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
69 s += ['-' * wdt]
70 return '\n'.join(s)
72 def _repr_html_(self) -> str:
73 s = []
74 s += [f'<h3>{self.__class__.__name__}</h3>']
75 s += ['<table border="1" class="dataframe">']
76 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
77 s += ['<tbody>']
78 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
79 s += ['<tr><td style="text-align: left;">Function of coefficients</td>']
80 s += [f'<td>1</td><td>{self._function_as_tex}</td></tr>']
81 s += ['<tr><td style="text-align: left;">Hash</td>']
82 h = hex(hash(self))
83 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
84 s += ['</tbody>']
85 s += ['</table>']
86 return '\n'.join(s)