Source code for mumott.optimization.optimizers.lbfgs

import logging

from typing import Dict, Any

from numpy import float64

from scipy.optimize import minimize

from mumott.core.hashing import list_to_hash
from mumott.optimization.loss_functions.base_loss_function import LossFunction
from .base_optimizer import Optimizer


logger = logging.getLogger(__name__)


[docs]class LBFGS(Optimizer): """This Optimizer makes the :term:`LBFGS` algorithm from `scipy.optimize <https://docs.scipy.org/doc/scipy/reference/optimize.html>`_ available for usage with a :class:`LossFunction <module-mumott.optimization.loss_functions>`. Parameters ---------- loss_function : LossFunction The :ref:`loss function <loss_functions>` to be minimized using this algorithm. kwargs : Dict[str, Any] Miscellaneous options. See notes for valid entries. Notes ----- Valid entries in :attr:`kwargs` are x0 Initial guess for solution vector. Must be the same size as :attr:`residual_calculator.coefficients`. Defaults to :attr:`loss_function.initial_values`. bounds Used to set the ``bounds`` of the optimization method, see `scipy.optimize.minimize <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html>`_ documentation for details. Defaults to ``None``. maxiter Maximum number of iterations. Defaults to ``10``. disp Whether to display output from the optimizer. Defaults to ``False`` maxcor Maximum number of Hessian corrections to the Jacobian. Defaults to ``10``. iprint If ``disp`` is true, controls output with no output if ``iprint < 0``, convergence output only if ``iprint == 0``, iteration-wise output if ``0 < iprint <= 99``, and sub-iteration output if ``iprint > 99``. maxfun Maximum number of function evaluations, including line search evaluations. Defaults to ``20``. ftol Relative change tolerance for objective function. Changes to absolute change tolerance if objective function is ``< 1``, see `scipy.optimize.minimize <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html>`_ documentation, which may lead to excessively fast convergence. Defaults to ``1e-3``. gtol Convergence tolerance for gradient. Defaults to ``1e-5``. """ def __init__(self, loss_function: LossFunction, **kwargs: Dict[str, Any]): super().__init__(loss_function, **kwargs) # This will later be used to reshape the flattened output. self._output_shape = None
[docs] def optimize(self) -> Dict: """ Executes the optimization using the options stored in this class instance. The optimization will continue until convergence, or until the maximum number of iterations (:attr:`maxiter`) is exceeded. Returns ------- A ``dict`` of optimization results. See `scipy.optimize.OptimizeResult <https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.OptimizeResult.html>`_ for details. The entry ``'x'``, which contains the result, will be reshaped using the shape of the gradient from :attr:`loss_function`. """ lbfgs_kwargs = dict(x0=self._loss_function.initial_values, bounds=None) misc_options = dict(maxiter=10, disp=False, maxcor=10, iprint=1, maxfun=20, ftol=1e-3, gtol=1e-5) for k in lbfgs_kwargs: if k in dict(self): lbfgs_kwargs[k] = self[k] for k in misc_options: if k in dict(self): misc_options[k] = self[k] for k in dict(self): if k not in lbfgs_kwargs and k not in misc_options: logger.warning(f'Unknown option {k}, with value {self[k]}, has been ignored.') lbfgs_kwargs['x0'] = lbfgs_kwargs['x0'].ravel() with self._tqdm(misc_options['maxiter']) as progress: def progress_callback(*args, **kwargs): progress.update(1) def loss_function_wrapper(coefficients): d = self._loss_function.get_loss(coefficients, get_gradient=True) # Store gradient shape to reshape flattened output. if self._output_shape is None: self._output_shape = d['gradient'].shape # LBFGS needs float64 return d['loss'], d['gradient'].ravel().astype(float64) result = minimize(fun=loss_function_wrapper, callback=progress_callback, **lbfgs_kwargs, jac=True, method='L-BFGS-B', options=misc_options) result = dict(result) result['x'] = result['x'].reshape(self._output_shape) return dict(result)
def __hash__(self) -> int: to_hash = [self._options, hash(self._loss_function)] return int(list_to_hash(to_hash), 16)