"""Class to handle the dipoles."""
# Authors: Mainak Jas <mjas@mgh.harvard.edu>
#          Sam Neymotin <samnemo@gmail.com>
import warnings
import numpy as np
from copy import deepcopy
from .viz import plot_dipole, plot_psd, plot_tfr_morlet
[docs]def simulate_dipole(net, tstop, dt=0.025, n_trials=None, record_vsoma=False,
                    record_isoma=False, postproc=False):
    """Simulate a dipole given the experiment parameters.
    Parameters
    ----------
    net : Network object
        The Network object specifying how cells are
        connected.
    tstop : float
        The simulation stop time (ms).
    dt : float
        The integration time step of h.CVode (ms)
    n_trials : int | None
        The number of trials to simulate. If None, the 'N_trials' value
        of the ``params`` used to create ``net`` is used (must be >0)
    record_vsoma : bool
        Option to record somatic voltages from cells
    record_isoma : bool
        Option to record somatic currents from cells
    postproc : bool
        If True, smoothing (``dipole_smooth_win``) and scaling
        (``dipole_scalefctr``) values are read from the parameter file, and
        applied to the dipole objects before returning. Note that this setting
        only affects the dipole waveforms, and not somatic voltages, possible
        extracellular recordings etc. The preferred way is to use the
        :meth:`~hnn_core.dipole.Dipole.smooth` and
        :meth:`~hnn_core.dipole.Dipole.scale` methods instead. Default: False.
    Returns
    -------
    dpls: list
        List of dipole objects for each trials
    """
    from .parallel_backends import _BACKEND, JoblibBackend
    if _BACKEND is None:
        _BACKEND = JoblibBackend(n_jobs=1)
    if n_trials is None:
        n_trials = net._params['N_trials']
    if n_trials < 1:
        raise ValueError("Invalid number of simulations: %d" % n_trials)
    if not net.connectivity:
        raise Warning('No connections instantiated in network. Consider using '
                      'net = jones_2009_model() or net = law_2021_model() to '
                      'create a predefined network from published models.')
    for drive_name, drive in net.external_drives.items():
        if 'tstop' in drive['dynamics']:
            if drive['dynamics']['tstop'] is None:
                drive['dynamics']['tstop'] = tstop
    for bias_name, bias in net.external_biases.items():
        for cell_type, bias_cell_type in bias.items():
            if bias_cell_type['tstop'] is None:
                bias_cell_type['tstop'] = tstop
            if bias_cell_type['tstop'] < 0.:
                raise ValueError('End time of tonic input cannot be negative')
            duration = bias_cell_type['tstop'] - bias_cell_type['t0']
            if duration < 0.:
                raise ValueError('Duration of tonic input cannot be negative')
    net._instantiate_drives(n_trials=n_trials, tstop=tstop)
    net._reset_rec_arrays()
    if isinstance(record_vsoma, bool):
        net._params['record_vsoma'] = record_vsoma
    else:
        raise TypeError("record_vsoma must be bool, got %s"
                        % type(record_vsoma).__name__)
    if isinstance(record_isoma, bool):
        net._params['record_isoma'] = record_isoma
    else:
        raise TypeError("record_isoma must be bool, got %s"
                        % type(record_isoma).__name__)
    if postproc:
        warnings.warn('The postproc-argument is deprecated and will be removed'
                      ' in a future release of hnn-core. Please define '
                      'smoothing and scaling explicitly using Dipole methods.',
                      DeprecationWarning)
    dpls = _BACKEND.simulate(net, tstop, dt, n_trials, postproc)
    return dpls 
[docs]def read_dipole(fname):
    """Read dipole values from a file and create a Dipole instance.
    Parameters
    ----------
    fname : str
        Full path to the input file (.txt)
    Returns
    -------
    dpl : Dipole
        The instance of Dipole class
    """
    dpl_data = np.loadtxt(fname, dtype=float)
    dpl = Dipole(dpl_data[:, 0], dpl_data[:, 1:])
    return dpl 
