"""Style and aesthetics definitions for plots."""
from itertools import cycle
from functools import wraps
import matplotlib.pyplot as plt
from fooof.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS,
CUSTOM_STYLE_ARGS, STYLE_ARGS, TICK_LABELSIZE, TITLE_FONTSIZE,
LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC)
###################################################################################################
###################################################################################################
[docs]def check_style_options():
"""Check the list of valid style arguments that can be passed into plot functions."""
print('Valid style arguments:')
for label, options in zip(['Axis', 'Line', 'Collection', 'Custom'],
[AXIS_STYLE_ARGS, LINE_STYLE_ARGS,
COLLECTION_STYLE_ARGS, CUSTOM_STYLE_ARGS]):
print(' {:10s} {}'.format(label, ', '.join(options)))
def style_spectrum_plot(ax, log_freqs, log_powers, grid=True):
"""Apply style and aesthetics to a power spectrum plot.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply styling to.
log_freqs : bool
Whether the frequency axis is plotted in log space.
log_powers : bool
Whether the power axis is plotted in log space.
grid : bool, optional, default: True
Whether to add grid lines to the plot.
"""
# Get labels, based on log status
xlabel = 'Frequency' if not log_freqs else 'log(Frequency)'
ylabel = 'Power' if not log_powers else 'log(Power)'
# Aesthetics and axis labels
ax.set_xlabel(xlabel, fontsize=20)
ax.set_ylabel(ylabel, fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.grid(grid)
# If labels were provided, add a legend
if ax.get_legend_handles_labels()[0]:
ax.legend(prop={'size': 16}, loc='upper right')
def style_param_plot(ax):
"""Apply style and aesthetics to a peaks plot.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply styling to.
"""
# Set the top and right side frame & ticks off
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
# Set linewidth of remaining spines
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
# Aesthetics and axis labels
for item in ([ax.xaxis.label, ax.yaxis.label]):
item.set_fontsize(20)
ax.tick_params(axis='both', which='major', labelsize=16)
# If labels were provided, add a legend and standardize the dot size
if ax.get_legend_handles_labels()[0]:
legend = ax.legend(prop={'size': 16})
for handle in legend.legendHandles:
handle._sizes = [100]
def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
"""Apply axis plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as axis styling.
**kwargs
Keyword arguments that define plot style to apply.
"""
# Apply any provided axis style arguments
plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
ax.set(**plot_kwargs)
def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
"""Apply line plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as line styling.
**kwargs
Keyword arguments that define line style to apply.
"""
# Check how many lines are from the current plot call, to apply style to
# If available, this indicates the apply styling to the last 'n' lines
n_lines_apply = kwargs.pop('n_lines_apply', 0)
# Get the line related styling arguments from the keyword arguments
line_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
# Apply any provided line style arguments
for style, value in line_kwargs.items():
# Values should be either a single value, for all lines, or a list, of a value per line
# This line checks type, and makes a cycle-able / loop-able object out of the values
values = cycle([value] if isinstance(value, (int, float, str)) else value)
for line in ax.lines[-n_lines_apply:]:
line.set(**{style : next(values)})
def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs):
"""Apply collection plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as collection styling.
**kwargs
Keyword arguments that define collection style to apply.
"""
# Get the collection related styling arguments from the keyword arguments
collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
# Apply any provided collection style arguments
for collection in ax.collections:
collection.set(**collection_kwargs)
def apply_custom_style(ax, **kwargs):
"""Apply custom plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
**kwargs
Keyword arguments that define custom style to apply.
"""
# If a title was provided, update the size
if ax.get_title():
ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE))
# Settings for the axis labels
label_size = kwargs.pop('label_size', LABEL_SIZE)
ax.xaxis.label.set_size(label_size)
ax.yaxis.label.set_size(label_size)
# Settings for the axis ticks
ax.tick_params(axis='both', which='major',
labelsize=kwargs.pop('tick_labelsize', TICK_LABELSIZE))
# If labels were provided, add a legend
if ax.get_legend_handles_labels()[0]:
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))
# Apply tight layout to the figure object, if matplotlib is new enough
# If available, `.set_layout_engine` should be equivalent to
# `plt.tight_layout()`, but seems to raise fewer warnings...
try:
fig = plt.gcf()
fig.set_layout_engine('tight')
except:
plt.tight_layout()
def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
collection_styler=apply_collection_style, custom_styler=apply_custom_style,
**kwargs):
"""Apply plot style to a figure axis.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
axis_styler, line_styler, collection_style, custom_styler : callable, optional
Functions to apply style to aspects of the plot.
**kwargs
Keyword arguments that define style to apply.
Notes
-----
This function wraps sub-functions which apply style to different plot elements.
Each of these sub-functions can be replaced by passing in replacement callables.
"""
axis_styler(ax, **kwargs) if axis_styler is not None else None
line_styler(ax, **kwargs) if line_styler is not None else None
collection_styler(ax, **kwargs) if collection_styler is not None else None
custom_styler(ax, **kwargs) if custom_styler is not None else None
def style_plot(func, *args, **kwargs):
"""Decorator function to apply a plot style function, after plot generation.
Parameters
----------
func : callable
The plotting function for creating a plot.
*args, **kwargs
Arguments & keyword arguments.
These should include any arguments for the plot, and those for applying plot style.
Notes
-----
This decorator works by:
- catching all inputs that relate to plot style
- creating a plot, using the passed in plotting function & passing in all non-style arguments
- passing the style related arguments into a `apply_style` function which applies plot styling
By default, this function applies styling with the `apply_style` function. Custom
functions for applying style can be passed in using `apply_style` as a keyword argument.
The `apply_style` function calls sub-functions for applying different plot elements, including:
- `axis_styler`: apply style options to an axis
- `line_styler`: applies style options to lines objects in a plot
- `collection_styler`: applies style options to collections objects in a plot
- `custom_style`: applies custom style options
Each of these sub-functions can be overridden by passing in alternatives.
To see the full set of style arguments that are supported, run the following code:
>>> from fooof.plts.style import check_style_options
>>> check_style_options()
Valid style arguments:
Axis title, xlabel, ylabel, xlim, ylim, xticks, yticks, xticklabels, yticklabels
Line alpha, lw, linewidth, ls, linestyle, marker, ms, markersize
Collection alpha, edgecolor
Custom title_fontsize, label_size, tick_labelsize, legend_size, legend_loc
"""
@wraps(func)
def decorated(*args, **kwargs):
# Grab a custom style function, if provided, and grab any provided style arguments
style_func = kwargs.pop('plot_style', apply_style)
style_args = kwargs.pop('style_args', STYLE_ARGS)
style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs}
# Check how many lines are already on the plot, if it exists already
n_lines_pre = len(kwargs['ax'].lines) if 'ax' in kwargs and kwargs['ax'] is not None else 0
# Create the plot
func(*args, **kwargs)
# Get plot axis, if a specific one was provided, or if not, grab the current axis
cur_ax = kwargs['ax'] if 'ax' in kwargs and kwargs['ax'] is not None else plt.gca()
# Check how many lines were added to the plot, and make info available to plot styling
n_lines_apply = len(cur_ax.lines) - n_lines_pre
style_kwargs['n_lines_apply'] = n_lines_apply
# Apply the styling function
style_func(cur_ax, **style_kwargs)
return decorated