"""Group model object and associated code for fitting the model to 2D groups of power spectra.
Notes
-----
Methods without defined docstrings import docs at runtime, from aliased external functions.
"""
from specparam.objs.base import BaseObject2D
from specparam.objs.model import SpectralModel
from specparam.objs.algorithm import SpectralFitAlgorithm
from specparam.plts.group import plot_group_model
from specparam.core.reports import save_group_report
from specparam.core.strings import gen_group_results_str
from specparam.core.modutils import (copy_doc_func_to_method,
docs_get_section, replace_docstring_sections)
from specparam.data.conversions import group_to_dataframe
###################################################################################################
###################################################################################################
[docs]@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'),
docs_get_section(SpectralModel.__doc__, 'Notes')])
class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D):
"""Model a group of power spectra as a combination of aperiodic and periodic components.
WARNING: frequency and power values inputs must be in linear space.
Passing in logged frequencies and/or power spectra is not detected,
and will silently produce incorrect results.
Parameters
----------
%copied in from SpectralModel object
Attributes
----------
freqs : 1d array
Frequency values for the power spectra.
power_spectra : 2d array
Power values for the group of power spectra, as [n_power_spectra, n_freqs].
Power values are stored internally in log10 scale.
freq_range : list of [float, float]
Frequency range of the power spectra, as [lowest_freq, highest_freq].
freq_res : float
Frequency resolution of the power spectra.
group_results : list of FitResults
Results of the model fit for each power spectrum.
has_data : bool
Whether data is loaded to the object.
has_model : bool
Whether model results are available in the object.
n_peaks_ : int
The number of peaks fit in the model.
n_null_ : int
The number of models that failed to fit and/or that are marked as null.
null_inds_ : list of int
The indices of any models that are null.
Notes
-----
%copied in from SpectralModel object
- The group object inherits from the model object. As such it also has data
attributes (`power_spectrum` & `modeled_spectrum_`), and parameter attributes
(`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`)
which are defined in the context of individual model fits. These attributes are
used during the fitting process, but in the group context do not store results
post-fitting. Rather, all model fit results are collected and stored into the
`group_results` attribute. To access individual parameters of the fit, use
the `get_params` method.
"""
[docs] def __init__(self, *args, **kwargs):
BaseObject2D.__init__(self,
aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'),
periodic_mode=kwargs.pop('periodic_mode', 'gaussian'),
debug_mode=kwargs.pop('debug_mode', False),
verbose=kwargs.pop('verbose', True))
SpectralFitAlgorithm.__init__(self, *args, **kwargs)
[docs] def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1,
progress=None, **plot_kwargs):
"""Fit a group of power spectra and display a report, with a plot and printed results.
Parameters
----------
freqs : 1d array, optional
Frequency values for the power_spectra, in linear space.
power_spectra : 2d array, shape: [n_power_spectra, n_freqs], optional
Matrix of power spectrum values, in linear space.
freq_range : list of [float, float], optional
Frequency range to fit the model to. If not provided, fits the entire given range.
n_jobs : int, optional, default: 1
Number of jobs to run in parallel.
1 is no parallelization. -1 uses all available cores.
progress : {None, 'tqdm', 'tqdm.notebook'}, optional
Which kind of progress bar to use. If None, no progress bar is used.
**plot_kwargs
Keyword arguments to pass into the plot method.
Notes
-----
Data is optional, if data has already been added to the object.
"""
self.fit(freqs, power_spectra, freq_range, n_jobs=n_jobs, progress=progress)
self.plot(**plot_kwargs)
self.print_results(False)
[docs] @copy_doc_func_to_method(plot_group_model)
def plot(self, **plot_kwargs):
plot_group_model(self, **plot_kwargs)
[docs] @copy_doc_func_to_method(save_group_report)
def save_report(self, file_name, file_path=None, add_settings=True):
save_group_report(self, file_name, file_path, add_settings)
[docs] def print_results(self, concise=False):
"""Print out the group results.
Parameters
----------
concise : bool, optional, default: False
Whether to print the report in a concise mode, or not.
"""
print(gen_group_results_str(self, concise))
[docs] def save_model_report(self, index, file_name, file_path=None,
add_settings=True, **plot_kwargs):
""""Save out an individual model report for a specified model fit.
Parameters
----------
index : int
Index of the model fit to save out.
file_name : str
Name to give the saved out file.
file_path : Path or str, optional
Path to directory to save to. If None, saves to current directory.
add_settings : bool, optional, default: True
Whether to add a print out of the model settings to the end of the report.
plot_kwargs : keyword arguments
Keyword arguments to pass into the plot method.
"""
self.get_model(ind=index, regenerate=True).save_report(\
file_name, file_path, add_settings, **plot_kwargs)
[docs] def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.
Parameters
----------
peak_org : int or Bands
How to organize peaks.
If int, extracts the first n peaks.
If Bands, extracts peaks based on band definitions.
Returns
-------
pd.DataFrame
Model results organized into a pandas object.
"""
return group_to_dataframe(self.get_results(), peak_org)
def _check_width_limits(self):
"""Check and warn about bandwidth limits / frequency resolution interaction."""
# Only check & warn on first power spectrum
# This is to avoid spamming standard output for every spectrum in the group
if self.power_spectra[0, 0] == self.power_spectrum[0]:
super()._check_width_limits()