[docs]def average_dipoles(dpls):
    """Compute dipole averages over a list of Dipole objects.
    Parameters
    ----------
    dpls: list of Dipole objects
        Contains list of dipole objects, each with a `data` member containing
        'L2', 'L5' and 'agg' components
    Returns
    -------
    dpl: instance of Dipole
        A new dipole object with each component of `dpl.data` representing the
        average over the same components in the input list
    """
    for dpl_idx, dpl in enumerate(dpls):
        if not isinstance(dpl, Dipole):
            raise ValueError(
                f"All elements in the list should be instances of "
                f"Dipole. Got {type(dpl)}")
        if dpl.nave > 1:
            raise ValueError("Dipole at index %d was already an average of %d"
                             " trials. Cannot reaverage" %
                             (dpl_idx, dpl.nave))
    avg_data = list()
    layers = dpl.data.keys()
    for layer in layers:
        avg_data.append(
            np.mean(np.array([dpl.data[layer] for dpl in dpls]), axis=0)
        )
    avg_data = np.c_[avg_data].T
    avg_dpl = Dipole(dpls[0].times, avg_data)
    # set nave to the number of trials averaged in this dipole
    avg_dpl.nave = len(dpls)
    return avg_dpl 
def _rmse(dpl, exp_dpl, tstart=0.0, tstop=0.0, weights=None):
    """ Calculates RMSE between data in dpl and exp_dpl
    Parameters
    ----------
    dpl: instance of Dipole
        A dipole object with simulated data
    exp_dpl: instance of Dipole
        A dipole object with experimental data
    tstart | None: float
        Time at beginning of range over which to calculate RMSE
    tstop | None: float
        Time at end of range over which to calculate RMSE
    weights | None: array
        An array of weights to be applied to each point in
        simulated dpl. Must have length >= dpl.data
        If None, weights will be replaced with 1's for typical RMSE
        calculation.
    Returns
    -------
    err: float
        Weighted RMSE between data in dpl and exp_dpl
    """
    from scipy import signal
    exp_times = exp_dpl.times
    sim_times = dpl.times
    # do tstart and tstop fall within both datasets?
    # if not, use the closest data point as the new tstop/tstart
    for tseries in [exp_times, sim_times]:
        if tstart < tseries[0]:
            tstart = tseries[0]
        if tstop > tseries[-1]:
            tstop = tseries[-1]
    # make sure start and end times are valid for both dipoles
    exp_start_index = (np.abs(exp_times - tstart)).argmin()
    exp_end_index = (np.abs(exp_times - tstop)).argmin()
    exp_length = exp_end_index - exp_start_index
    sim_start_index = (np.abs(sim_times - tstart)).argmin()
    sim_end_index = (np.abs(sim_times - tstop)).argmin()
    sim_length = sim_end_index - sim_start_index
    if weights is None:
        # weighted RMSE with weights of all 1's is equivalent to
        # normal RMSE
        weights = np.ones(len(sim_times[0:sim_end_index]))
    weights = weights[sim_start_index:sim_end_index]
    dpl1 = dpl.data['agg'][sim_start_index:sim_end_index]
    dpl2 = exp_dpl.data['agg'][exp_start_index:exp_end_index]
    if (sim_length > exp_length):
        # downsample simulation timeseries to match exp data
        dpl1 = signal.resample(dpl1, exp_length)
        weights = signal.resample(weights, exp_length)
        indices = np.where(weights < 1e-4)
        weights[indices] = 0
    elif (sim_length < exp_length):
        # downsample exp timeseries to match simulation data
        dpl2 = signal.resample(dpl2, sim_length)
    return np.sqrt((weights * ((dpl1 - dpl2) ** 2)).sum() / weights.sum())
