Coverage for local_installation_linux/mumott/methods/residual_calculators/gradient_residual_calculator.py: 98%
88 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
2from typing import Dict
4import numpy as np
5from numpy.typing import NDArray
7from mumott import DataContainer
8from mumott.core.hashing import list_to_hash
9from .base_residual_calculator import ResidualCalculator
10from mumott.methods.basis_sets.base_basis_set import BasisSet
11from mumott.methods.projectors.base_projector import Projector
14logger = logging.getLogger(__name__)
17class GradientResidualCalculator(ResidualCalculator):
18 """Class that implements the GradientResidualCalculator method.
19 This residual calculator is an appropriate choice for :term:`SAXS` tensor tomography, as it relies
20 on the small-angle approximation. It relies on inverting the John transform
21 (also known as the X-ray transform) of a tensor field (where each tensor is a
22 representation of a spherical function) by comparing it to scattering data which
23 has been corrected for transmission.
25 Parameters
26 ----------
27 data_container : DataContainer
28 Container for the data which is to be reconstructed.
29 basis_set : BasisSet
30 The basis set used for representing spherical functions.
31 projector : Projector
32 The type of projector used together with this method.
33 use_scalar_projections : bool
34 Whether to use a set of scalar projections, rather than the data
35 in :attr:`data_container`.
36 scalar_projections : NDArray[float]
37 If :attr:`use_scalar_projections` is true, the set of scalar projections to use.
38 Should have the same shape as :attr:`data_container.data`, except with
39 only one channel in the last index.
40 """
42 def __init__(self,
43 data_container: DataContainer,
44 basis_set: BasisSet,
45 projector: Projector,
46 use_scalar_projections: bool = False,
47 scalar_projections: NDArray[float] = None):
48 super().__init__(data_container,
49 basis_set,
50 projector,
51 use_scalar_projections,
52 scalar_projections,)
54 def get_residuals(self,
55 get_gradient: bool = False,
56 get_weights: bool = False,
57 gradient_part: str = None) -> dict[str, NDArray[float]]:
58 """ Calculates a residuals and possibly a gradient between
59 coefficients projected using the :attr:`BasisSet` and :attr:`Projector`
60 attached to this instance.
62 Parameters
63 ----------
64 get_gradient
65 Whether to return the gradient. Default is ``False``.
66 get_weights
67 Whether to return weights. Default is ``False``. If ``True`` along with
68 :attr:`get_gradient`, the gradient will be computed with weights.
69 gradient_part
70 Used for the zonal harmonics resonstructions to determine what part of the gradient is
71 being calculated. Default is ``None``. Raises a ``NotImplementedError`` for any other value.
73 Returns
74 -------
75 A dictionary containing the residuals, and possibly the
76 gradient and/or weights. If gradient and/or weights
77 are not returned, their value will be ``None``.
78 """
80 if gradient_part is not None:
81 raise NotImplementedError('The GradientResidualCalculator class does not work with optimizing '
82 'angles. Use the ZHTTResidualCalculator class instead.')
84 projection = self._basis_set.forward(self._projector.forward(self._coefficients))
85 residuals = projection - self._data
86 if get_gradient:
87 # todo: consider if more complicated behaviour is useful,
88 # e.g. providing function to be applied to weights
89 if get_weights:
90 gradient = self._projector.adjoint(
91 self._basis_set.gradient(residuals * self._weights).astype(self.dtype))
92 else:
93 gradient = self._projector.adjoint(
94 self._basis_set.gradient(residuals).astype(self.dtype))
95 else:
96 gradient = None
98 if get_weights:
99 weights = self._weights
100 else:
101 weights = None
103 return dict(residuals=residuals, gradient=gradient, weights=weights)
105 def get_gradient_from_residual_gradient(self, residual_gradient: NDArray[float]) -> Dict:
106 """ Projects a residual gradient into coefficient and volume space. Used
107 to get gradients from more complicated residuals, e.g., the Huber loss.
108 Assumes that any weighting to the residual gradient has already been applied.
110 Parameters
111 ----------
112 residual_gradient
113 The residual gradient, from which to calculate the gradient.
115 Returns
116 -------
117 An ``NDArray`` containing the gradient.
118 """
119 return self._projector.adjoint(
120 self._basis_set.gradient(residual_gradient).astype(self.dtype))
122 def _update(self, force_update: bool = False) -> None:
123 """ Carries out necessary updates if anything changes with respect to
124 the geometry or basis set. """
125 if not (self.is_dirty or force_update): 125 ↛ 126line 125 didn't jump to line 126, because the condition on line 125 was never true
126 return
127 self._basis_set.probed_coordinates = self.probed_coordinates
128 len_diff = len(self._basis_set) - self._coefficients.shape[-1]
129 vol_diff = self._data_container.geometry.volume_shape - np.array(self._coefficients.shape[:-1])
130 # TODO: Think about whether the ``Method`` should do this or handle it differently
131 if np.any(vol_diff != 0) or len_diff != 0:
132 logger.warning('Shape of coefficient array has changed, array will be padded'
133 ' or truncated.')
134 # save old array, no copy needed
135 old_coefficients = self._coefficients
136 # initialize new array
137 self._coefficients = \
138 np.zeros((*self._data_container.geometry.volume_shape, len(self._basis_set)),
139 dtype=self.dtype)
140 # for comparison of volume shapes
141 shapes = zip(old_coefficients.shape[:-1], self._coefficients.shape[:-1])
142 # old coefficients go into middle of new coefficients except in last index
143 slice_1 = tuple([slice(max(0, (d-s) // 2), min(d, (s + d) // 2)) for s, d in shapes]) + \
144 (slice(0, min(old_coefficients.shape[-1], self._coefficients.shape[-1])),)
145 # zip objects are depleted
146 shapes = zip(old_coefficients.shape[:-1], self._coefficients.shape[:-1])
147 slice_2 = tuple([slice(max(0, (s-d) // 2), min(s, (s + d) // 2)) for s, d in shapes]) + \
148 (slice(0, min(old_coefficients.shape[-1], self._coefficients.shape[-1])),)
149 # assumption made that old_coefficients[..., 0] correspnds to self._coefficients[..., 0]
150 self._coefficients[slice_1] = old_coefficients[slice_2]
151 # Assumption may not be true for all representations!
152 # TODO: Consider more logic here using e.g. basis set properties.
153 if len_diff != 0:
154 logger.warning('Size of basis set has changed. Coefficients have'
155 ' been copied over starting at index 0. If coefficients'
156 ' of new size do not line up with the old size,'
157 ' please reinitialize the coefficients.')
158 self._geometry_hash = hash(self._data_container.geometry)
159 self._basis_set_hash = hash(self._basis_set)
161 def __hash__(self) -> int:
162 """ Returns a hash of the current state of this instance. """
163 to_hash = [self._coefficients,
164 hash(self._projector),
165 hash(self._data_container.geometry),
166 self._basis_set_hash,
167 self._geometry_hash]
168 return int(list_to_hash(to_hash), 16)
170 @property
171 def is_dirty(self) -> bool:
172 """ ``True`` if stored hashes of geometry or basis set objects do
173 not match their current hashes. Used to trigger updates """
174 return ((self._geometry_hash != hash(self._data_container.geometry)) or
175 (self._basis_set_hash != hash(self._basis_set)))
177 def __str__(self) -> str:
178 wdt = 74
179 s = []
180 s += ['=' * wdt]
181 s += [self.__class__.__name__.center(wdt)]
182 s += ['-' * wdt]
183 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
184 s += ['{:18} : {}'.format('BasisSet', self._basis_set.__class__.__name__)]
185 s += ['{:18} : {}'.format('Projector', self._projector.__class__.__name__)]
186 s += ['{:18} : {}'.format('is_dirty', self.is_dirty)]
187 s += ['{:18} : {}'.format('probed_coordinates (hash)',
188 hex(hash(self.probed_coordinates))[:6])]
189 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
190 s += ['-' * wdt]
191 return '\n'.join(s)
193 def _repr_html_(self) -> str:
194 s = []
195 s += [f'<h3>{self.__class__.__name__}</h3>']
196 s += ['<table border="1" class="dataframe">']
197 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
198 s += ['<tbody>']
199 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
200 s += ['<tr><td style="text-align: left;">BasisSet</td>']
201 s += [f'<td>{1}</td><td>{self._basis_set.__class__.__name__}</td></tr>']
202 s += ['<tr><td style="text-align: left;">Projector</td>']
203 s += [f'<td>{len(self._projector.__class__.__name__)}</td>'
204 f'<td>{self._projector.__class__.__name__}</td></tr>']
205 s += ['<tr><td style="text-align: left;">Is dirty</td>']
206 s += [f'<td>{1}</td><td>{self.is_dirty}</td></tr>']
207 s += ['<tr><td style="text-align: left;">probed_coordinates</td>']
208 s += [f'<td>{self.probed_coordinates.vector.shape}</td>'
209 f'<td>{hex(hash(self.probed_coordinates))[:6]} (hash)</td></tr>']
210 s += ['<tr><td style="text-align: left;">Hash</td>']
211 h = hex(hash(self))
212 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
213 s += ['</tbody>']
214 s += ['</table>']
215 return '\n'.join(s)