Coverage for local_installation_linux/mumott/data_handling/data_container.py: 85%
277 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import logging
3import h5py as h5
4import numpy as np
6from numpy.typing import NDArray
7from scipy.spatial.transform import Rotation
9from mumott.core.deprecation_warning import print_deprecation_warning
10from mumott.core.geometry import Geometry
11from mumott.core.projection_stack import ProjectionStack, Projection
13logger = logging.getLogger(__name__)
15# used to easily keep track of preferred keys
16_preferred_keys = dict(rotations='inner_angle',
17 tilts='outer_angle',
18 offset_j='j_offset',
19 offset_k='k_offset',
20 rot_mat='rotation_matrix')
23def _deprecated_key_warning(deprecated_key: str):
24 """Internal method for deprecation warnings of keys."""
25 if deprecated_key in _preferred_keys:
26 # Only print once.
27 preferred_key = _preferred_keys.pop(deprecated_key)
28 print_deprecation_warning(
29 f'Entry name {deprecated_key} is deprecated. Use {preferred_key} instead.')
32class DataContainer:
34 """
35 Instances of this class represent data read from an input file in a format suitable for further analysis.
36 The two core components are :attr:`geometry` and :attr:`projections`.
37 The latter comprises a list of :class:`Projection <mumott.core.projection_stack.Projection>`
38 instances, each of which corresponds to a single measurement.
40 By default all data is read, which can be rather time consuming and unnecessary in some cases,
41 e.g., when aligning data.
42 In those cases, one can skip loading the actual measurements by setting :attr:`skip_data` to ``True``.
43 The geometry information and supplementary information such as the diode data will still be read.
45 Example
46 -------
47 The following code snippet illustrates the basic use of the :class:`DataContainer` class.
49 First we create a :class:`DataContainer` instance, providing the path to the data file to be read.
51 >>> from mumott.data_handling import DataContainer
52 >>> dc = DataContainer('tests/test_full_circle.h5')
54 One can then print a short summary of the content of the :class:`DataContainer` instance.
56 >>> print(dc)
57 ==========================================================================
58 DataContainer
59 --------------------------------------------------------------------------
60 Corrected for transmission : False
61 ...
63 To access individual measurements we can use the :attr:`projections` attribute.
64 The latter behaves like a list, where the elements of the list are
65 :class:`Projection <mumott.core.projection_stack.Projection>` objects,
66 each of which represents an individual measurement.
67 We can print a summary of the content of the first projection.
69 >>> print(dc.projections[0])
70 --------------------------------------------------------------------------
71 Projection
72 --------------------------------------------------------------------------
73 hash_data : 3f0ba8
74 hash_diode : 808328
75 hash_weights : 088d39
76 rotation : [1. 0. 0.], [ 0. -1. 0.], [ 0. 0. -1.]
77 j_offset : 0.0
78 k_offset : 0.3
79 inner_angle : None
80 outer_angle : None
81 inner_axis : 0.0, 0.0, -1.0
82 outer_axis : 1.0, 0.0, 0.0
83 --------------------------------------------------------------------------
86 Parameters
87 ----------
88 data_path : str, optional
89 Path of the data file relative to the directory of execution.
90 If None, a data container with an empty :attr:`projections`
91 attached will be initialized.
92 data_type : str, optional
93 The type (or format) of the data file. Supported values are
94 ``h5`` (default) for hdf5 format and ``None`` for an empty ``DataContainer``
95 that can be manually populated.
96 skip_data : bool, optional
97 If ``True``, will skip data from individual measurements when loading the file.
98 This will result in a functioning :attr:`geometry` instance as well as
99 :attr:`diode` and :attr:`weights` entries in each projection, but
100 :attr:`data` will be empty.
101 nonfinite_replacement_value : float, optional
102 Value to replace nonfinite values (``np.nan``, ``np.inf``, and ``-np.inf``) with in the
103 data, diode, and weights. If ``None`` (default), an error is raised
104 if any nonfinite values are present in these input fields.
105 """
106 def __init__(self,
107 data_path: str = None,
108 data_type: str = 'h5',
109 skip_data: bool = False,
110 nonfinite_replacement_value: float = None):
111 self._correct_for_transmission_called = False
112 self._projections = ProjectionStack()
113 self._geometry_dictionary = dict()
114 self._skip_data = skip_data
115 self._nonfinite_replacement_value = nonfinite_replacement_value
116 if data_path is not None:
117 if data_type == 'h5': 117 ↛ 120line 117 didn't jump to line 120, because the condition on line 117 was never false
118 self._h5_to_projections(data_path)
119 else:
120 raise ValueError(f'Unknown data_type: {data_type} for'
121 ' load_only_geometry=False.')
123 def _h5_to_projections(self, file_path: str):
124 """
125 Internal method for loading data from hdf5 file.
126 """
127 h5_data = h5.File(file_path, 'r')
128 projections = h5_data['projections']
129 number_of_projections = len(projections)
130 max_shape = (0, 0)
131 inner_axis = np.array((0., 0., -1.))
132 outer_axis = np.array((1., 0., 0.))
133 found_inner_in_base = False
134 found_outer_in_base = False
135 if 'inner_axis' in h5_data:
136 inner_axis = h5_data['inner_axis'][:]
137 logger.info('Inner axis found in dataset base directory. This will override the default.')
138 found_inner_in_base = True
139 if 'outer_axis' in h5_data:
140 outer_axis = h5_data['outer_axis'][:]
141 logger.info('Outer axis found in dataset base directory. This will override the default.')
142 found_outer_in_base = True
144 for i in range(number_of_projections):
145 p = projections[f'{i}']
146 if 'diode' in p: 146 ↛ 144line 146 didn't jump to line 144, because the condition on line 146 was never false
147 max_shape = np.max((max_shape, p['diode'].shape), axis=0)
148 for i in range(number_of_projections):
149 p = projections[f'{i}']
150 if 'diode' in p: 150 ↛ 156line 150 didn't jump to line 156, because the condition on line 150 was never false
151 diode = np.ascontiguousarray(np.copy(p['diode']).astype(np.float64))
152 pad_sequence = np.array(((0, max_shape[0] - diode.shape[0]),
153 (0, max_shape[1] - diode.shape[1]),
154 (0, 0)))
155 diode = np.pad(diode, pad_sequence[:-1])
156 elif 'data' in p:
157 diode = None
158 pad_sequence = np.array(((0, max_shape[0] - p['data'].shape[0]),
159 (0, max_shape[1] - p['data'].shape[1]),
160 (0, 0)))
161 else:
162 pad_sequence = np.zeros((3, 2))
163 if not self._skip_data:
164 data = np.ascontiguousarray(np.copy(p['data']).astype(np.float64))
165 data = np.pad(data, pad_sequence)
166 if 'weights' in p:
167 weights = np.ascontiguousarray(np.copy(p['weights']).astype(np.float64))
168 if weights.ndim == 2 or (weights.ndim == 3 and weights.shape[-1] == 1):
169 weights = weights.reshape(weights.shape[:2])
170 weights = weights[..., np.newaxis] * \
171 np.ones((1, 1, data.shape[-1])).astype(np.float64)
172 weights = np.pad(weights, pad_sequence)
173 else:
174 weights = np.ones_like(data)
175 else:
176 data = None
177 if 'weights' in p: 177 ↛ 181line 177 didn't jump to line 181, because the condition on line 177 was never false
178 weights = np.ascontiguousarray(np.copy(p['weights']).astype(np.float64))
179 weights = np.pad(weights, pad_sequence[:weights.ndim])
180 else:
181 if diode is None:
182 weights = None
183 else:
184 weights = np.ones_like(diode)
185 weights = np.pad(weights, pad_sequence[:weights.ndim])
186 # Look for rotation information and load if available
187 if 'inner_axis' in p:
188 p_inner_axis = p['inner_axis'][:]
189 if found_inner_in_base: 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true
190 logger.info(f'Inner axis found in projection {i}. This will override '
191 'the value found in the base directory for all projections '
192 'where it is found.')
193 found_inner_in_base = False
194 # override default only if projection zero
195 elif i == 0: 195 ↛ 202line 195 didn't jump to line 202, because the condition on line 195 was never false
196 logger.info(f'Inner axis found in projection {i}. This will override '
197 'the default value for all projections, if they do not specify '
198 'another axis.')
199 inner_axis = p_inner_axis
200 else:
201 p_inner_axis = inner_axis
202 if 'outer_axis' in p:
203 p_outer_axis = p['outer_axis'][:]
204 if found_outer_in_base: 204 ↛ 205line 204 didn't jump to line 205, because the condition on line 204 was never true
205 logger.info(f'Outer axis found in projection {i}. This will override '
206 'the value found in the base directory for all projections '
207 'where it is found.')
208 found_outer_in_base = False
209 elif i == 0: 209 ↛ 217line 209 didn't jump to line 217, because the condition on line 209 was never false
210 logger.info(f'Inner axis found in projection {i}. This will override '
211 'the default value for all projections, if they do not specify '
212 'another axis.')
213 outer_axis = p_outer_axis
214 else:
215 p_outer_axis = outer_axis
217 inner_angle = None
218 outer_angle = None
220 if 'inner_angle' in p:
221 inner_angle = np.copy(p['inner_angle'][...]).flatten()[0]
222 # if at least one angle exists, assume other is 0 by default.
223 if outer_angle is None: 223 ↛ 230line 223 didn't jump to line 230, because the condition on line 223 was never false
224 outer_angle = 0
225 elif 'rotations' in p:
226 inner_angle = np.copy(p['rotations'][...]).flatten()[0]
227 _deprecated_key_warning('rotations')
228 if outer_angle is None: 228 ↛ 230line 228 didn't jump to line 230, because the condition on line 228 was never false
229 outer_angle = 0
230 if 'outer_angle' in p:
231 outer_angle = np.copy(p['outer_angle'][...]).flatten()[0]
232 if inner_angle is None: 232 ↛ 233line 232 didn't jump to line 233, because the condition on line 232 was never true
233 inner_angle = 0
234 elif 'tilts' in p:
235 outer_angle = np.copy(p['tilts'][...]).flatten()[0]
236 _deprecated_key_warning('tilts')
237 if inner_angle is None: 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true
238 inner_angle = 0
240 if 'rotation_matrix' in p:
241 rotation = p['rotation_matrix'][...]
242 if i == 0:
243 logger.info('Rotation matrices were loaded from the input file.')
244 elif 'rot_mat' in p:
245 rotation = p['rot_mat'][...]
246 _deprecated_key_warning('rot_mat')
247 if i == 0:
248 logger.info('Rotation matrices were loaded from the input file.')
249 elif outer_angle is not None: 249 ↛ 258line 249 didn't jump to line 258, because the condition on line 249 was never false
250 R_inner = Rotation.from_rotvec(inner_angle * p_inner_axis).as_matrix()
251 R_outer = Rotation.from_rotvec(outer_angle * p_outer_axis).as_matrix()
252 rotation = R_outer @ R_inner
253 if i == 0:
254 logger.info('Rotation matrix generated from inner and outer angles,'
255 ' along with inner and outer rotation axis vectors.'
256 ' Rotation and tilt angles assumed to be in radians.')
257 else:
258 rotation = np.eye(3)
259 if i == 0:
260 logger.info('No rotation information found.')
262 # default to 0-dim array to simplify subsequent code
263 j_offset = np.array(0)
264 if 'j_offset' in p:
265 j_offset = p['j_offset'][...]
266 elif 'offset_j' in p: 266 ↛ 270line 266 didn't jump to line 270, because the condition on line 266 was never false
267 j_offset = p['offset_j'][...]
268 _deprecated_key_warning('offset_j')
269 # offset will be either numpy/size-0 or size-1 array, ravel and extract.
270 j_offset = j_offset.ravel()[0]
271 j_offset -= pad_sequence[0, 1] * 0.5
273 k_offset = np.array(0)
274 if 'k_offset' in p:
275 k_offset = p['k_offset'][...]
276 elif 'offset_k' in p: 276 ↛ 279line 276 didn't jump to line 279, because the condition on line 276 was never false
277 k_offset = p['offset_k'][...]
278 _deprecated_key_warning('offset_k')
279 k_offset = k_offset.ravel()[0]
280 k_offset -= pad_sequence[1, 1] * 0.5
282 if not self._skip_data:
283 self._handle_nonfinite_values(data)
284 self._handle_nonfinite_values(weights)
285 self._handle_nonfinite_values(diode)
287 projection = Projection(data=data,
288 diode=diode,
289 weights=weights,
290 rotation=rotation,
291 j_offset=j_offset,
292 k_offset=k_offset,
293 outer_angle=outer_angle,
294 inner_angle=inner_angle,
295 inner_axis=p_inner_axis,
296 outer_axis=p_outer_axis
297 )
298 self._projections.append(projection)
299 if not self._skip_data:
300 self._projections.geometry.detector_angles = np.copy(h5_data['detector_angles'])
301 self._estimate_angular_coverage(self._projections.geometry.detector_angles)
302 if 'volume_shape' in h5_data.keys(): 302 ↛ 305line 302 didn't jump to line 305, because the condition on line 302 was never false
303 self._projections.geometry.volume_shape = np.copy(h5_data['volume_shape']).astype(int)
304 else:
305 self._projections.geometry.volume_shape = np.array(max_shape)[[0, 0, 1]]
306 # Load sample geometry information
307 if 'p_direction_0' in h5_data.keys(): # TODO check for orthogonality, normality
308 self._projections.geometry.p_direction_0 = np.copy(h5_data['p_direction_0'][...])
309 self._projections.geometry.j_direction_0 = np.copy(h5_data['j_direction_0'][...])
310 self._projections.geometry.k_direction_0 = np.copy(h5_data['k_direction_0'][...])
311 logger.info('Sample geometry loaded from file.')
312 else:
313 logger.info('No sample geometry information was found. Default mumott geometry assumed.')
315 # Load detector geometry information
316 if 'detector_direction_origin' in h5_data.keys(): # TODO check for orthogonality, normality
317 self._projections.geometry.detector_direction_origin = np.copy(
318 h5_data['detector_direction_origin'][...])
319 self._projections.geometry.detector_direction_positive_90 = np.copy(
320 h5_data['detector_direction_positive_90'][...])
321 logger.info('Detector geometry loaded from file.')
322 else:
323 logger.info('No detector geometry information was found. Default mumott geometry assumed.')
325 # Load scattering angle
326 if 'two_theta' in h5_data:
327 self._projections.geometry.two_theta = np.array(h5_data['two_theta'])
328 logger.info('Scattering angle loaded from data.')
330 def _estimate_angular_coverage(self, detector_angles: list):
331 """Check if full circle appears covered in data or not."""
332 delta = np.abs(detector_angles[0] - detector_angles[-1] % (2 * np.pi))
333 if abs(delta - np.pi) < min(delta, abs(delta - 2 * np.pi)):
334 self.geometry.full_circle_covered = False
335 else:
336 logger.warning('The detector angles appear to cover a full circle. This '
337 'is only expected for WAXS data.')
338 self.geometry.full_circle_covered = True
340 def _handle_nonfinite_values(self, array):
341 """ Internal convenience function for handling nonfinite values. """
342 if np.any(~np.isfinite(array)):
343 if self._nonfinite_replacement_value is not None:
344 np.nan_to_num(array, copy=False, nan=self._nonfinite_replacement_value,
345 posinf=self._nonfinite_replacement_value,
346 neginf=self._nonfinite_replacement_value)
347 else:
348 raise ValueError('Nonfinite values detected in input, which is not permitted by default. '
349 'To permit and replace nonfinite values, please set '
350 'nonfinite_replacement_value to desired value.')
352 def __len__(self) -> int:
353 """
354 Length of the :attr:`projections <mumott.data_handling.projection_stack.ProjectionStack>`
355 attached to this :class:`DataContainer` instance.
356 """
357 return len(self._projections)
359 def append(self, f: Projection) -> None:
360 """
361 Appends a :class:`Projection <mumott.core.projection_stack.Projection>`
362 to the :attr:`projections` attached to this :class:`DataContainer` instance.
363 """
364 self._projections.append(f)
366 @property
367 def projections(self) -> ProjectionStack:
368 """ The projections, containing data and geometry. """
369 return self._projections
371 @property
372 def geometry(self) -> Geometry:
373 """ Container of geometry information. """
374 return self._projections.geometry
376 @property
377 def data(self) -> NDArray[np.float64]:
378 """
379 The data in the :attr:`projections` object
380 attached to this :class:`DataContainer` instance.
381 """
382 return self._projections.data
384 @property
385 def diode(self) -> NDArray[np.float64]:
386 """
387 The diode data in the :attr:`projections` object
388 attached to this :class:`DataContainer` instance.
389 """
390 return self._projections.diode
392 @property
393 def weights(self) -> NDArray[np.float64]:
394 """
395 The weights in the :attr:`projections` object
396 attached to this :class:`DataContainer` instance.
397 """
398 return self._projections.weights
400 def correct_for_transmission(self) -> None:
401 """
402 Applies correction from the input provided in the :attr:`diode
403 <mumott.core.projection_stack.Projection.diode>` field. Should
404 only be used if this correction has *not* been applied yet.
405 """
406 if self._correct_for_transmission_called: 406 ↛ 407line 406 didn't jump to line 407, because the condition on line 406 was never true
407 logger.info(
408 'DataContainer.correct_for_transmission() has been called already.'
409 ' The correction has been applied previously, and the repeat call is ignored.')
410 return
412 data = self._projections.data / self._projections.diode[..., np.newaxis]
414 for i, f in enumerate(self._projections):
415 f.data = data[i]
416 self._correct_for_transmission_called = True
418 def _Rx(self, angle: float) -> NDArray[float]:
419 """ Generate a rotation matrix for rotations around
420 the x-axis, following the convention that vectors
421 have components ordered ``(x, y, z)``.
423 Parameters
424 ----------
425 angle
426 The angle of the rotation.
428 Returns
429 -------
430 R
431 The rotation matrix.
433 Notes
434 -----
435 For a vector ``v`` with shape ``(..., 3)`` and a rotation angle :attr:`angle`,
436 ``np.einsum('ji, ...i', _Rx(angle), v)`` rotates the vector around the
437 ``x``-axis by :attr:`angle`. If the
438 coordinate system is being rotated, then
439 ``np.einsum('ij, ...i', _Rx(angle), v)`` gives the vector in the
440 new coordinate system.
441 """
442 return Rotation.from_euler('X', angle).as_matrix()
444 def _Ry(self, angle: float) -> NDArray[float]:
445 """ Generate a rotation matrix for rotations around
446 the y-axis, following the convention that vectors
447 have components ordered ``(x, y, z)``.
449 Parameters
450 ----------
451 angle
452 The angle of the rotation.
454 Returns
455 -------
456 R
457 The rotation matrix.
459 Notes
460 -----
461 For a vector ``v`` with shape ``(..., 3)`` and a rotation angle ``angle``,
462 ``np.einsum('ji, ...i', _Ry(angle), v)`` rotates the vector around the
463 For a vector ``v`` with shape ``(..., 3)`` and a rotation angle :attr:`angle`,
464 ``np.einsum('ji, ...i', _Ry(angle), v)`` rotates the vector around the
465 ``y``-axis by :attr:`angle`. If the
466 coordinate system is being rotated, then
467 ``np.einsum('ij, ...i', _Ry(angle), v)`` gives the vector in the
468 new coordinate system.
469 """
470 return Rotation.from_euler('Y', angle).as_matrix()
472 def _Rz(self, angle: float) -> NDArray[float]:
473 """ Generate a rotation matrix for rotations around
474 the z-axis, following the convention that vectors
475 have components ordered ``(x, y, z)``.
477 Parameters
478 ----------
479 angle
480 The angle of the rotation.
482 Returns
483 -------
484 R
485 The rotation matrix.
487 Notes
488 -----
489 For a vector ``v`` with shape ``(..., 3)`` and a rotation angle :attr:`angle`,
490 ``np.einsum('ji, ...i', _Rz(angle), v)`` rotates the vector around the
491 ``z``-axis by :attr:`angle`. If the
492 coordinate system is being rotated, then
493 ``np.einsum('ij, ...i', _Rz(angle), v)`` gives the vector in the
494 new coordinate system.
495 """
496 return Rotation.from_euler('Z', angle).as_matrix()
498 def _get_str_representation(self, max_lines=50) -> str:
499 """ Retrieves a string representation of the object with specified
500 maximum number of lines.
502 Parameters
503 ----------
504 max_lines
505 The maximum number of lines to return.
506 """
507 wdt = 74
508 s = []
509 s += ['=' * wdt]
510 s += ['DataContainer'.center(wdt)]
511 s += ['-' * wdt]
512 s += ['{:26} : {}'.format('Corrected for transmission', self._correct_for_transmission_called)]
513 truncated_s = []
514 leave_loop = False
515 while not leave_loop:
516 line = s.pop(0).split('\n')
517 for split_line in line:
518 if split_line != '': 518 ↛ 520line 518 didn't jump to line 520, because the condition on line 518 was never false
519 truncated_s += [split_line]
520 if len(truncated_s) > max_lines - 2:
521 if split_line != '...': 521 ↛ 523line 521 didn't jump to line 523, because the condition on line 521 was never false
522 truncated_s += ['...']
523 if split_line != ('=' * wdt): 523 ↛ 525line 523 didn't jump to line 525, because the condition on line 523 was never false
524 truncated_s += ['=' * wdt]
525 leave_loop = True
526 break
527 if len(s) == 0:
528 leave_loop = True
529 truncated_s += ['=' * wdt]
530 return '\n'.join(truncated_s)
532 def __str__(self) -> str:
533 return self._get_str_representation()
535 def _get_html_representation(self, max_lines=25) -> str:
536 """ Retrieves an html representation of the object with specified
537 maximum number of lines.
539 Parameters
540 ----------
541 max_lines
542 The maximum number of lines to return.
543 """
544 s = []
545 s += ['<h3>DataContainer</h3>']
546 s += ['<table border="1" class="dataframe">']
547 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th></tr></thead>']
548 s += ['<tbody>']
549 s += ['<tr><td style="text-align: left;">Number of projections</td>']
550 s += [f'<td>{len(self._projections)}</td></tr>']
551 s += ['<tr><td style="text-align: left;">Corrected for transmission</td>']
552 s += [f'<td>{self._correct_for_transmission_called}</td></tr>']
553 s += ['</tbody>']
554 s += ['</table>']
555 truncated_s = []
556 line_count = 0
557 leave_loop = False
558 while not leave_loop:
559 line = s.pop(0).split('\n')
560 for split_line in line:
561 truncated_s += [split_line]
562 if '</tr>' in split_line:
563 line_count += 1
564 # Catch if last line had ellipses
565 last_tr = split_line
566 if line_count > max_lines - 1: 566 ↛ 567line 566 didn't jump to line 567, because the condition on line 566 was never true
567 if last_tr != '<tr><td style="text-align: left;">...</td></tr>':
568 truncated_s += ['<tr><td style="text-align: left;">...</td></tr>']
569 truncated_s += ['</tbody>']
570 truncated_s += ['</table>']
571 leave_loop = True
572 break
573 if len(s) == 0:
574 leave_loop = True
575 return '\n'.join(truncated_s)
577 def _repr_html_(self) -> str:
578 return self._get_html_representation()