Coverage for local_installation_linux/mumott/optimization/regularizers/total_variation.py: 92%

52 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import numpy as np 

2from numpy.typing import NDArray 

3 

4from mumott.optimization.regularizers.base_regularizer import Regularizer 

5import logging 

6logger = logging.getLogger(__name__) 

7 

8 

9class TotalVariation(Regularizer): 

10 

11 r"""Regularizes using the symmetric total variation, i.e., the root-mean-square difference 

12 between nearest neighbours. It is combined with a `Huber norm 

13 <https://en.wikipedia.org/wiki/Huber_loss>`_, using the squared differences at small values, 

14 in order to improve convergence. Suitable for scalar fields or tensor fields in local 

15 representations. Tends to reduce noise. 

16 

17 In two dimensions, the total variation spliced with its squared function like a 

18 Huber loss can be written 

19 

20 .. math:: 

21 \mathrm{TV}_1(f(x, y)) 

22 = \frac{1}{h}\sum_i ((f(x_i, y_i) - f(x_i + h, y_i))^2 + 

23 (f(x_i, y_i) - f(x_i - h, y_i))^2 + \\ (f(x_i, y_i) - f(x_i, y_i + h))^2 + 

24 (f(x_i, y_i) - f(x_i, y_i - h))^2))^{\frac{1}{2}} - 0.5 \delta 

25 

26 If :math:`\mathrm{TV}_1 < 0.5 \delta` we instead use 

27 

28 .. math:: 

29 \mathrm{TV}_2(f(x, y)) 

30 = \frac{1}{2 \delta h^2}\sum_i (f(x_i, y_i) - f(x_i + h, y_i))^2 + 

31 (f(x_i, y_i) - f(x_i - h, y_i))^2 + \\ 

32 (f(x_i, y_i) - f(x_i, y_i + h))^2 + (f(x_i, y_i) - f(x_i, y_i - h))^2 

33 

34 See also the Wikipedia articles on 

35 `total variation denoising <https://en.m.wikipedia.org/wiki/Total_variation_denoising>`_ 

36 and `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_ 

37 

38 Parameters 

39 ---------- 

40 delta : float 

41 Below this value, the scaled square of the total variation is used as the norm. 

42 This makes the norm differentiable everywhere, and can improve convergence. 

43 If :attr`delta` is ``None``, the standard total variation will be used everywhere, 

44 and the gradient will be ``0`` at the singular point where the norm is ``0``. 

45 """ 

46 

47 def __init__(self, delta: float = 1e-2): 

48 if delta is not None: 

49 if delta <= 0: 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 raise ValueError('delta must be greater than or equal to zero, but a value' 

51 f' of {delta} was specified! To use the total variation without' 

52 ' Huber splicing, explicitly specify delta=None.') 

53 self._delta = float(delta) 

54 else: 

55 self._delta = delta 

56 super().__init__() 

57 

58 def get_regularization_norm(self, 

59 coefficients: NDArray[float], 

60 get_gradient: bool = False, 

61 gradient_part: str = None) -> dict[str, NDArray[float]]: 

62 """Retrieves the isotropic total variation, i.e., the symmetric root-mean-square difference 

63 between nearest neighbours. 

64 

65 Parameters 

66 ---------- 

67 coefficients 

68 An ``np.ndarray`` of values, with shape ``(X, Y, Z, W)``, where 

69 the last channel contains, e.g., tensor components. 

70 get_gradient 

71 If ``True``, returns a ``'gradient'`` of the same shape as :attr:`coefficients`. 

72 Otherwise, the entry ``'gradient'`` will be ``None``. Defaults to ``False``. 

73 gradient_part 

74 Used for the zonal harmonics (ZH) reconstructions to determine what part of the gradient is 

75 being calculated. Default is ``None``. 

76 If a flag is passed in (``'full'``, ``'angles'``, ``'coefficients'``), 

77 we assume that the ZH workflow is used and that the last two coefficients are Euler angles, 

78 which should not be regularized by this regularizer. 

79 

80 Returns 

81 ------- 

82 A dictionary with two entries, ``regularization_norm`` and ``gradient``. 

83 """ 