[docs]class Dipole(object):
    """Dipole class.
    An instance of the ``Dipole``-class contains the simulated dipole moment
    timecourses for L2 and L5 pyramidal cells, as well as their aggregate
    (``'agg'``). The units of the dipole moment are in ``nAm``
    (1e-9 Ampere-meters).
    Parameters
    ----------
    times : array (n_times,)
        The time vector (in ms)
    data : array, shape (n_times x n_layers)
        The data. The first column represents 'agg' (the total diple),
        the second 'L2' layer and the last one 'L5' layer. For experimental
        data, it can contain only one column.
    nave : int
        Number of trials that were averaged to produce this Dipole. Defaults
        to 1
    Attributes
    ----------
    times : array-like
        The time vector (in ms)
    sfreq : float
        The sampling frequency (in Hz)
    data : dict of array
        Dipole moment timecourse arrays with keys 'agg', 'L2' and 'L5'
    nave : int
        Number of trials that were averaged to produce this Dipole
    scale_applied : int or float
        The total factor by which the dipole has been scaled (using
        :meth:`~hnn_core.dipole.Dipole.scale`).
    """
    def __init__(self, times, data, nave=1):  # noqa: D102
        self.times = np.array(times)
        if data.ndim == 1:
            data = data[:, None]
        if data.shape[1] == 3:
            self.data = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]}
        elif data.shape[1] == 1:
            self.data = {'agg': data[:, 0]}
        self.nave = nave
        self.sfreq = 1000. / (times[1] - times[0])  # NB assumes len > 1
        self.scale_applied = 1  # for visualisation
[docs]    def copy(self):
        """Return a copy of the Dipole instance
        Returns
        -------
        dpl_copy : instance of Dipole
            A copy of the Dipole instance.
        """
        return deepcopy(self) 
    def _post_proc(self, window_len, fctr):
        """Apply scaling and smoothing from param-files (DEPRECATE)
        Parameters
        ----------
        window_len : int
            Smoothing window in ms
        fctr : int
            Scaling factor
        """
        self.scale(fctr)
        if window_len > 0:  # this is to allow param-files with len==0
            self.smooth(window_len)
    def _convert_fAm_to_nAm(self):
        """The NEURON simulator output is in fAm, convert to nAm
        NB! Must be run `after` :meth:`Dipole.baseline_renormalization`
        """
        for key in self.data.keys():
            self.data[key] *= 1e-6
[docs]    def scale(self, factor):
        """Scale (multiply) the dipole moment by a fixed factor
        The attribute ``Dipole.scale_applied`` is updated to reflect factors
        applied and displayed in plots.
        Parameters
        ----------
        factor : int
            Scaling factor, applied to the data in-place.
        """
        for key in self.data.keys():
            self.data[key] *= factor
        self.scale_applied *= factor
        return self 
[docs]    def smooth(self, window_len):
        """Smooth the dipole waveform using Hamming-windowed convolution
        Note that this method operates in-place, i.e., it will alter the data.
        If you prefer a filtered copy, consider using the
        :meth:`~hnn_core.dipole.Dipole.copy`-method.
        Parameters
        ----------
        window_len : float
            The length (in ms) of a `~numpy.hamming` window to convolve the
            data with.
        Returns
        -------
        dpl_copy : instance of Dipole
            A copy of the modified Dipole instance.
        """
        from .utils import smooth_waveform
        for key in self.data.keys():
            self.data[key] = smooth_waveform(self.data[key], window_len,
                                             self.sfreq)
        return self 
[docs]    def savgol_filter(self, h_freq):
        """Smooth the dipole waveform using Savitzky-Golay filtering
        Note that this method operates in-place, i.e., it will alter the data.
        If you prefer a filtered copy, consider using the
        :meth:`~hnn_core.dipole.Dipole.copy`-method. The high-frequency cutoff
        value of a Savitzky-Golay filter is approximate; see the SciPy
        reference: :func:`~scipy.signal.savgol_filter`.
        Parameters
        ----------
        h_freq : float or None
            Approximate high cutoff frequency in Hz. Note that this
            is not an exact cutoff, since Savitzky-Golay filtering
            is done using polynomial fits
            instead of FIR/IIR filtering. This parameter is thus used to
            determine the length of the window over which a 5th-order
            polynomial smoothing is applied.
        Returns
        -------
        dpl_copy : instance of Dipole
            A copy of the modified Dipole instance.
        """
        from .utils import _savgol_filter
        if h_freq < 0:
            raise ValueError('h_freq cannot be negative')
        elif h_freq > 0.5 * self.sfreq:
            raise ValueError(
                'h_freq must be less than half the sample rate')
        for key in self.data.keys():
            self.data[key] = _savgol_filter(self.data[key],
                                            h_freq,
                                            self.sfreq)
        return self 
