"""Class to handle the dipoles."""
# Authors: Mainak Jas <mjas@mgh.harvard.edu>
# Sam Neymotin <samnemo@gmail.com>
import os
import warnings
from io import StringIO
import numpy as np
from copy import deepcopy
from h5io import write_hdf5, read_hdf5
from .externals.mne import _check_option
from .viz import plot_dipole, plot_psd, plot_tfr_morlet
[docs]
def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsec=False,
record_isec=False, record_ca=False, postproc=False):
"""Simulate a dipole given the experiment parameters.
Parameters
----------
net : Network object
The Network object specifying how cells are
connected.
tstop : float
The simulation stop time (ms).
dt : float
The integration time step of h.CVode (ms)
n_trials : int | None
The number of trials to simulate. If None, the 'N_trials' value
of the ``params`` used to create ``net`` is used (must be >0)
record_vsec : 'all' | 'soma' | False
Option to record voltages from all sections ('all'), or just
the soma ('soma'). Default: False.
record_isec : 'all' | 'soma' | False
Option to record synaptic currents from all sections ('all'), or just
the soma ('soma'). Default: False.
record_ca : 'all' | 'soma' | False
Option to record calcium concentration from all sections ('all'),
or just the soma ('soma'). Default: False.
postproc : bool
If True, smoothing (``dipole_smooth_win``) and scaling
(``dipole_scalefctr``) values are read from the parameter file, and
applied to the dipole objects before returning. Note that this setting
only affects the dipole waveforms, and not somatic voltages, possible
extracellular recordings etc. The preferred way is to use the
:meth:`~hnn_core.dipole.Dipole.smooth` and
:meth:`~hnn_core.dipole.Dipole.scale` methods instead. Default: False.
Returns
-------
dpls: list
List of dipole objects for each trials
"""
from .parallel_backends import _BACKEND, JoblibBackend
if _BACKEND is None:
_BACKEND = JoblibBackend(n_jobs=1)
if n_trials is None:
n_trials = net._params['N_trials']
if n_trials < 1:
raise ValueError("Invalid number of simulations: %d" % n_trials)
if not net.connectivity:
warnings.warn('No connections instantiated in network. Consider using '
'net = jones_2009_model() or net = law_2021_model() to '
'create a predefined network from published models.',
UserWarning)
# ADD DRIVE WARNINGS HERE
if not net.external_drives and not net.external_biases:
warnings.warn('No external drives or biases loaded', UserWarning)
for drive_name, drive in net.external_drives.items():
if 'tstop' in drive['dynamics']:
if drive['dynamics']['tstop'] is None:
drive['dynamics']['tstop'] = tstop
for bias_name, bias in net.external_biases.items():
for cell_type, bias_cell_type in bias.items():
if bias_cell_type['tstop'] is None:
bias_cell_type['tstop'] = tstop
if bias_cell_type['tstop'] < 0.:
raise ValueError('End time of tonic input cannot be negative')
duration = bias_cell_type['tstop'] - bias_cell_type['t0']
if duration < 0.:
raise ValueError('Duration of tonic input cannot be negative')
net._instantiate_drives(n_trials=n_trials, tstop=tstop)
net._reset_rec_arrays()
_check_option('record_vsec', record_vsec, ['all', 'soma', False])
net._params['record_vsec'] = record_vsec
_check_option('record_isec', record_isec, ['all', 'soma', False])
net._params['record_isec'] = record_isec
_check_option('record_ca', record_ca, ['all', 'soma', False])
net._params['record_ca'] = record_ca
net._tstop = tstop
net._dt = dt
if postproc:
warnings.warn('The postproc-argument is deprecated and will be removed'
' in a future release of hnn-core. Please define '
'smoothing and scaling explicitly using Dipole methods.',
DeprecationWarning)
dpls = _BACKEND.simulate(net, tstop, dt, n_trials, postproc)
return dpls
def _read_dipole_txt(fname, extension='.txt'):
"""Read dipole values from a txt file and create a Dipole instance.
Parameters
----------
fname : str or io.StringIO
Full path to the input file (.txt or .csv) or
Content of file in memory as a StringIO
Returns
-------
dpl : Dipole
The instance of Dipole class
"""
if extension == '.csv':
# read from a csv file ignoring the headers
dpl_data = np.genfromtxt(fname, delimiter=',',
skip_header=1, dtype=float)
else:
dpl_data = np.loadtxt(fname, dtype=float)
ncols = dpl_data.shape[1]
if ncols not in (2, 4):
raise ValueError(
f'Data are supposed to have 2 or 4 columns while we have {ncols}.')
dpl = Dipole(dpl_data[:, 0], dpl_data[:, 1:])
return dpl
def _read_dipole_hdf5(fname):
"""Read dipole values from a hdf5 file and create a Dipole instance.
Parameters
----------
fname : str
Full path to the input file (.hdf5)
Returns
-------
dpl : Dipole
The instance of Dipole class
"""
dpl_data = read_hdf5(fname)
if 'object_type' not in dpl_data:
raise NameError('The given file is not compatible. '
'The file should contain information'
' about object type to be read.')
if dpl_data['object_type'] != 'Dipole':
raise ValueError('The object should be of type Dipole. '
'The file contains object of '
'type %s' % (dpl_data['object_type'],))
dpl = Dipole(times=dpl_data['times'],
data=dpl_data['data'],
nave=dpl_data['nave'])
dpl.sfreq = dpl_data['sfreq']
dpl.scale_applied = dpl_data['scale_applied']
return dpl
[docs]
def read_dipole(fname):
"""Read dipole values from a txt or hdf5 file and
create a Dipole instance.
Parameters
----------
fname : str | Path object
Full path to the input file (.txt or .hdf5)
Returns
-------
dpl : Dipole
The instance of Dipole class
"""
fname = str(fname)
if not os.path.exists(fname):
raise FileNotFoundError('File not found at path %s.' % (fname,))
file_extension = os.path.splitext(fname)[-1]
if file_extension == '.txt':
return _read_dipole_txt(fname)
elif file_extension == '.hdf5':
return _read_dipole_hdf5(fname)
else:
raise NameError('File extension should be either txt or hdf5, but the '
'given extension is %s' % (file_extension,))
[docs]
def average_dipoles(dpls):
"""Compute dipole averages over a list of Dipole objects.
Parameters
----------
dpls : list of Dipole objects
Contains list of dipole objects, each with a `data` member containing
'L2', 'L5' and 'agg' components
Returns
-------
dpl : instance of Dipole
A new dipole object with each component of `dpl.data` representing the
average over the same components in the input list
"""
scale_applied = dpls[0].scale_applied
for dpl_idx, dpl in enumerate(dpls):
if dpl.scale_applied != scale_applied:
raise RuntimeError('All dipoles must be scaled equally!')
if not isinstance(dpl, Dipole):
raise ValueError(
f"All elements in the list should be instances of "
f"Dipole. Got {type(dpl)}")
if dpl.nave > 1:
raise ValueError("Dipole at index %d was already an average of %d"
" trials. Cannot reaverage" %
(dpl_idx, dpl.nave))
avg_data = list()
layers = dpl.data.keys()
for layer in layers:
avg_data.append(
np.mean(np.array([dpl.data[layer] for dpl in dpls]), axis=0)
)
avg_data = np.c_[avg_data].T
avg_dpl = Dipole(dpls[0].times, avg_data)
# The averaged scale should equal all scals in the input dpl list.
avg_dpl.scale_applied = scale_applied
# set nave to the number of trials averaged in this dipole
avg_dpl.nave = len(dpls)
return avg_dpl
def _rmse(dpl, exp_dpl, tstart=0.0, tstop=0.0, weights=None):
""" Calculates RMSE between data in dpl and exp_dpl
Parameters
----------
dpl : instance of Dipole
A dipole object with simulated data
exp_dpl : instance of Dipole
A dipole object with experimental data
tstart : None | float
Time at beginning of range over which to calculate RMSE
tstop : None | float
Time at end of range over which to calculate RMSE
weights : None | array
An array of weights to be applied to each point in
simulated dpl. Must have length >= dpl.data
If None, weights will be replaced with 1's for typical RMSE
calculation.
Returns
-------
err : float
Weighted RMSE between data in dpl and exp_dpl
"""
from scipy import signal
exp_times = exp_dpl.times
sim_times = dpl.times
# do tstart and tstop fall within both datasets?
# if not, use the closest data point as the new tstop/tstart
for tseries in [exp_times, sim_times]:
if tstart < tseries[0]:
tstart = tseries[0]
if tstop > tseries[-1]:
tstop = tseries[-1]
# make sure start and end times are valid for both dipoles
exp_start_index = (np.abs(exp_times - tstart)).argmin()
exp_end_index = (np.abs(exp_times - tstop)).argmin()
exp_length = exp_end_index - exp_start_index
sim_start_index = (np.abs(sim_times - tstart)).argmin()
sim_end_index = (np.abs(sim_times - tstop)).argmin()
sim_length = sim_end_index - sim_start_index
if weights is None:
# weighted RMSE with weights of all 1's is equivalent to
# normal RMSE
weights = np.ones(len(sim_times[0:sim_end_index]))
weights = weights[sim_start_index:sim_end_index]
dpl1 = dpl.data['agg'][sim_start_index:sim_end_index]
dpl2 = exp_dpl.data['agg'][exp_start_index:exp_end_index]
if (sim_length > exp_length):
# downsample simulation timeseries to match exp data
dpl1 = signal.resample(dpl1, exp_length)
weights = signal.resample(weights, exp_length)
indices = np.where(weights < 1e-4)
weights[indices] = 0
elif (sim_length < exp_length):
# downsample exp timeseries to match simulation data
dpl2 = signal.resample(dpl2, sim_length)
return np.sqrt((weights * ((dpl1 - dpl2) ** 2)).sum() / weights.sum())
[docs]
class Dipole(object):
"""Dipole class.
An instance of the ``Dipole``-class contains the simulated dipole moment
timecourses for L2 and L5 pyramidal cells, as well as their aggregate
(``'agg'``). The units of the dipole moment are in ``nAm``
(1e-9 Ampere-meters).
Parameters
----------
times : array (n_times,)
The time vector (in ms)
data : array, shape (n_times x n_layers)
The data. The first column represents 'agg' (the total diple),
the second 'L2' layer and the last one 'L5' layer. For experimental
data, it can contain only one column.
nave : int
Number of trials that were averaged to produce this Dipole. Defaults
to 1
Attributes
----------
times : array-like
The time vector (in ms)
sfreq : float
The sampling frequency (in Hz)
data : dict of array
Dipole moment timecourse arrays with keys 'agg', 'L2' and 'L5'
nave : int
Number of trials that were averaged to produce this Dipole
scale_applied : int or float
The total factor by which the dipole has been scaled (using
:meth:`~hnn_core.dipole.Dipole.scale`).
"""
def __init__(self, times, data, nave=1): # noqa: D102
self.times = np.array(times)
if isinstance(data, dict):
self.data = data
else:
if data.ndim == 1:
data = data[:, None]
if data.shape[1] == 3:
self.data = {'agg': data[:, 0], 'L2': data[:, 1],
'L5': data[:, 2]}
elif data.shape[1] == 1:
self.data = {'agg': data[:, 0]}
self.nave = nave
self.sfreq = 1000. / (times[1] - times[0]) # NB assumes len > 1
self.scale_applied = 1 # for visualisation
[docs]
def copy(self):
"""Return a copy of the Dipole instance
Returns
-------
dpl_copy : instance of Dipole
A copy of the Dipole instance.
"""
return deepcopy(self)
def _post_proc(self, window_len, fctr):
"""Apply scaling and smoothing from param-files (DEPRECATE)
Parameters
----------
window_len : int
Smoothing window in ms
fctr : int
Scaling factor
"""
self.scale(fctr)
if window_len > 0: # this is to allow param-files with len==0
self.smooth(window_len)
def _convert_fAm_to_nAm(self):
"""The NEURON simulator output is in fAm, convert to nAm
NB! Must be run `after` :meth:`Dipole.baseline_renormalization`
"""
for key in self.data.keys():
self.data[key] *= 1e-6
[docs]
def scale(self, factor):
"""Scale (multiply) the dipole moment by a fixed factor
The attribute ``Dipole.scale_applied`` is updated to reflect factors
applied and displayed in plots.
Parameters
----------
factor : int
Scaling factor, applied to the data in-place.
"""
for key in self.data.keys():
self.data[key] *= factor
self.scale_applied *= factor
return self
[docs]
def smooth(self, window_len):
"""Smooth the dipole waveform using Hamming-windowed convolution
Note that this method operates in-place, i.e., it will alter the data.
If you prefer a filtered copy, consider using the
:meth:`~hnn_core.dipole.Dipole.copy`-method.
Parameters
----------
window_len : float
The length (in ms) of a `~numpy.hamming` window to convolve the
data with.
Returns
-------
dpl_copy : instance of Dipole
A copy of the modified Dipole instance.
"""
from .utils import smooth_waveform
for key in self.data.keys():
self.data[key] = smooth_waveform(self.data[key], window_len,
self.sfreq)
return self
[docs]
def savgol_filter(self, h_freq):
"""Smooth the dipole waveform using Savitzky-Golay filtering
Note that this method operates in-place, i.e., it will alter the data.
If you prefer a filtered copy, consider using the
:meth:`~hnn_core.dipole.Dipole.copy`-method. The high-frequency cutoff
value of a Savitzky-Golay filter is approximate; see the SciPy
reference: :func:`~scipy.signal.savgol_filter`.
Parameters
----------
h_freq : float or None
Approximate high cutoff frequency in Hz. Note that this
is not an exact cutoff, since Savitzky-Golay filtering
is done using polynomial fits
instead of FIR/IIR filtering. This parameter is thus used to
determine the length of the window over which a 5th-order
polynomial smoothing is applied.
Returns
-------
dpl_copy : instance of Dipole
A copy of the modified Dipole instance.
"""
from .utils import _savgol_filter
if h_freq < 0:
raise ValueError('h_freq cannot be negative')
elif h_freq > 0.5 * self.sfreq:
raise ValueError(
'h_freq must be less than half the sample rate')
for key in self.data.keys():
self.data[key] = _savgol_filter(self.data[key],
h_freq,
self.sfreq)
return self
[docs]
def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
color='k', show=True):
"""Simple layer-specific plot function.
Parameters
----------
layer : str
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decimate : int
Factor by which to decimate the raw dipole traces (optional)
ax : instance of matplotlib figure | None
The matplotlib axis
color : tuple of float
RGBA value to use for plotting. By default, 'k' (black)
show : bool
If True, show the figure
Returns
-------
fig : instance of plt.fig
The matplotlib figure handle.
"""
return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
decim=decim, color=color, show=show)
[docs]
def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
color=None, label=None, ax=None, show=True):
"""Plot power spectral density (PSD) of dipole time course
Applies `~scipy.signal.periodogram` from SciPy with
``window='hamming'``.
Note that no spectral averaging is applied across time, as most
``hnn_core`` simulations are short-duration. However, passing a list of
`Dipole` instances will plot their average (Hamming-windowed) power,
which resembles the `Welch`-method applied over time.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object.
fmin : float
Minimum frequency to plot (in Hz). Default: 0 Hz
fmax : float
Maximum frequency to plot (in Hz). Default: None (plot up to
Nyquist)
tmin : float or None
Start time of data to include (in ms). If None, use entire
simulation.
tmax : float or None
End time of data to include (in ms). If None, use entire
simulation.
layer : str, default 'agg'
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
color : str | tuple | None
The line color of PSD
label : str | None
Line label for PSD
ax : instance of matplotlib figure | None
The matplotlib axis.
show : bool
If True, show the figure
Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""
return plot_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax,
layer=layer, color=color, label=label, ax=ax,
show=show)
[docs]
def plot_tfr_morlet(self, freqs, n_cycles=7., tmin=None, tmax=None,
layer='agg', decim=None, padding='zeros', ax=None,
colormap='inferno', colorbar=True,
colorbar_inside=False, show=True):
"""Plot Morlet time-frequency representation of dipole time course
NB: Calls `~mne.time_frequency.tfr_array_morlet`, so ``mne`` must be
installed.
Parameters
----------
dpl : instance of Dipole | list of Dipole instances
The Dipole object. If a list of dipoles is given, the power is
calculated separately for each trial, then averaged.
freqs : array
Frequency range of interest.
n_cycles : float or array of float, default 7.0
Number of cycles. Fixed number or one per frequency.
tmin : float or None
Start time of plot in milliseconds. If None, plot entire
simulation.
tmax : float or None
End time of plot in milliseconds. If None, plot entire simulation.
layer : str, default 'agg'
The layer to plot. Can be one of 'agg', 'L2', and 'L5'
decim : int or list of int or None (default)
Optional (integer) factor by which to decimate the raw dipole
traces. The SciPy function :func:`~scipy.signal.decimate` is used,
which recommends values <13. To achieve higher decimation factors,
a list of ints can be provided. These are applied successively.
padding : str or None
Optional padding of the dipole time course beyond the plotting
limits. Possible values are: 'zeros' for padding with 0's
(default), 'mirror' for mirror-image padding.
ax : instance of matplotlib figure | None
The matplotlib axis
colormap : str
The name of a matplotlib colormap, e.g., 'viridis'. Default:
'inferno'
colorbar : bool
If True (default), adjust figure to include colorbar.
colorbar_inside: bool, default False
Put the color inside the heatmap if True.
show : bool
If True, show the figure
Returns
-------
fig : instance of matplotlib Figure
The matplotlib figure handle.
"""
return plot_tfr_morlet(
self, freqs, n_cycles=n_cycles, tmin=tmin, tmax=tmax,
layer=layer, decim=decim, padding=padding, ax=ax,
colormap=colormap, colorbar=colorbar,
colorbar_inside=colorbar_inside, show=show)
def _baseline_renormalize(self, N_pyr_x, N_pyr_y):
"""Only baseline renormalize if the units are fAm.
Parameters
----------
N_pyr_x : int
Nr of cells (x)
N_pyr_y : int
Nr of cells (y)
"""
# N_pyr cells in grid. This is PER LAYER
N_pyr = N_pyr_x * N_pyr_y
# dipole offset calculation: increasing number of pyr
# cells (L2 and L5, simultaneously)
# with no inputs resulted in an aggregate dipole over the
# interval [50., 1000.] ms that
# eventually plateaus at -48 fAm. The range over this interval
# is something like 3 fAm
# so the resultant correction is here, per dipole
# dpl_offset = N_pyr * 50.207
dpl_offset = {
# these values will be subtracted
'L2': N_pyr * 0.0443,
'L5': N_pyr * -49.0502
# 'L5': N_pyr * -48.3642,
# will be calculated next, this is a placeholder
# 'agg': None,
}
# L2 dipole offset can be roughly baseline shifted over
# the entire range of t
self.data['L2'] -= dpl_offset['L2']
# L5 dipole offset should be different for interval [50., 500.]
# and then it can be offset
# slope (m) and intercept (b) params for L5 dipole offset
# uncorrected for N_cells
# these values were fit over the range [37., 750.)
m = 3.4770508e-3
b = -51.231085
# these values were fit over the range [750., 5000]
t1 = 750.
m1 = 1.01e-4
b1 = -48.412078
# piecewise normalization
self.data['L5'][self.times <= 37.] -= dpl_offset['L5']
self.data['L5'][(self.times > 37.) & (self.times < t1)] -= N_pyr * \
(m * self.times[(self.times > 37.) & (self.times < t1)] + b)
self.data['L5'][self.times >= t1] -= N_pyr * \
(m1 * self.times[self.times >= t1] + b1)
# recalculate the aggregate dipole based on the baseline
# normalized ones
self.data['agg'] = self.data['L2'] + self.data['L5']
def _write_txt(self, fname):
"""Write dipole values to a file.
Parameters
----------
fname : str
Full path to the output file (.txt)
Outputs
-------
A tab separatd txt file where rows correspond
to samples and columns correspond to
1) time (ms),
2) aggregate current dipole (scaled nAm),
3) L2/3 current dipole (scaled nAm), and
4) L5 current dipole (scaled nAm)
"""
warnings.warn('Writing dipole to txt file is deprecated '
'and will be removed in future versions. '
'Please use hdf5', DeprecationWarning, stacklevel=2)
if self.nave > 1:
warnings.warn("Saving Dipole to file that is an average of %d"
" trials" % self.nave)
X = [self.times]
fmt = ['%3.3f']
for data in self.data.values():
X.append(data)
fmt.append('%5.4f')
X = np.r_[X].T
np.savetxt(fname, X, fmt=fmt, delimiter='\t')
def _write_hdf5(self, fname):
"""Write dipole values to a hdf5 file.
Parameters
----------
fname : str
Full path to the output file (.hdf5)
Outputs
-------
A hdf5 file containing the dipole object
"""
print(f'Writing file {fname}')
dpl_data = dict()
dpl_data['object_type'] = "Dipole"
dpl_data['times'] = self.times
dpl_data['sfreq'] = self.sfreq
dpl_data['nave'] = self.nave
dpl_data['data'] = self.data
dpl_data['scale_applied'] = self.scale_applied
write_hdf5(fname, dpl_data, overwrite=True)
[docs]
def write(self, fname, overwrite=True):
"""Write dipole values to a txt or hdf5 file.
Parameters
----------
fname : str | Path object
Full path to the output file (.txt or .hdf5)
Outputs
-------
A tab separatd txt file where rows correspond
to samples and columns correspond to
1) time (ms),
2) aggregate current dipole (scaled nAm),
3) L2/3 current dipole (scaled nAm), and
4) L5 current dipole (scaled nAm)
OR
A hdf5 file containing the dipole object
"""
# For supporting tests in test_gui.py
if isinstance(fname, StringIO):
return self._write_txt(fname)
fname = str(fname)
if overwrite is False and os.path.exists(fname):
raise FileExistsError('File already exists at path %s. Rename '
'the file or set overwrite=True.' % (fname,))
file_extension = os.path.splitext(fname)[-1]
if file_extension == '.txt':
self._write_txt(fname)
elif file_extension == '.hdf5':
self._write_hdf5(fname)
else:
raise NameError('File extension should be either txt or hdf5, but '
'the given extension is %s.' % (file_extension,))