"""Plots for the model object.
This file contains plotting functions that take as input a model object.
import numpy as np
from specparam.core.utils import nearest_ind
from specparam.core.modutils import safe_import, check_dependency
from specparam.sim.gen import gen_periodic
from specparam.utils.data import trim_spectrum
from specparam.utils.params import compute_fwhm
from specparam.plts.spectra import plot_spectra
from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS
from specparam.plts.utils import check_ax, check_plot_kwargs, savefig
from specparam.plts.style import style_spectrum_plot, style_plot
plt = safe_import('.pyplot', 'matplotlib')
@check_dependency(plt, 'matplotlib')
def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None,
freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None,
model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
"""Plot the power spectrum and model fit results from a model object.
model : SpectralModel
Object containing a power spectrum and (optionally) results from fitting.
plot_peaks : None or {'shade', 'dot', 'outline', 'line'}, optional
What kind of approach to take to plot peaks. If None, peaks are not specifically plotted.
Can also be a combination of approaches, separated by '-', for example: 'shade-line'.
plot_aperiodic : boolean, optional, default: True
Whether to plot the aperiodic component of the model fit.
freqs : 1d array, optional
Frequency values of the power spectrum to plot, in linear space.
If provided, this overrides the values in the model object.
power_spectrum : 1d array, optional
Power values to plot, in linear space.
If provided, this overrides the values in the model object.
freq_range : list of [float, float], optional
Frequency range to plot, defined in linear space.
plt_log : boolean, optional, default: False
Whether to plot the frequency values in log10 spacing.
add_legend : boolean, optional, default: False
Whether to add a legend describing the plot components.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
The y-axis (power) is plotted in log spacing by default.
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
# Check inputs for what to plot
custom_spectrum = (np.any(freqs) and np.any(power_spectrum))
# Log settings - note that power values in model objects are already logged
log_freqs = plt_log
log_powers = False
# Plot the data, if available
if model.has_data or custom_spectrum:
data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
'label' : 'Original Spectrum' if add_legend else None}
data_kwargs = check_plot_kwargs(data_kwargs, data_defaults)
plot_spectra(freqs if custom_spectrum else model.freqs,
power_spectrum if custom_spectrum else model.power_spectrum,
log_freqs, log_powers if not custom_spectrum else True,
freq_range, ax=ax, **data_kwargs)
# Add the full model fit, and components (if requested)
if model.has_model:
model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
'label' : 'Full Model Fit' if add_legend else None}
model_kwargs = check_plot_kwargs(model_kwargs, model_defaults)
plot_spectra(model.freqs, model.modeled_spectrum_,
log_freqs, log_powers, ax=ax, **model_kwargs)
# Plot the aperiodic component of the model fit
if plot_aperiodic:
aperiodic_defaults = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0,
'alpha' : 0.5, 'linestyle' : 'dashed',
'label' : 'Aperiodic Fit' if add_legend else None}
aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults)
plot_spectra(model.freqs, model._ap_fit,
log_freqs, log_powers, ax=ax, **aperiodic_kwargs)
# Plot the periodic components of the model fit
if plot_peaks:
_add_peaks(model, plot_peaks, plt_log, ax, peak_kwargs)
# Apply default style to plot
style_spectrum_plot(ax, log_freqs, True)
def _add_peaks(model, approach, plt_log, ax, peak_kwargs):
"""Add peaks to a model plot.
model : SpectralModel
Model object containing results from fitting.
approach : {'shade', 'dot', 'outline', 'outline', 'line'}
What kind of approach to take to plot peaks.
Can also be a combination of approaches, separated by '-' (for example 'shade-line').
plt_log : boolean, optional, default: False
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
peak_kwargs : None or dict
Keyword arguments to pass into the plot call.
This can be a flat dictionary, with plot keyword arguments,
or a dictionary of dictionaries, with keys as labels indicating an `approach`,
and values which contain a dictionary of plot keywords for that approach.
This is a pass through function, that takes a specification of one
or multiple add peak approaches to use, and calls the relevant function(s).
# Input for kwargs could be None, so check if dict and typecast if not
peak_kwargs = peak_kwargs if isinstance(peak_kwargs, dict) else {}
# Split up approaches, in case multiple are specified, and apply each
for cur_approach in approach.split('-'):
# This unpacks kwargs, if it's embedded dictionaries for each approach
plot_kwargs = peak_kwargs.get(cur_approach, peak_kwargs)
# Pass through to the peak plotting function
ADD_PEAK_FUNCS[cur_approach](model, plt_log, ax, **plot_kwargs)
except KeyError:
raise ValueError("Plot peak type not understood.")
def _add_peaks_shade(model, plt_log, ax, **plot_kwargs):
"""Add a shading in of all peaks.
model : SpectralModel
Model object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
Keyword arguments to pass into ``fill_between``.
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}
plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in model.gaussian_params_:
peak_freqs = np.log10(model.freqs) if plt_log else model.freqs
peak_line = model._ap_fit + gen_periodic(model.freqs, peak)
ax.fill_between(peak_freqs, peak_line, model._ap_fit, **plot_kwargs)
def _add_peaks_dot(model, plt_log, ax, **plot_kwargs):
"""Add a short line, from aperiodic to peak, with a dot at the top.
model : SpectralModel
Model object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
Keyword arguments to pass into the plot call.
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}
plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in model.peak_params_:
ap_point = np.interp(peak[0], model.freqs, model._ap_fit)
freq_point = np.log10(peak[0]) if plt_log else peak[0]
# Add the line from the aperiodic fit up the tip of the peak
ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **plot_kwargs)
# Add an extra dot at the tip of the peak
ax.plot(freq_point, ap_point + peak[1], marker='o', **plot_kwargs)
def _add_peaks_outline(model, plt_log, ax, **plot_kwargs):
"""Add an outline of each peak.
model : SpectralModel
Model object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
Keyword arguments to pass into the plot call.
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5}
plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in model.gaussian_params_:
# Define the frequency range around each peak to plot - peak bandwidth +/- 3
peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3]
# Generate a peak reconstruction for each peak, and trim to desired range
peak_line = model._ap_fit + gen_periodic(model.freqs, peak)
peak_freqs, peak_line = trim_spectrum(model.freqs, peak_line, peak_range)
# Plot the peak outline
peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs
ax.plot(peak_freqs, peak_line, **plot_kwargs)
def _add_peaks_line(model, plt_log, ax, **plot_kwargs):
"""Add a long line, from the top of the plot, down through the peak, with an arrow at the top.
model : SpectralModel
Model object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
Keyword arguments to pass into the plot call.
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}
plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
ylims = ax.get_ylim()
for peak in model.peak_params_:
freq_point = np.log10(peak[0]) if plt_log else peak[0]
ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs)
ax.plot(freq_point, ylims[1], 'v', **plot_kwargs)
def _add_peaks_width(model, plt_log, ax, **plot_kwargs):
"""Add a line across the width of peaks.
model : SpectralModel
Model object containing results from fitting.
plt_log : boolean
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes
Figure axes upon which to plot.
Keyword arguments to pass into the plot call.
This line represents the bandwidth (width or gaussian standard deviation) of
the peak, though what is literally plotted is the full-width half-max.
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}
plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in model.gaussian_params_:
peak_top = model.power_spectrum[nearest_ind(model.freqs, peak[0])]
bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]),
peak[0] + 0.5 * compute_fwhm(peak[2])]
if plt_log:
bw_freqs = np.log10(bw_freqs)
ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **plot_kwargs)
# Collect all the possible `add_peak_*` functions together
'shade' : _add_peaks_shade,
'dot' : _add_peaks_dot,
'outline' : _add_peaks_outline,
'line' : _add_peaks_line,
'width' : _add_peaks_width