"""Define object to manage algorithm implementations."""
import numpy as np
from specparam.utils.checks import check_input_options
from specparam.algorithms.settings import SettingsDefinition, SettingsValues
from specparam.modutils.docs import docs_get_section, replace_docstring_sections
from specparam.reports.strings import gen_settings_str
###################################################################################################
###################################################################################################
DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms']
[docs]class Algorithm():
"""Template object for defining a fit algorithm.
Parameters
----------
name : str
Name of the fitting algorithm.
description : str
Description of the fitting algorithm.
public_settings : SettingsDefinition or dict
Name and description of public settings for the fitting algorithm.
private_settings : SettingsDefinition or dict, optional
Name and description of private settings for the fitting algorithm.
data_format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'}
Set base data format the model can be applied to.
modes : Modes
Modes object with fit mode definitions.
data : Data
Data object with spectral data and metadata.
results : Results
Results object with model fit results and metrics.
debug : bool
Whether to run in debug state, raising an error if encountered during fitting.
"""
[docs] def __init__(self, name, description, public_settings, private_settings=None,
data_format='spectrum', modes=None, data=None, results=None, debug=False):
"""Initialize Algorithm object."""
self.name = name
self.description = description
if not isinstance(public_settings, SettingsDefinition):
public_settings = SettingsDefinition(public_settings)
self.public_settings = public_settings
self.settings = SettingsValues(self.public_settings.names)
if private_settings is None:
private_settings = {}
if not isinstance(private_settings, SettingsDefinition):
private_settings = SettingsDefinition(private_settings)
self.private_settings = private_settings
self._settings = SettingsValues(self.private_settings.names)
check_input_options(data_format, DATA_FORMATS, 'data_format')
self.data_format = data_format
self.modes = None
self.data = None
self.results = None
self._reset_subobjects(modes, data, results)
self.set_debug(debug)
def _fit_prechecks(self, verbose):
"""Pre-checks to run before the fit function - if are some, overload this function."""
def _fit(self):
"""Required fit function, to be overloaded."""
[docs] def add_settings(self, settings):
"""Add settings into object from a ModelSettings object.
Parameters
----------
settings : ModelSettings
A data object containing model settings.
"""
for setting in settings._fields:
setattr(self.settings, setting, getattr(settings, setting))
[docs] def get_settings(self):
"""Return user defined settings of the current object.
Returns
-------
ModelSettings
Object containing the settings from the current object.
"""
return self.public_settings.make_model_settings()(\
**{key : getattr(self.settings, key) for key in self.public_settings.names})
[docs] def get_debug(self):
"""Return object debug status."""
return self._debug
[docs] def set_debug(self, debug):
"""Set debug state, which controls if an error is raised if model fitting is unsuccessful.
Parameters
----------
debug : bool
Whether to run in debug state.
"""
self._debug = debug
[docs] def print(self, description=False, concise=False):
"""Print out the algorithm name and fit settings.
Parameters
----------
description : bool, optional, default: False
Whether to print out a description with current settings.
concise : bool, optional, default: False
Whether to print the report in a concise mode, or not.
"""
print(gen_settings_str(self, description, concise))
def _reset_subobjects(self, modes=None, data=None, results=None):
"""Reset links to sub-objects (mode / data / results).
Parameters
----------
modes : Modes
Model modes object.
data : Data*
Model data object.
results : Results*
Model results object.
"""
if modes is not None:
self.modes = modes
if data is not None:
self.data = data
if results is not None:
self.results = results
## AlgorithmCF
CURVE_FIT_SETTINGS = SettingsDefinition({
'maxfev' : {
'type' : 'int',
'description' : 'The maximum number of calls to the curve fitting function.',
},
'tol' : {
'type' : 'float',
'description' : \
'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).'
},
})
@replace_docstring_sections([docs_get_section(Algorithm.__doc__, 'Parameters')])
class AlgorithmCF(Algorithm):
"""Template object for defining a fit algorithm that uses `curve_fit`.
Parameters
----------
% copied in from Algorithm
"""
def __init__(self, name, description, public_settings, private_settings=None,
data_format='spectrum', modes=None, data=None, results=None, debug=False):
"""Initialize Algorithm object."""
Algorithm.__init__(self, name=name, description=description,
public_settings=public_settings, private_settings=private_settings,
data_format=data_format, modes=modes, data=data, results=results,
debug=debug)
self._cf_settings_desc = CURVE_FIT_SETTINGS
self._cf_settings = SettingsValues(self._cf_settings_desc.names)
def _initialize_bounds(self, mode):
"""Initialize a bounds definition.
Parameters
----------
mode : {'aperiodic', 'periodic'}
Which mode to initialize for.
Returns
-------
bounds : tuple of array
Bounds values.
Notes
-----
Output follows the needed bounds definition for curve_fit, which is:
([low_bound_param1, low_bound_param2],
[high_bound_param1, high_bound_param2])
"""
# If modes defined, get number of params - otherwise set stores as empty
if self.modes is not None:
n_params = getattr(self.modes, mode).n_params
else:
n_params = 0
bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params))
return bounds
def _initialize_guess(self, mode):
"""Initialize a guess definition.
Parameters
----------
mode : {'aperiodic', 'periodic'}
Which mode to initialize for.
Returns
-------
guess : 1d array
Guess values.
"""
guess = np.zeros([getattr(self.modes, mode).n_params])
return guess