Source code for specparam.plts.annotate

"""Plots for annotating power spectrum fittings and models."""

import numpy as np

from specparam.core.utils import nearest_ind
from specparam.core.errors import NoModelError
from specparam.core.funcs import gaussian_function
from specparam.core.modutils import safe_import, check_dependency

from specparam.sim.gen import gen_aperiodic
from specparam.analysis.periodic import get_band_peak
from specparam.utils.params import compute_knee_frequency, compute_fwhm

from specparam.plts.spectra import plot_spectra
from specparam.plts.utils import check_ax, savefig
from specparam.plts.settings import PLT_FIGSIZES, PLT_COLORS
from specparam.plts.style import style_spectrum_plot

plt = safe_import('.pyplot', 'matplotlib')
mpatches = safe_import('.patches', 'matplotlib')

###################################################################################################
###################################################################################################




[docs]@savefig @check_dependency(plt, 'matplotlib') def plot_annotated_model(model, plt_log=False, annotate_peaks=True, annotate_aperiodic=True, ax=None): """Plot a an annotated power spectrum and model, from a model object. Parameters ---------- model : SpectralModel Model object, with model fit, data and settings available. plt_log : boolean, optional, default: False Whether to plot the frequency values in log10 spacing. annotate_peaks : boolean, optional, default: True Whether to annotate the periodic components of the model fit. annotate_aperiodic : boolean, optional, default: True Whether to annotate the aperiodic components of the model fit. ax : matplotlib.Axes, optional Figure axes upon which to plot. Raises ------ NoModelError If there are no model results available to plot. """ # Check that model is available if not model.has_model: raise NoModelError("No model is available to plot, can not proceed.") # Settings fontsize = 15 lw1 = 4.0 lw2 = 3.0 ms1 = 12 # Create the baseline figure ax = check_ax(ax, PLT_FIGSIZES['spectral']) model.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, data_kwargs={'lw' : lw1, 'alpha' : 0.6}, aperiodic_kwargs={'lw' : lw1, 'zorder' : 10}, model_kwargs={'lw' : lw1, 'alpha' : 0.5}, peak_kwargs={'dot' : {'color' : PLT_COLORS['periodic'], 'ms' : ms1, 'lw' : lw2}, 'shade' : {'color' : PLT_COLORS['periodic']}, 'width' : {'color' : PLT_COLORS['periodic'], 'alpha' : 0.75, 'lw' : lw2}}) # Get freqs for plotting, and convert to log if needed freqs = model.freqs if not plt_log else np.log10(model.freqs) ## Buffers: for spacing things out on the plot (scaled by plot values) x_buff1 = max(freqs) * 0.1 x_buff2 = max(freqs) * 0.25 y_buff1 = 0.15 * np.ptp(ax.get_ylim()) shrink = 0.1 # There is a bug in annotations for some perpendicular lines, so add small offset # See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1. bug_buff = 0.000001 if annotate_peaks and model.n_peaks_: # Extract largest peak, to annotate, grabbing gaussian params gauss = get_band_peak(model, model.freq_range, attribute='gaussian_params') peak_ctr, peak_hgt, peak_wid = gauss bw_freqs = [peak_ctr - 0.5 * compute_fwhm(peak_wid), peak_ctr + 0.5 * compute_fwhm(peak_wid)] if plt_log: peak_ctr = np.log10(peak_ctr) bw_freqs = np.log10(bw_freqs) peak_top = model.power_spectrum[nearest_ind(freqs, peak_ctr)] # Annotate Peak CF ax.annotate('Center Frequency', xy=(peak_ctr, peak_top), xytext=(peak_ctr, peak_top+np.abs(0.6*peak_hgt)), verticalalignment='center', horizontalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['periodic'], shrink=shrink), color=PLT_COLORS['periodic'], fontsize=fontsize) # Annotate Peak PW ax.annotate('Power', xy=(peak_ctr, peak_top-0.3*peak_hgt), xytext=(peak_ctr+x_buff1, peak_top-0.3*peak_hgt), verticalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['periodic'], shrink=shrink), color=PLT_COLORS['periodic'], fontsize=fontsize) # Annotate Peak BW bw_buff = (peak_ctr - bw_freqs[0])/2 ax.annotate('Bandwidth', xy=(peak_ctr-bw_buff+bug_buff, peak_top-(0.5*peak_hgt)), xytext=(peak_ctr-bw_buff, peak_top-(1.5*peak_hgt)), verticalalignment='center', horizontalalignment='right', arrowprops=dict(facecolor=PLT_COLORS['periodic'], shrink=shrink), color=PLT_COLORS['periodic'], fontsize=fontsize, zorder=20) if annotate_aperiodic: # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.modeled_spectrum_[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', xy=(freqs[0]+bug_buff, model.power_spectrum[0]-y_buff1), xytext=(freqs[0]-x_buff1, model.power_spectrum[0]-y_buff1), verticalalignment='center', horizontalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['aperiodic'], shrink=shrink), color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Annotate Aperiodic Knee if model.aperiodic_mode == 'knee': # Find the knee frequency point to annotate knee_freq = compute_knee_frequency(model.get_params('aperiodic', 'knee'), model.get_params('aperiodic', 'exponent')) knee_freq = np.log10(knee_freq) if plt_log else knee_freq knee_pow = model.power_spectrum[nearest_ind(freqs, knee_freq)] # Add a dot to the plot indicating the knee frequency ax.plot(knee_freq, knee_pow, 'o', color=PLT_COLORS['aperiodic'], ms=ms1*1.5, alpha=0.7) ax.annotate('Knee', xy=(knee_freq, knee_pow), xytext=(knee_freq-x_buff2, knee_pow-y_buff1), verticalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['aperiodic'], shrink=shrink), color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Annotate Aperiodic Exponent mid_ind = int(len(freqs)/2) ax.annotate('Exponent', xy=(freqs[mid_ind], model.power_spectrum[mid_ind]), xytext=(freqs[mid_ind]-x_buff2, model.power_spectrum[mid_ind]-y_buff1), verticalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['aperiodic'], shrink=shrink), color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Apply style to plot & tune grid styling style_spectrum_plot(ax, plt_log, True) ax.grid(True, alpha=0.5) # Add labels to plot in the legend da_patch = mpatches.Patch(color=PLT_COLORS['data'], label='Original Data') ap_patch = mpatches.Patch(color=PLT_COLORS['aperiodic'], label='Aperiodic Parameters') pe_patch = mpatches.Patch(color=PLT_COLORS['periodic'], label='Peak Parameters') mo_patch = mpatches.Patch(color=PLT_COLORS['model'], label='Full Model') handles = [da_patch, ap_patch if annotate_aperiodic else None, pe_patch if annotate_peaks else None, mo_patch] handles = [el for el in handles if el is not None] ax.legend(handles=handles, handlelength=1, fontsize='x-large')