[docs]    def plot(self, tmin=None, tmax=None, layer='agg', decim=None, ax=None,
             color=None, show=True):
        """Simple layer-specific plot function.
        Parameters
        ----------
        tmin : float or None
            Start time of plot (in ms). If None, plot entire simulation.
        tmax : float or None
            End time of plot (in ms). If None, plot entire simulation.
        layer : str
            The layer to plot. Can be one of 'agg', 'L2', and 'L5'
        decimate : int
            Factor by which to decimate the raw dipole traces (optional)
        ax : instance of matplotlib figure | None
            The matplotlib axis
        color : tuple of float
            RGBA value to use for plotting (optional)
        show : bool
            If True, show the figure
        Returns
        -------
        fig : instance of plt.fig
            The matplotlib figure handle.
        """
        return plot_dipole(self, tmin=tmin, tmax=tmax, ax=ax, layer=layer,
                           decim=decim, color=color, show=show) 
[docs]    def plot_psd(self, fmin=0, fmax=None, tmin=None, tmax=None, layer='agg',
                 ax=None, show=True):
        """Plot power spectral density (PSD) of dipole time course
        Applies `~scipy.signal.periodogram` from SciPy with
        ``window='hamming'``.
        Note that no spectral averaging is applied across time, as most
        ``hnn_core`` simulations are short-duration. However, passing a list of
        `Dipole` instances will plot their average (Hamming-windowed) power,
        which resembles the `Welch`-method applied over time.
        Parameters
        ----------
        dpl : instance of Dipole | list of Dipole instances
            The Dipole object.
        fmin : float
            Minimum frequency to plot (in Hz). Default: 0 Hz
        fmax : float
            Maximum frequency to plot (in Hz). Default: None (plot up to
            Nyquist)
        tmin : float or None
            Start time of data to include (in ms). If None, use entire
            simulation.
        tmax : float or None
            End time of data to include (in ms). If None, use entire
            simulation.
        layer : str, default 'agg'
            The layer to plot. Can be one of 'agg', 'L2', and 'L5'
        ax : instance of matplotlib figure | None
            The matplotlib axis.
        show : bool
            If True, show the figure
        Returns
        -------
        fig : instance of matplotlib Figure
            The matplotlib figure handle.
        """
        return plot_psd(self, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax,
                        layer=layer, ax=ax, show=show) 
