Coverage for local_installation_linux/mumott/optimization/regularizers/base_regularizer.py: 94%

47 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1from abc import ABC, abstractmethod 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5 

6from mumott.core.hashing import list_to_hash 

7 

8 

9class Regularizer(ABC): 

10 

11 """This is the base class from which specific regularizers are derived. 

12 """ 

13 

14 def __init__(self): 

15 pass 

16 

17 @abstractmethod 

18 def get_regularization_norm(self, 

19 coefficients: NDArray[float] = None, 

20 get_gradient: bool = False, 

21 gradient_part: str = None) -> dict[str, NDArray[float]]: 

22 """Returns regularization norm and possibly gradient based on the provided coefficients. 

23 

24 Parameters 

25 ---------- 

26 coefficients 

27 An ``np.ndarray`` of values, with shape `(X, Y, Z, W)`, where 

28 the last channel contains e.g. tensor components. 

29 get_gradient 

30 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`. 

31 Otherwise, the entry ``'gradient'`` will be ``None``. 

32 gradient_part 

33 Used for the zonal harmonics resonstructions to determine what part of the gradient is 

34 being calculated. Default is None. 

35 

36 Returns 

37 ------- 

38 A dictionary with at least two entries, ``residual_norm`` and ``gradient``. 

39 """ 

40 pass 

41 

42 @property 

43 @abstractmethod 

44 def _function_as_str(self) -> str: 

45 """ Should return a string representation of the associated norm 

46 of the coefficients in Python idiom, e.g. 'R(x) = 0.5 * x ** 2' for L2. """ 

47 pass 

48 

49 @property 

50 @abstractmethod 

51 def _function_as_tex(self) -> str: 

52 """ Should return a string representation of the associated norm 

53 of the coefficients in MathJax-renderable TeX, e.g. $R(x) = \frac{r^2}{2}$ for L2""" 

54 pass 

55 

56 def __hash__(self) -> int: 

57 to_hash = [self._function_as_str] 

58 return int(list_to_hash(to_hash), 16) 

59 

60 def __str__(self) -> str: 

61 s = [] 

62 wdt = 74 

63 s += ['=' * wdt] 

64 s += [self.__class__.__name__.center(wdt)] 

65 s += ['-' * wdt] 

66 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1): 

67 s += ['{:18} : {}'.format('Function of coefficients', self._function_as_str)] 

68 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

69 s += ['-' * wdt] 

70 return '\n'.join(s) 

71 

72 def _repr_html_(self) -> str: 

73 s = [] 

74 s += [f'<h3>{self.__class__.__name__}</h3>'] 

75 s += ['<table border="1" class="dataframe">'] 

76 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

77 s += ['<tbody>'] 

78 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

79 s += ['<tr><td style="text-align: left;">Function of coefficients</td>'] 

80 s += [f'<td>1</td><td>{self._function_as_tex}</td></tr>'] 

81 s += ['<tr><td style="text-align: left;">Hash</td>'] 

82 h = hex(hash(self)) 

83 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>'] 

84 s += ['</tbody>'] 

85 s += ['</table>'] 

86 return '\n'.join(s)