Coverage for local_installation_linux/mumott/output_handling/saving.py: 91%

39 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import logging 

2import os 

3from typing import Dict, Any, Union, Tuple, List 

4 

5 

6import numpy as np 

7from numpy.typing import NDArray 

8import h5py as h5 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def save_array_like(key: str, val: Union[Tuple, List, NDArray], group): 

14 val = np.array(val) 

15 if val.dtype.kind in 'fcbviuS': 

16 group.create_dataset(key, data=val, shape=val.shape) 

17 elif val.dtype.kind in 'U': 

18 val = val.astype(bytes) 

19 group.create_dataset(key, data=val, shape=val.shape) 

20 else: 

21 logger.warning(f'Data type {val.dtype} not supported, entry will be ignored!') 

22 

23 

24def save_item(key: str, val: Any, group): 

25 if val is None: 

26 logger.warning(f'Entry {key} has value None, which is not supported and will be ignored!') 

27 return 

28 if isinstance(val, np.ndarray) or type(val) in (list, tuple): 

29 save_array_like(key, val, group) 

30 else: 

31 val = np.array((val,)) 

32 save_array_like(key, val, group) 

33 

34 

35def save_dict_recursively(inner_dict: Dict[str, Any], group): 

36 for key, val in inner_dict.items(): 

37 if isinstance(val, dict): 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true

38 group.create_group(key) 

39 save_dict_recursively(val, group[key]) 

40 else: 

41 save_item(key, val, group) 

42 

43 

44def dict_to_h5(dict_to_output: Dict[str, Any], filename: str, overwrite: bool = False) -> None: 

45 """Function for recursively saving a dictionary as an hdf5 file. 

46 

47 Example 

48 ------- 

49 The following snippet demonstrates how to save and read an example dictionary using this functionality. 

50 In this example, the output file will be overwritten in case it already exists. 

51 

52 >>> from mumott.output_handling import dict_to_h5 

53 >>> dict_to_h5(dict_to_output=dict(a=5, b='123', c=dict(L=['test'])), 

54 filename='my-example-dict.h5', overwrite=True) 

55 

56 To read the output file (typically at some later point) we use the ``h5py.File`` context. 

57 Note that this does not restore the original dictionary. 

58 

59 >>> import h5py 

60 >>> file = h5py.File('my-example-dict.h5') 

61 

62 We can loop over the fields in the hdf5 file using the dict-functionality of the ``h5py.File`` class. 

63 

64 >>> for k, v in file.items(): 

65 ... print(k, v) 

66 a <HDF5 dataset "a": shape (1,), type "<i8"> 

67 b <HDF5 dataset "b": shape (1,), type "|S3"> 

68 c <HDF5 group "/c" (1 members)> 

69 

70 This allows us also to access individual fields 

71 

72 >>> print(file['a'][0]) 

73 5 

74 

75 as well as nested data. 

76 

77 >>> print(file['c/L'][:]) 

78 [b'test'] 

79 

80 Parameters 

81 ---------- 

82 dict_to_output 

83 A ``dict`` with supported entry types, including other ``dict``s. File will be recursively 

84 structured according to the structure of the ``dict``. 

85 filename 

86 The name of the file, including the full path, which will be output. 

87 overwrite 

88 Whether to overwrite an existing file with name :attr:`filename`. Default is ``False``. 

89 

90 Notes 

91 ----- 

92 Supported entry types in :attr:`dict_to_output`: 

93 ``int``, ``float``, ``complex`` 

94 Saved as array of size ``1``. 

95 ``bytes``, ``str`` 

96 ``str`` is cast to ``bytes-like``. Saved as array of size ``1``. 

97 ``None`` 

98 Ignored. 

99 ``np.ndarray`` 

100 Provided ``dtype`` is not ``object``, ``datetime`` or ``timedelta, 

101 hence arrays of ``None`` are not allowed. 

102 ``list``, ``tuple`` 

103 Provided they can be cast to allowed, i.e. non-ragged ``np.ndarray``. 

104 ``dict`` 

105 Assuming entries are allowed types, including ``dict``. 

106 """ 

107 if os.path.exists(filename) and not overwrite: 107 ↛ 108line 107 didn't jump to line 108, because the condition on line 107 was never true

108 raise FileExistsError('File already exists, and overwrite is set to False.') 

109 with h5.File(filename, 'w') as file_to_save: 

110 for key, val in dict_to_output.items(): 

111 if isinstance(val, dict): 

112 file_to_save.create_group(key) 

113 save_dict_recursively(val, file_to_save[key]) 

114 else: 

115 save_item(key, val, file_to_save) 

116 logger.info(f'File {filename} saved successfully!')