84 

85 num = 6 * coefficients 

86 denom = np.zeros_like(coefficients) 

87 slices_r = [np.s_[1:, :, :], np.s_[:-1, :, :], 

88 np.s_[:, 1:, :], np.s_[:, :-1, :], 

89 np.s_[:, :, 1:], np.s_[:, :, :-1]] 

90 slices_coeffs = [np.s_[:-1, :, :], np.s_[1:, :, :], 

91 np.s_[:, :-1, :], np.s_[:, 1:, :], 

92 np.s_[:, :, :-1], np.s_[:, :, 1:]] 

93 

94 for s, v in zip(slices_r, slices_coeffs): 

95 num[s] -= coefficients[v] 

96 denom[s] += (coefficients[s] - coefficients[v]) ** 2 

97 

98 result = dict(regularization_norm=None, gradient=None) 

99 norm = np.sqrt(denom) 

100 gradient = np.zeros_like(coefficients) 

101 

102 if self._delta is None: 

103 where = norm == 0 

104 else: 

105 where = norm < self._delta 

106 

107 gradient[where] = num[where] * np.reciprocal(self._delta) 

108 

109 norm[where] = np.reciprocal(self._delta) * 0.5 * norm[where] ** 2 

110 norm[~where] -= 0.5 * self._delta 

111 

112 gradient[~where] = num[~where] / np.sqrt(denom[~where]) 

113 

114 if get_gradient: 

115 if gradient_part is None: 

116 result['gradient'] = gradient 

117 elif gradient_part in ('full', 'coefficients'): 

118 result['gradient'] = gradient 

119 result['gradient'][..., -2:] = 0 

120 elif gradient_part in ('angles'): 120 ↛ 123line 120 didn't jump to line 123, because the condition on line 120 was never false

121 result['gradient'] = np.zeros_like(coefficients) 

122 else: 

123 raise ValueError('Unexpected argument given for gradient part.') 

124 

125 if gradient_part is None: 

126 result['regularization_norm'] = np.sum(norm) 

127 elif gradient_part in ('full', 'coefficients', 'angles'): 127 ↛ 130line 127 didn't jump to line 130, because the condition on line 127 was never false

128 result['regularization_norm'] = np.sum(norm[..., :-2]) 

129 else: 

130 raise ValueError('Unexpected argument given for gradient part.') 

131 

132 return result 

133 

134 @property 

135 def _function_as_str(self) -> str: 

136 return ('R(x) = lambda * sqrt(' 

137 '\n (x[i, j] - x[i + 1, j]) ** 2 + (x[i, j] - x[i - 1, j]) ** 2 +' 

138 '\n (x[i, j] - x[i, j + 1]) ** 2 + (x[i, j] - x[i, j - 1]) ** 2)') 

139 

140 @property 

141 def _function_as_tex(self) -> str: 

142 # we use html line breaks <br> since LaTeX line breaks appear unsupported. 

143 return (r'$R(\vec{x}) = \sum_{ij} L(x_{ij})$ <br>' 

144 r'$L(x_{ij}) = \begin{Bmatrix}L_1(x_{ij})' 

145 r'\text{\quad if } L_1(x_{ij}) > 0.5 \delta \\ L_2(x_{ij})' 

146 r'\text{\quad otherwise}\end{Bmatrix}$<br>' 

147 r'$L_1(x_{ij}) = \lambda ((x_{ij} - x_{(i+1)j})^2 +$<br>' 

148 r'$(x_{ij} - x_{i(j+1)})^2 + (x_{ij} - x_{(i-1)j})^2 +$<br>' 

149 r'$(x_{ij} - x_{i(j-1)})^2 )^\frac{1}{2} - 0.5 \delta$<br>' 

150 r'$L_2(x_{ij}) = \dfrac{\lambda}{2 \delta} ((x_{ij} - x_{(i+1)j})^2 +$<br>' 

151 r'$(x_{ij} - x_{i(j+1)})^2 + (x_{ij} - x_{(i-1)j})^2 +$<br>' 

152 r'$(x_{ij} - x_{i(j-1)})^2 )$<br>')