Coverage for local_installation_linux/mumott/optimization/optimizers/base_optimizer.py: 94%
67 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-05-05 21:21 +0000
1import sys
2import logging
3from abc import ABC, abstractmethod
4from typing import Any, Dict, Iterator, Tuple
6import tqdm
7import numpy as np
9from mumott.optimization.loss_functions.base_loss_function import LossFunction
11logger = logging.getLogger(__name__)
14class Optimizer(ABC):
16 """This is the base class from which specific optimizers are derived.
17 """
19 def __init__(self,
20 loss_function: LossFunction,
21 **kwargs: Dict[str, Any]):
22 self._loss_function = loss_function
24 if 'no_tqdm' in kwargs:
25 self._no_tqdm = kwargs.pop('no_tqdm')
26 else:
27 self._no_tqdm = False
28 # empty kwargs automatically yields empty dict
29 self._options = kwargs
31 # set, get and iter methods for kwargs interface
32 def __setitem__(self, key: str, val: Any) -> None:
33 """ Sets options akin to ``**kwargs`` during initialization. Allows
34 access to instance as a dictionary.
36 Parameters
37 ----------
38 key
39 The key used in ``dict``-like interface to instance.
40 val
41 The value to store in association with ``key``.
42 """
43 if key not in self._options.keys(): 43 ↛ 45line 43 didn't jump to line 45, because the condition on line 43 was never false
44 logger.info(f'Key {key} added to options with value {val}.')
45 self._options[key] = val
47 def __getitem__(self, key: str) -> Any:
48 """ Sets options akin to ``**kwargs`` during initialization. Allows
49 access to instance as a dictionary.
51 Parameters
52 ----------
53 key
54 The key used in ``dict``-like interface to instance.
55 val
56 The value to store in association with ``key``.
57 Returns
58 Value stored in association with ``key``
59 """
60 if key in self._options.keys(): 60 ↛ 63line 60 didn't jump to line 63, because the condition on line 60 was never false
61 return self._options[key]
62 else:
63 raise KeyError(f'Unrecognized key: {key}')
65 def __iter__(self) -> Iterator[Tuple[str, Any]]:
66 """ Allows casting as dict and use as an iterator. """
67 for t in self._options.items():
68 yield t
70 @property
71 def no_tqdm(self):
72 """Whether to avoid making a ``tqdm`` progress bar."""
73 return self._no_tqdm
75 def _tqdm(self, length: int):
76 """
77 Returns tqdm iterable, unless ``no_tqdm`` is set to true, in which case
78 it returns a ``range``.
79 """
80 if self._no_tqdm:
81 return range(length)
82 else:
83 return tqdm.tqdm(range(length), file=sys.stdout)
85 @abstractmethod
86 def optimize(self) -> Dict:
87 """ Function for executing the optimization. Should return a ``dict`` of the
88 optimization results. """
89 pass
91 def __str__(self) -> str:
92 wdt = 74
93 s = []
94 s += ['=' * wdt]
95 s += [__class__.__name__.center(wdt)]
96 s += ['-' * wdt]
97 with np.printoptions(threshold=4, precision=5, linewidth=60, edgeitems=1):
98 s += ['{:18} : {}'.format('LossFunction', self._loss_function.__class__.__name__)]
99 s += ['{:18} : {}'.format('hash', hex(hash(self))[2:8])]
100 for key, value in self._options.items():
101 s += ['{:18} : {}'.format(f'option[{key}]', value)]
102 s += ['-' * wdt]
103 return '\n'.join(s)
105 def _repr_html_(self) -> str:
106 s = []
107 s += [f'<h3>{__class__.__name__}</h3>']
108 s += ['<table border="1" class="dataframe">']
109 s += ['<thead><tr><th style="text-align: left;">Field</th><th>Size</th><th>Data</th></tr></thead>']
110 s += ['<tbody>']
111 with np.printoptions(threshold=4, edgeitems=2, precision=2, linewidth=40):
112 s += ['<tr><td style="text-align: left;">LossFunction</td>']
113 s += [f'<td>{1}</td><td>{self._loss_function.__class__.__name__}</td></tr>']
114 h = hex(hash(self))
115 s += ['<tr><td style="text-align: left;">Hash</td>']
116 s += [f'<td>{len(h)}</td><td>{h[2:8]}</td></tr>']
117 for key, value in self._options.items():
118 s += [f'<tr><td style="text-align: left;">options[{key}]</td>']
119 s += [f'<td>{1}</td><td>{value}</td></tr>']
120 s += ['</tbody>']
121 s += ['</table>']
122 return '\n'.join(s)