"""Classes & functions for managing parameters for simulating power spectra."""
import numpy as np
from specparam.core.utils import group_three, check_flat
from specparam.core.info import get_indices
from specparam.core.funcs import infer_ap_func
from specparam.core.errors import InconsistentDataError
from specparam.data import SimParams
###################################################################################################
###################################################################################################
def collect_sim_params(aperiodic_params, periodic_params, nlv):
"""Collect simulation parameters into a SimParams object.
Parameters
----------
aperiodic_params : list of float
Parameters of the aperiodic component of the power spectrum.
periodic_params : list of float or list of list of float
Parameters of the periodic component of the power spectrum.
nlv : float
Noise level of the power spectrum.
Returns
-------
SimParams
Object containing the simulation parameters.
"""
return SimParams(aperiodic_params.copy(),
sorted(group_three(check_flat(periodic_params))),
nlv)
[docs]def update_sim_ap_params(sim_params, delta, field=None):
"""Update the aperiodic parameter definition in a SimParams object.
Parameters
----------
sim_params : SimParams
Object storing the current parameter definition.
delta : float or list of float
Value(s) by which to update the parameters.
field : {'offset', 'knee', 'exponent'} or list of string
Field of the aperiodic parameter(s) to update.
Returns
-------
new_sim_params : SimParams
Updated object storing the new parameter definition.
Raises
------
InconsistentDataError
If the input parameters and update values are inconsistent.
"""
# Grab the aperiodic parameters that need updating
ap_params = sim_params.aperiodic_params.copy()
# If field isn't specified, check shapes line up and update across parameters
if not field:
if not len(ap_params) == len(delta):
raise InconsistentDataError("The number of items to update and "
"number of new values is inconsistent.")
ap_params = [param + update for param, update in zip(ap_params, delta)]
# If labels are given, update deltas according to their labels
else:
# This loop checks & casts to list, to work for single or multiple passed in values
for cur_field, cur_delta in zip(list([field]) if not isinstance(field, list) else field,
list([delta]) if not isinstance(delta, list) else delta):
data_ind = get_indices(infer_ap_func(ap_params))[cur_field]
ap_params[data_ind] = ap_params[data_ind] + cur_delta
# Replace parameters. Note that this copies a new object, as data objects are immutable
new_sim_params = sim_params._replace(aperiodic_params=ap_params)
return new_sim_params
[docs]class Stepper():
"""Object for stepping across parameter values.
Parameters
----------
start : float
Start value to iterate from.
stop : float
End value to iterate to.
step : float
Increment of each iteration.
Attributes
----------
len : int
Length of generator range.
data : iterator
Set of specified parameters to iterate across.
Examples
--------
Define a stepper object for center frequency values for an alpha peak:
>>> alpha_cf_steps = Stepper(8, 12.5, 0.5)
"""
[docs] def __init__(self, start, stop, step):
"""Initialize a Stepper object."""
self._check_values(start, stop, step)
self.start = start
self.stop = stop
self.step = step
self.len = round((stop-start)/step)
self.data = iter(np.arange(start, stop, step))
def __len__(self):
return self.len
def __next__(self):
return round(next(self.data), 4)
def __iter__(self):
return self.data
@staticmethod
def _check_values(start, stop, step):
"""Checks if provided values are valid.
Parameters
----------
start, stop, step : float
Definition of the parameter range to iterate over.
Raises
------
ValueError
If the given values for defining the iteration range are inconsistent.
"""
if any(ii < 0 for ii in [start, stop]):
raise ValueError("Inputs 'start' and 'stop' should be positive values.")
if (stop - start) * step < 0:
raise ValueError("The sign of 'step' does not align with 'start' / 'stop' values.")
if start == stop:
raise ValueError("Input 'start' and 'stop' must be different values.")
if not abs(step) < abs(stop - start):
raise ValueError("Input 'step' is too large given values for 'start' and 'stop'.")
[docs]def param_iter(params):
"""Create a generator to iterate across parameter ranges.
Parameters
----------
params : list of floats and Stepper
Parameters over which to iterate, including a Stepper object.
The Stepper defines the iterated parameter and its range.
Floats define the other parameters, that will be held constant.
Yields
------
list of floats
Next generated list of parameters.
Raises
------
ValueError
If the number of Stepper objects given is greater than one.
Examples
--------
Iterate across exponent values from 1 to 2, in steps of 0.1:
>>> aps = param_iter([Stepper(1, 2, 0.1), 1])
Iterate over center frequency values from 8 to 12 in increments of 0.25:
>>> peaks = param_iter([Stepper(8, 12, .25), 0.5, 1])
"""
# If input is a list of lists, check each element, and flatten if needed
if isinstance(params[0], list):
params = [item for sublist in params for item in sublist]
# Finds where Stepper object is. If there is more than one, raise an error
iter_ind = 0
num_iters = 0
for cur_ind, param in enumerate(params):
if isinstance(param, Stepper):
num_iters += 1
iter_ind = cur_ind
if num_iters > 1:
raise ValueError("Iteration is only supported across one parameter at a time.")
# Generate and yield next set of parameters
gen = params[iter_ind]
while True:
try:
params[iter_ind] = next(gen)
yield params
except StopIteration:
return
[docs]def param_sampler(params, probs=None):
"""Create a generator to sample randomly from possible parameters.
Parameters
----------
params : list of lists or list of float
Possible parameter values.
probs : list of float, optional
Probabilities with which to sample each parameter option.
If None, each parameter option is sampled uniformly.
Yields
------
list of float
A randomly sampled set of parameters.
Examples
--------
Sample from aperiodic definitions with high and low exponents, with 50% probability of each:
>>> aps = param_sampler([[1, 1], [2, 1]], probs=[0.5, 0.5])
Sample from peak definitions of alpha or alpha & beta, with 75% change of sampling just alpha:
>>> peaks = param_sampler([[10, 1, 1], [[10, 1, 1], [20, 0.5, 1]]], probs=[0.75, 0.25])
"""
# If input is a list of lists, check each element, and flatten if needed
if isinstance(params[0], list):
params = [check_flat(lst) for lst in params]
# In order to use numpy's choice, with probabilities, choices are made on indices
# This is because the params can be a messy-sized list, that numpy choice does not like
inds = np.array(range(len(params)))
# Check that length of options is same as length of probs, if provided
if np.any(probs):
if len(inds) != len(probs):
raise ValueError("The number of options must match the number of probabilities.")
# While loop allows the generator to be called an arbitrary number of times
while True:
yield params[np.random.choice(inds, p=probs)]
[docs]def param_jitter(params, jitters):
"""Create a generator that adds jitter to parameter definitions.
Parameters
----------
params : list of lists or list of float
Possible parameter values.
jitters : list of lists or list of float
The scale of the jitter for each parameter.
Must be the same shape and organization as `params`.
Yields
------
list of float
A jittered set of parameters.
Notes
-----
- Jitter is added as random samples from a normal (gaussian) distribution.
- The jitter specified corresponds to the standard deviation of the normal distribution.
- For any parameter for which there should be no jitter, set the corresponding value to zero.
Examples
--------
Jitter aperiodic definitions, for offset and exponent, each with the same amount of jitter:
>>> aps = param_jitter([1, 1], [0.1, 0.1])
Jitter center frequency of peak definitions, by different amounts for alpha & beta:
>>> peaks = param_jitter([[10, 1, 1], [20, 1, 1]], [[0.1, 0, 0], [0.5, 0, 0]])
"""
# Check if inputs are list of lists, and flatten if so
if isinstance(params[0], list):
params = check_flat(params)
jitters = check_flat(jitters)
# While loop allows the generator to be called an arbitrary number of times
while True:
out_params = [None] * len(params)
for ind, (param, jitter) in enumerate(zip(params, jitters)):
out_params[ind] = param + np.random.normal(0, jitter)
yield out_params