[docs]    def plot_tfr_morlet(self, freqs, n_cycles=7., tmin=None, tmax=None,
                        layer='agg', decim=None, padding='zeros', ax=None,
                        colormap='inferno', colorbar=True, show=True):
        """Plot Morlet time-frequency representation of dipole time course
        NB: Calls `~mne.time_frequency.tfr_array_morlet`, so ``mne`` must be
        installed.
        Parameters
        ----------
        dpl : instance of Dipole | list of Dipole instances
            The Dipole object. If a list of dipoles is given, the power is
            calculated separately for each trial, then averaged.
        freqs : array
            Frequency range of interest.
        n_cycles : float or array of float, default 7.0
            Number of cycles. Fixed number or one per frequency.
        tmin : float or None
            Start time of plot in milliseconds. If None, plot entire
            simulation.
        tmax : float or None
            End time of plot in milliseconds. If None, plot entire simulation.
        layer : str, default 'agg'
            The layer to plot. Can be one of 'agg', 'L2', and 'L5'
        decim : int or list of int or None (default)
            Optional (integer) factor by which to decimate the raw dipole
            traces. The SciPy function :func:`~scipy.signal.decimate` is used,
            which recommends values <13. To achieve higher decimation factors,
            a list of ints can be provided. These are applied successively.
        padding : str or None
            Optional padding of the dipole time course beyond the plotting
            limits. Possible values are: 'zeros' for padding with 0's
            (default), 'mirror' for mirror-image padding.
        ax : instance of matplotlib figure | None
            The matplotlib axis
        colormap : str
            The name of a matplotlib colormap, e.g., 'viridis'. Default:
            'inferno'
        colorbar : bool
            If True (default), adjust figure to include colorbar.
        show : bool
            If True, show the figure
        Returns
        -------
        fig : instance of matplotlib Figure
            The matplotlib figure handle.
        """
        return plot_tfr_morlet(
            self, freqs, n_cycles=n_cycles, tmin=tmin, tmax=tmax,
            layer=layer, decim=decim, padding=padding, ax=ax,
            colormap=colormap, colorbar=colorbar, show=show) 
    def _baseline_renormalize(self, N_pyr_x, N_pyr_y):
        """Only baseline renormalize if the units are fAm.
        Parameters
        ----------
        N_pyr_x : int
            Nr of cells (x)
        N_pyr_y : int
            Nr of cells (y)
        """
        # N_pyr cells in grid. This is PER LAYER
        N_pyr = N_pyr_x * N_pyr_y
        # dipole offset calculation: increasing number of pyr
        # cells (L2 and L5, simultaneously)
        # with no inputs resulted in an aggregate dipole over the
        # interval [50., 1000.] ms that
        # eventually plateaus at -48 fAm. The range over this interval
        # is something like 3 fAm
        # so the resultant correction is here, per dipole
        # dpl_offset = N_pyr * 50.207
        dpl_offset = {
            # these values will be subtracted
            'L2': N_pyr * 0.0443,
            'L5': N_pyr * -49.0502
            # 'L5': N_pyr * -48.3642,
            # will be calculated next, this is a placeholder
            # 'agg': None,
        }
        # L2 dipole offset can be roughly baseline shifted over
        # the entire range of t
        self.data['L2'] -= dpl_offset['L2']
        # L5 dipole offset should be different for interval [50., 500.]
        # and then it can be offset
        # slope (m) and intercept (b) params for L5 dipole offset
        # uncorrected for N_cells
        # these values were fit over the range [37., 750.)
        m = 3.4770508e-3
        b = -51.231085
        # these values were fit over the range [750., 5000]
        t1 = 750.
        m1 = 1.01e-4
        b1 = -48.412078
        # piecewise normalization
        self.data['L5'][self.times <= 37.] -= dpl_offset['L5']
        self.data['L5'][(self.times > 37.) & (self.times < t1)] -= N_pyr * \
            
(m * self.times[(self.times > 37.) & (self.times < t1)] + b)
        self.data['L5'][self.times >= t1] -= N_pyr * \
            
(m1 * self.times[self.times >= t1] + b1)
        # recalculate the aggregate dipole based on the baseline
        # normalized ones
        self.data['agg'] = self.data['L2'] + self.data['L5']
[docs]    def write(self, fname):
        """Write dipole values to a file.
        Parameters
        ----------
        fname : str
            Full path to the output file (.txt)
        Outputs
        -------
        A tab separatd txt file where rows correspond
            to samples and columns correspond to
            1) time (s),
            2) aggregate current dipole (scaled nAm),
            3) L2/3 current dipole (scaled nAm), and
            4) L5 current dipole (scaled nAm)
        """
        if self.nave > 1:
            warnings.warn("Saving Dipole to file that is an average of %d"
                          " trials" % self.nave)
        X = [self.times]
        fmt = ['%3.3f']
        for data in self.data.values():
            X.append(data)
            fmt.append('%5.4f')
        X = np.r_[X].T
        np.savetxt(fname, X, fmt=fmt, delimiter='\t')