Coverage for local_installation_linux/mumott/optimization/optimizers/base_optimizer.py: 94%

67 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-08-11 23:08 +0000

1import sys 

2import logging 

3from abc import ABC, abstractmethod 

4from typing import Any, Dict, Iterator, Tuple 

5 

6import tqdm 

7import numpy as np 

8 

9from mumott.optimization.loss_functions.base_loss_function import LossFunction 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14class Optimizer(ABC): 

15 

16 """This is the base class from which specific optimizers are derived. 

17 """ 

18 

19 def __init__(self, 

20 loss_function: LossFunction, 

21 **kwargs: Dict[str, Any]): 

22 self._loss_function = loss_function 

23 

24 if 'no_tqdm' in kwargs: 

25 self._no_tqdm = kwargs.pop('no_tqdm') 

26 else: 

27 self._no_tqdm = False 

28 # empty kwargs automatically yields empty dict 

29 self._options = kwargs 

30 

31 # set, get and iter methods for kwargs interface 

32 def __setitem__(self, key: str, val: Any) -> None: 

33 """ Sets options akin to ``**kwargs`` during initialization. Allows 

34 access to instance as a dictionary. 

35 

36 Parameters 

37 ---------- 

38 key 

39 The key used in ``dict``-like interface to instance. 

40 val 

41 The value to store in association with ``key``. 

42 """ 

43 if key not in self._options.keys(): 43 ↛ 45line 43 didn't jump to line 45, because the condition on line 43 was never false

44 logger.info(f'Key {key} added to options with value {val}.') 

45 self._options[key] = val 

46 

47 def __getitem__(self, key: str) -> Any: 

48 """ Sets options akin to ``**kwargs`` during initialization. Allows 

49 access to instance as a dictionary. 

50 

51 Parameters 

52 ---------- 

53 key 

54 The key used in ``dict``-like interface to instance. 

55 val 

56 The value to store in association with ``key``. 

57 Returns 

58 Value stored in association with ``key`` 

59 """ 

60 if key in self._options.keys(): 60 ↛ 63line 60 didn't jump to line 63, because the condition on line 60 was never false

61 return self._options[key] 

62 else: 

63 raise KeyError(f'Unrecognized key: {key}') 

64 

65 def __iter__(self) -> Iterator[Tuple[str, Any]]: 

66 """ Allows casting as dict and use as an iterator. """ 

67 for t in self._options.items(): 

68 yield t 

69 

70 @property 

71 def no_tqdm(self): 

72 """Whether to avoid making a ``tqdm`` progress bar.""" 

73 return self._no_tqdm 

74 

75 def _tqdm(self, length: int): 

76 """ 

77 Returns tqdm iterable, unless ``no_tqdm`` is set to true, in which case 

78 it returns a ``range``. 

79 """ 

80 if self._no_tqdm: 

81 return range(length) 

82 else: 

83 return tqdm.tqdm(range(length), file=sys.stdout) 

84 

85 @abstractmethod 

86 def optimize(self) -> Dict: 

87 """ Function for executing the optimization. Should return a ``dict`` of the 

88 optimization results. """ 

89 pass 

90 

91 def __str__(self) -> str: 

92 wdt = 74 

93 s = [] 

94 s += ['=' * wdt] 

95 s += [__class__.__name__.center(wdt)] 

96 s += ['-' * wdt] 

97 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1): 

98 s += ['{:18} : {}'.format('LossFunction', self._loss_function.__class__.__name__)] 

99 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])] 

100 for key, value in self._options.items(): 

101 s += ['{:18} : {}'.format(f'option[{key}]', value)] 

102 s += ['-' * wdt] 

103 return '\n'.join(s) 

104 

105 def _repr_html_(self) -> str: 

106 s = [] 

107 s += [f'<h3>{__class__.__name__}</h3>'] 

108 s += ['<table border="1" class="dataframe">'] 

109 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>'] 

110 s += ['<tbody>'] 

111 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40): 

112 s += ['<tr><td style="text-align: left;">LossFunction</td>'] 

113 s += [f'<td>{1}</td><td>{self._loss_function.__class__.__name__}</td></tr>'] 

114 h = hex(hash(self)) 

115 s += ['<tr><td style="text-align: left;">Hash</td>'] 

116 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>'] 

117 for key, value in self._options.items(): 

118 s += [f'<tr><td style="text-align: left;">options[{key}]</td>'] 

119 s += [f'<td>{1}</td><td>{value}</td></tr>'] 

120 s += ['</tbody>'] 

121 s += ['</table>'] 

122 return '\n'.join(s)