Source code for hnn_core.cell_response

"""CellResponse class."""

# Authors: Nick Tolley <nicholas_tolley@brown.edu>
#          Ryan Thorpe <ryan_thorpe@brown.edu>
#          Mainak Jas <mjas@mgh.harvard.edu>

from glob import glob
from warnings import warn

import numpy as np

from .viz import plot_spikes_hist, plot_spikes_raster


[docs] class CellResponse(object): """The CellResponse class. Parameters ---------- spike_times : list (n_trials,) of list (n_spikes,) of float | None Each element of the outer list is a trial. The inner list contains the time stamps of spikes. spike_gids : list (n_trials,) of list (n_spikes,) of float | None Each element of the outer list is a trial. The inner list contains the cell IDs of neurons that spiked. spike_types : list (n_trials,) of list (n_spikes,) of float | None Each element of the outer list is a trial. The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occurred at the corresponding time stamp. Each gid corresponds to a type via Network().gid_ranges. times : numpy array | None Array of time points for samples in continuous data. This includes vsoma and isoma. cell_type_names : list List of unique cell type names that are explicitly modeled in the network Attributes ---------- spike_times : list (n_trials,) of list (n_spikes,) of float Each element of the outer list is a trial. The inner list contains the time stamps of spikes. spike_gids : list (n_trials,) of list (n_spikes,) of float Each element of the outer list is a trial. The inner list contains the cell IDs of neurons that spiked. spike_types : list (n_trials,) of list (n_spikes,) of float Each element of the outer list is a trial. The inner list contains the type of spike (e.g., evprox1 or L2_pyramidal) that occurred at the corresponding time stamp. Each gid corresponds to a type via Network::gid_ranges. vsec : list (n_trials,) of dict Each element of the outer list is a trial. Dictionary indexed by gids containing voltages for cell sections. isec : list (n_trials,) of dict Each element of the outer list is a trial. Dictionary indexed by gids containing currents for cell sections. ca : list (n_trials,) of dict, shape Each element of the outer list is a trial. Dictionary indexed by gids containing calcium concentration for cell sections. times : array-like, shape (n_times,) Array of time points for samples in continuous data. This includes vsoma and isoma. Methods ------- reset() Reset all recorded attributes to empty lists. update_types(gid_ranges) Update spike types in the current instance of CellResponse. plot(ax=None, show=True) Plot and return a matplotlib Figure object showing the aggregate network spiking activity according to cell type. mean_rates(tstart, tstop, gid_ranges, mean_type='all') Calculate mean firing rate for each cell type. Specify averaging method with mean_type argument. write(fname) Write spiking activity to a collection of spike trial files. """ def __init__(self, spike_times=None, spike_gids=None, spike_types=None, times=None, cell_type_names=None): if spike_times is None: spike_times = list() if spike_gids is None: spike_gids = list() if spike_types is None: spike_types = list() if times is None: times = list() if cell_type_names is None: cell_type_names = ['L2_basket', 'L2_pyramidal', 'L5_basket', 'L5_pyramidal'] # Validate arguments arg_names = ['spike_times', 'spike_gids', 'spike_types'] for arg_idx, arg in enumerate([spike_times, spike_gids, spike_types]): # Validate outer list if not isinstance(arg, list): raise TypeError('%s should be a list of lists' % (arg_names[arg_idx],)) # If arg is not an empty list, validate inner list for trial_list in arg: if not isinstance(trial_list, list): raise TypeError('%s should be a list of lists' % (arg_names[arg_idx],)) # Set the length of 'spike_times' as a references and validate # uniform length if arg == spike_times: n_trials = len(spike_times) if len(arg) != n_trials: raise ValueError('spike times, gids, and types should be ' 'lists of the same length') self._spike_times = spike_times self._spike_gids = spike_gids self._spike_types = spike_types self._vsec = list() self._isec = list() self._ca = list() if times is not None: if not isinstance(times, (list, np.ndarray)): raise TypeError("'times' is an np.ndarray of simulation times") self._times = np.array(times) self._cell_type_names = cell_type_names
[docs] def __repr__(self): class_name = self.__class__.__name__ n_trials = len(self._spike_times) return '<%s | %d simulation trials>' % (class_name, n_trials)
def __eq__(self, other): if not isinstance(other, CellResponse): return NotImplemented # Round each time element times_self = [[round(time, 3) for time in trial] for trial in self._spike_times] times_other = [[round(time, 3) for time in trial] for trial in other._spike_times] return (times_self == times_other and self._spike_gids == other._spike_gids and self._spike_types == other._spike_types and self._vsec == other._vsec and self._isec == other._isec and self._ca == other._ca and self.vsec == other.vsec and self.isec == other.isec and self.ca == other.ca) @property def spike_times(self): return self._spike_times @property def spike_gids(self): return self._spike_gids @property def spike_types(self): return self._spike_types @property def vsec(self): return self._vsec @property def isec(self): return self._isec @property def ca(self): return self._ca @property def times(self): return self._times
[docs] def update_types(self, gid_ranges): """Update spike types in the current instance of CellResponse. Parameters ---------- gid_ranges : dict of lists or range objects Dictionary with keys 'evprox1', 'evdist1' etc. containing the range of Cell or input IDs of different cell or input types. """ # Validate gid_ranges all_gid_ranges = list(gid_ranges.values()) for item_idx_1 in range(len(all_gid_ranges)): for item_idx_2 in range(item_idx_1 + 1, len(all_gid_ranges)): gid_set_1 = set(all_gid_ranges[item_idx_1]) gid_set_2 = set(all_gid_ranges[item_idx_2]) if not gid_set_1.isdisjoint(gid_set_2): raise ValueError('gid_ranges should contain only disjoint ' 'sets of gid values') spike_types = list() for trial_idx in range(len(self._spike_times)): spike_types_trial = np.empty_like(self._spike_times[trial_idx], dtype='<U36') for gidtype, gids in gid_ranges.items(): spike_gids_mask = np.isin(self._spike_gids[trial_idx], gids) spike_types_trial[spike_gids_mask] = gidtype spike_types += [list(spike_types_trial)] self._spike_types = spike_types
[docs] def mean_rates(self, tstart, tstop, gid_ranges, mean_type='all'): """Mean spike rates (Hz) by cell type. Parameters ---------- tstart : int | float | None Value defining the start time of all trials. tstop : int | float | None Value defining the stop time of all trials. gid_ranges : dict of lists or range objects Dictionary with keys 'evprox1', 'evdist1' etc. containing the range of Cell or input IDs of different cell or input types. mean_type : str 'all' : Average over trials and cells Returns mean rate for cell types 'trial' : Average over cell types Returns trial mean rate for cell types 'cell' : Average over individual cells Returns trial mean rate for individual cells Returns ------- spike_rate : dict Dictionary with keys 'L5_pyramidal', 'L5_basket', etc. """ spike_rates = dict() if mean_type not in ['all', 'trial', 'cell']: raise ValueError("Invalid mean_type. Valid arguments include " f"'all', 'trial', or 'cell'. Got {mean_type}") # Validate tstart, tstop if not isinstance(tstart, (int, float)) or not isinstance( tstop, (int, float)): raise ValueError('tstart and tstop must be of type int or float') elif tstop <= tstart: raise ValueError('tstop must be greater than tstart') for cell_type in self._cell_type_names: cell_type_gids = np.array(gid_ranges[cell_type]) n_trials, n_cells = len(self._spike_times), len(cell_type_gids) gid_spike_rate = np.zeros((n_trials, n_cells)) trial_data = zip(self._spike_types, self._spike_gids) for trial_idx, (spike_types, spike_gids) in enumerate(trial_data): trial_type_mask = np.isin(spike_types, cell_type) gids, gid_counts = np.unique(np.array( spike_gids)[trial_type_mask], return_counts=True) gid_spike_rate[trial_idx, np.isin(cell_type_gids, gids)] = ( gid_counts / (tstop - tstart)) * 1000 if mean_type == 'all': spike_rates[cell_type] = np.mean( gid_spike_rate.mean(axis=1)) if mean_type == 'trial': spike_rates[cell_type] = np.mean( gid_spike_rate, axis=1).tolist() if mean_type == 'cell': spike_rates[cell_type] = [gid_trial_rate.tolist() for gid_trial_rate in gid_spike_rate] return spike_rates
[docs] def plot_spikes_raster(self, trial_idx=None, ax=None, show=True): """Plot the aggregate spiking activity according to cell type. Parameters ---------- trial_idx : int | list of int | None Index of trials to be plotted. If None, all trials plotted. ax : instance of matplotlib axis | None An axis object from matplotlib. If None, a new figure is created. show : bool If True, show the figure. Returns ------- fig : instance of matplotlib Figure The matplotlib figure object. """ return plot_spikes_raster( cell_response=self, trial_idx=trial_idx, ax=ax, show=show)
[docs] def plot_spikes_hist(self, trial_idx=None, ax=None, spike_types=None, color=None, show=True, **kwargs_hist): """Plot the histogram of spiking activity across trials. Parameters ---------- trial_idx : int | list of int | None Index of trials to be plotted. If None, all trials plotted. ax : instance of matplotlib axis | None An axis object from matplotlib. If None, a new figure is created. spike_types: string | list | dictionary | None String input of a valid spike type is plotted individually. | Ex: ``'poisson'``, ``'evdist'``, ``'evprox'``, ... List of valid string inputs will plot each spike type individually. | Ex: ``['poisson', 'evdist']`` Dictionary of valid lists will plot list elements as a group. | Ex: ``{'Evoked': ['evdist', 'evprox'], 'Tonic': ['poisson']}`` If None, all input spike types are plotted individually if any are present. Otherwise spikes from all cells are plotted. Valid strings also include leading characters of spike types | Ex: ``'ev'`` is equivalent to ``['evdist', 'evprox']`` color : str | list of str | dict | None Input defining colors of plotted histograms. If str, all histograms plotted with same color. If list of str provided, histograms for each spike type will be plotted by cycling through colors in the list. If dict, colors must be specified for all spike_types as a key. If a group of spike types is defined by the `spike_types` parameter (see dictionary example for `spike_types`), the name of this group must be used to specify the colors. | Ex: ``{'evdist': 'g', 'evprox': 'r'}``, ``{'Tonic': 'b'}`` If None, default color cycle used. show : bool If True, show the figure. **kwargs_hist : dict Additional keyword arguments to pass to ax.hist. Returns ------- fig : instance of matplotlib Figure The matplotlib figure handle. """ return plot_spikes_hist(self, trial_idx=trial_idx, ax=ax, spike_types=spike_types, color=color, show=show, **kwargs_hist)
[docs] def to_dict(self): """Return cell response as a dict object. Returns ------- dict object containing the cell response """ cell_response_data = dict() cell_response_data['spike_times'] = self.spike_times cell_response_data['spike_gids'] = self.spike_gids cell_response_data['spike_types'] = self.spike_types vsec_data = self.vsec cell_response_data['vsec'] = list() for trial in vsec_data: # Turn `int` gid keys into string values for hdf5 format trial = dict((str(key), val) for key, val in trial.items()) cell_response_data['vsec'].append(trial) isec_data = self.isec cell_response_data['isec'] = list() for trial in isec_data: # Turn `int` gid keys into string values for hdf5 format trial = dict((str(key), val) for key, val in trial.items()) cell_response_data['isec'].append(trial) ca_data = self.ca cell_response_data['ca'] = list() for trial in ca_data: # Turn `int` gid keys into string values for hdf5 format trial = dict((str(key), val) for key, val in trial.items()) cell_response_data['ca'].append(trial) cell_response_data['times'] = self.times return cell_response_data
[docs] def write(self, fname): """Write spiking activity per trial to a collection of files. Parameters ---------- fname : str String format (e.g., 'spk_%d.txt' or 'spk_{0}.txt') of the path to the output spike file(s). If no string format is provided, the trial index will be automatically appended to the file name. Outputs ------- A tab separated txt file for each trial where rows correspond to spikes, and columns correspond to 1) spike time (ms), 2) spike gid, and 3) gid type """ warn('Writing cell response to txt files is deprecated ' 'and will be removed in future versions. Please save ' 'cell response along with network', DeprecationWarning, stacklevel=2) fname = str(fname) old_style = True try: fname % 0 except TypeError: fname.format(0) old_style = False except TypeError: fname.replace('.txt', '_%d.txt') for trial_idx in range(len(self._spike_times)): if old_style: this_fname = fname % (trial_idx,) else: this_fname = fname.format(trial_idx) print(f'Writing file {this_fname}') with open(this_fname, 'w') as f: for spike_idx in range(len(self._spike_times[trial_idx])): f.write('{:.3f}\t{}\t{}\n'.format( self._spike_times[trial_idx][spike_idx], int(self._spike_gids[trial_idx][spike_idx]), self._spike_types[trial_idx][spike_idx]))
[docs] def read_spikes(fname, gid_ranges=None): """Read spiking activity from a collection of spike trial files. Parameters ---------- fname : str Wildcard expression (e.g., '<pathname>/spk_*.txt') of the path to the spike file(s). gid_ranges : dict of lists or range objects | None Dictionary with keys 'evprox1', 'evdist1' etc. containing the range of Cell or input IDs of different cell or input types. If None, each spike file must contain a 3rd column for spike type. Returns ------- cell_response : CellResponse An instance of the CellResponse object. """ warn('Reading cell response from txt files is deprecated ' 'and will be removed in future versions. Please load ' 'cell response along with simulated network', DeprecationWarning, stacklevel=2) spike_times = list() spike_gids = list() spike_types = list() for file in sorted(glob(str(fname))): spike_trial = np.loadtxt(file, dtype=str) if spike_trial.shape[0] > 0: spike_times += [list(spike_trial[:, 0].astype(float))] spike_gids += [list(spike_trial[:, 1].astype(int))] # Note that legacy HNN 'spk.txt' files don't contain a 3rd column # for spike type. If reading a legacy version, ensure that a # gid_dict is provided. if spike_trial.shape[1] == 3: spike_types += [list(spike_trial[:, 2].astype(str))] else: if gid_ranges is None: raise ValueError("gid_ranges must be provided if spike " "types are unspecified in the " "file %s" % (file,)) spike_types += [list()] else: spike_times += [list()] spike_gids += [list()] spike_types += [list()] cell_response = CellResponse(spike_times=spike_times, spike_gids=spike_gids, spike_types=spike_types) if gid_ranges is not None: cell_response.update_types(gid_ranges) return cell_response