Source code for specparam.plts.model

"""Plots for the model object.

Notes
-----
This file contains plotting functions that take as input a model object.
"""

import numpy as np

from specparam.core.utils import nearest_ind
from specparam.core.modutils import safe_import, check_dependency
from specparam.sim.gen import gen_periodic
from specparam.utils.data import trim_spectrum
from specparam.utils.params import compute_fwhm
from specparam.plts.spectra import plot_spectra
from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS
from specparam.plts.utils import check_ax, check_plot_kwargs, savefig
from specparam.plts.style import style_spectrum_plot, style_plot

plt = safe_import('.pyplot', 'matplotlib')

###################################################################################################
###################################################################################################

[docs]@savefig @style_plot @check_dependency(plt, 'matplotlib') def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_spectrum=None, freq_range=None, plt_log=False, add_legend=True, ax=None, data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): """Plot the power spectrum and model fit results from a model object. Parameters ---------- model : SpectralModel Object containing a power spectrum and (optionally) results from fitting. plot_peaks : None or {'shade', 'dot', 'outline', 'line'}, optional What kind of approach to take to plot peaks. If None, peaks are not specifically plotted. Can also be a combination of approaches, separated by '-', for example: 'shade-line'. plot_aperiodic : boolean, optional, default: True Whether to plot the aperiodic component of the model fit. freqs : 1d array, optional Frequency values of the power spectrum to plot, in linear space. If provided, this overrides the values in the model object. power_spectrum : 1d array, optional Power values to plot, in linear space. If provided, this overrides the values in the model object. freq_range : list of [float, float], optional Frequency range to plot, defined in linear space. plt_log : boolean, optional, default: False Whether to plot the frequency values in log10 spacing. add_legend : boolean, optional, default: False Whether to add a legend describing the plot components. ax : matplotlib.Axes, optional Figure axes upon which to plot. data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional Keyword arguments to pass into the plot call for each plot element. **plot_kwargs Additional plot related keyword arguments, with styling options managed by ``style_plot``. Notes ----- The y-axis (power) is plotted in log spacing by default. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) # Check inputs for what to plot custom_spectrum = (np.any(freqs) and np.any(power_spectrum)) # Log settings - note that power values in model objects are already logged log_freqs = plt_log log_powers = False # Plot the data, if available if model.has_data or custom_spectrum: data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, 'label' : 'Original Spectrum' if add_legend else None} data_kwargs = check_plot_kwargs(data_kwargs, data_defaults) plot_spectra(freqs if custom_spectrum else model.freqs, power_spectrum if custom_spectrum else model.power_spectrum, log_freqs, log_powers if not custom_spectrum else True, freq_range, ax=ax, **data_kwargs) # Add the full model fit, and components (if requested) if model.has_model: model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) plot_spectra(model.freqs, model.modeled_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit if plot_aperiodic: aperiodic_defaults = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, 'alpha' : 0.5, 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None} aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) plot_spectra(model.freqs, model._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit if plot_peaks: _add_peaks(model, plot_peaks, plt_log, ax, peak_kwargs) # Apply default style to plot style_spectrum_plot(ax, log_freqs, True)
def _add_peaks(model, approach, plt_log, ax, peak_kwargs): """Add peaks to a model plot. Parameters ---------- model : SpectralModel Model object containing results from fitting. approach : {'shade', 'dot', 'outline', 'outline', 'line'} What kind of approach to take to plot peaks. Can also be a combination of approaches, separated by '-' (for example 'shade-line'). plt_log : boolean, optional, default: False Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. peak_kwargs : None or dict Keyword arguments to pass into the plot call. This can be a flat dictionary, with plot keyword arguments, or a dictionary of dictionaries, with keys as labels indicating an `approach`, and values which contain a dictionary of plot keywords for that approach. Notes ----- This is a pass through function, that takes a specification of one or multiple add peak approaches to use, and calls the relevant function(s). """ # Input for kwargs could be None, so check if dict and typecast if not peak_kwargs = peak_kwargs if isinstance(peak_kwargs, dict) else {} # Split up approaches, in case multiple are specified, and apply each for cur_approach in approach.split('-'): try: # This unpacks kwargs, if it's embedded dictionaries for each approach plot_kwargs = peak_kwargs.get(cur_approach, peak_kwargs) # Pass through to the peak plotting function ADD_PEAK_FUNCS[cur_approach](model, plt_log, ax, **plot_kwargs) except KeyError: raise ValueError("Plot peak type not understood.") def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): """Add a shading in of all peaks. Parameters ---------- model : SpectralModel Model object containing results from fitting. plt_log : boolean Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs Keyword arguments to pass into ``fill_between``. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in model.gaussian_params_: peak_freqs = np.log10(model.freqs) if plt_log else model.freqs peak_line = model._ap_fit + gen_periodic(model.freqs, peak) ax.fill_between(peak_freqs, peak_line, model._ap_fit, **plot_kwargs) def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): """Add a short line, from aperiodic to peak, with a dot at the top. Parameters ---------- model : SpectralModel Model object containing results from fitting. plt_log : boolean Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs Keyword arguments to pass into the plot call. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in model.peak_params_: ap_point = np.interp(peak[0], model.freqs, model._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **plot_kwargs) # Add an extra dot at the tip of the peak ax.plot(freq_point, ap_point + peak[1], marker='o', **plot_kwargs) def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): """Add an outline of each peak. Parameters ---------- model : SpectralModel Model object containing results from fitting. plt_log : boolean Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs Keyword arguments to pass into the plot call. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in model.gaussian_params_: # Define the frequency range around each peak to plot - peak bandwidth +/- 3 peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] # Generate a peak reconstruction for each peak, and trim to desired range peak_line = model._ap_fit + gen_periodic(model.freqs, peak) peak_freqs, peak_line = trim_spectrum(model.freqs, peak_line, peak_range) # Plot the peak outline peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs ax.plot(peak_freqs, peak_line, **plot_kwargs) def _add_peaks_line(model, plt_log, ax, **plot_kwargs): """Add a long line, from the top of the plot, down through the peak, with an arrow at the top. Parameters ---------- model : SpectralModel Model object containing results from fitting. plt_log : boolean Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs Keyword arguments to pass into the plot call. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) ylims = ax.get_ylim() for peak in model.peak_params_: freq_point = np.log10(peak[0]) if plt_log else peak[0] ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) ax.plot(freq_point, ylims[1], 'v', **plot_kwargs) def _add_peaks_width(model, plt_log, ax, **plot_kwargs): """Add a line across the width of peaks. Parameters ---------- model : SpectralModel Model object containing results from fitting. plt_log : boolean Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs Keyword arguments to pass into the plot call. Notes ----- This line represents the bandwidth (width or gaussian standard deviation) of the peak, though what is literally plotted is the full-width half-max. """ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in model.gaussian_params_: peak_top = model.power_spectrum[nearest_ind(model.freqs, peak[0])] bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]), peak[0] + 0.5 * compute_fwhm(peak[2])] if plt_log: bw_freqs = np.log10(bw_freqs) ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **plot_kwargs) # Collect all the possible `add_peak_*` functions together ADD_PEAK_FUNCS = { 'shade' : _add_peaks_shade, 'dot' : _add_peaks_dot, 'outline' : _add_peaks_outline, 'line' : _add_peaks_line, 'width' : _add_peaks_width }