Source code for hnn_core.params

"""Handling of parameters."""

# Authors: Mainak Jas <mjas@mgh.harvard.edu>
#          Sam Neymotin <samnemo@gmail.com>

import json
import fnmatch
import os.path as op
from pathlib import Path
from copy import deepcopy

from .params_default import get_params_default
from .externals.mne import _validate_type


# return number of evoked inputs (proximal, distal)
# using dictionary d (or if d is a string, first load the dictionary from
# filename d)
def _count_evoked_inputs(d):
    nprox = ndist = 0
    for k, _ in d.items():
        if k.startswith('t_'):
            if k.count('evprox') > 0:
                nprox += 1
            elif k.count('evdist') > 0:
                ndist += 1
    return nprox, ndist


def _read_json(param_data):
    """Read param values from a .json file.
    Parameters
    ----------
    param_data : str
        The data read in from the param file

    Returns
    -------
    params_input : dict
        Dictionary of parameters
    """
    return json.loads(param_data)


def _read_legacy_params(param_data):
    """Read param values from a .param file (legacy).
    Parameters
    ----------
    param_data : str
        The data read in from the param file

    Returns
    -------
    params_input : dict
        Dictionary of parameters
    """
    params_input = dict()
    for line in param_data.splitlines():
        split_line = line.lstrip().split(':')
        key, value = [field.strip() for field in split_line]
        try:
            if '.' in value or 'e' in value:
                params_input[key] = float(value)
            else:
                params_input[key] = int(value)
        except ValueError:
            params_input[key] = str(value)

    return params_input


[docs] def read_params(params_fname, file_contents=None): """Read param values from a file (.json or .param). Parameters ---------- params_fname : str Full path to the file (.param) file_contents : str | None If file_contents are provided as a string, it is parsed into a dictionary. Returns ------- params : an instance of Params Params containing parameter values from file """ split_fname = op.splitext(params_fname) ext = split_fname[1] if ext not in ['.json', '.param']: raise ValueError('Unrecognized extension, expected one of' + ' .json, .param. Got %s' % ext) if file_contents is None: with open(params_fname, 'r') as fp: file_contents = fp.read() read_func = {'.json': _read_json, '.param': _read_legacy_params} params_dict = read_func[ext](file_contents) if len(params_dict) == 0: raise ValueError("Failed to read parameters from file: %s" % op.normpath(params_fname)) params = Params(params_dict) return params
def _long_name(short_name): long_name = dict(L2Basket='L2_basket', L5Basket='L5_basket', L2Pyr='L2_pyramidal', L5Pyr='L5_pyramidal') if short_name in long_name: return long_name[short_name] return short_name def _short_name(short_name): long_name = dict(L2_basket='L2Basket', L5_basket='L5Basket', L2_pyramidal='L2Pyr', L5_pyramidal='L5Pyr') if short_name in long_name: return long_name[short_name] return short_name def _extract_bias_specs_from_hnn_params(params, cellname_list): """Create 'bias specification' dicts from saved parameters""" bias_specs = {'tonic': {}} # currently only 'tonic' biases known for cellname in cellname_list: short_name = _short_name(cellname) is_tonic_present = [f'Itonic_{p}_{short_name}_soma' in params for p in ['A', 't0', 'T']] if any(is_tonic_present): if not all(is_tonic_present): raise ValueError( f'Tonic input must have the amplitude, ' f'start time and end time specified. One ' f'or more parameter may be missing for ' f'cell type {cellname}') bias_specs['tonic'][cellname] = { 'amplitude': params[f'Itonic_A_{short_name}_soma'], 't0': params[f'Itonic_t0_{short_name}_soma'], 'tstop': params[f'Itonic_T_{short_name}_soma'] } return bias_specs def _extract_drive_specs_from_hnn_params( params, cellname_list, legacy_mode=False): """Create 'drive specification' dicts from saved parameters""" # convert legacy params-dict to legacy "feeds" dicts p_common, p_unique = create_pext(params, params['tstop']) # Using 'feed' for legacy compatibility, 'drives' for new API drive_specs = dict() for ic, par in enumerate(p_common): if (not legacy_mode) and par['tstop'] < par['t0']: continue feed_name = f'bursty{ic + 1}' drive = dict() drive['type'] = 'bursty' drive['cell_specific'] = False drive['dynamics'] = {'tstart': par['t0'], 'tstart_std': par['t0_stdev'], 'tstop': par['tstop'], 'burst_rate': par['f_input'], 'burst_std': par['stdev'], 'numspikes': par['events_per_cycle'], 'n_drive_cells': par['n_drive_cells'], 'spike_isi': 10} # not exposed in params-files drive['location'] = par['loc'] drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] drive['weights_ampa'] = dict() drive['weights_nmda'] = dict() drive['synaptic_delays'] = dict() for cellname in cellname_list: cname_ampa = _short_name(cellname) + '_ampa' cname_nmda = _short_name(cellname) + '_nmda' if cname_ampa in par: ampa_weight = par[cname_ampa][0] ampa_delay = par[cname_ampa][1] drive['weights_ampa'][cellname] = ampa_weight # NB synaptic delay same for NMDA, read only for AMPA drive['synaptic_delays'][cellname] = ampa_delay if cname_nmda in par: nmda_weight = par[cname_nmda][0] drive['weights_nmda'][cellname] = nmda_weight drive_specs[feed_name] = drive for feed_name, par in p_unique.items(): drive = dict() drive['cell_specific'] = True drive['weights_ampa'] = dict() drive['weights_nmda'] = dict() drive['synaptic_delays'] = dict() if (feed_name.startswith('evprox') or feed_name.startswith('evdist')): drive['type'] = 'evoked' if feed_name.startswith('evprox'): drive['location'] = 'proximal' else: drive['location'] = 'distal' cell_keys_present = [key for key in par if key in cellname_list] sigma = par[cell_keys_present[0]][3] # IID for all cells! n_drive_cells = 'n_cells' if par['sync_evinput']: n_drive_cells = 1 drive['cell_specific'] = False drive['dynamics'] = {'mu': par['t0'], 'sigma': sigma, 'numspikes': par['numspikes'], 'n_drive_cells': n_drive_cells} drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] # XXX Force random states to be the same as HNN-gui for the default # parameter set after increasing the number of bursty drive # gids from 2 to 20 if legacy_mode: drive['event_seed'] -= 18 for cellname in cellname_list: if cellname in par: ampa_weight = par[cellname][0] nmda_weight = par[cellname][1] synaptic_delays = par[cellname][2] drive['weights_ampa'][cellname] = ampa_weight drive['weights_nmda'][cellname] = nmda_weight drive['synaptic_delays'][cellname] = synaptic_delays # Skip drive if not in legacy mode elif feed_name.startswith('extgauss'): if (not legacy_mode) and par[ 'L2_basket'][3] > params['tstop']: continue drive['type'] = 'gaussian' drive['location'] = par['loc'] drive['dynamics'] = {'mu': par['L2_basket'][3], # NB IID 'sigma': par['L2_basket'][4], 'numspikes': 50, # NB hard-coded in GUI! 'sync_within_trial': False} drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] for cellname in cellname_list: if cellname in par: ampa_weight = par[cellname][0] synaptic_delays = par[cellname][3] drive['weights_ampa'][cellname] = ampa_weight drive['synaptic_delays'][cellname] = synaptic_delays drive['weights_nmda'] = dict() # no NMDA weights for Gaussians elif feed_name.startswith('extpois'): if (not legacy_mode) and par['t_interval'][1] < par[ 't_interval'][0]: continue drive['type'] = 'poisson' drive['location'] = par['loc'] drive['space_constant'] = par['lamtha'] drive['event_seed'] = par['prng_seedcore'] rate_params = dict() for cellname in cellname_list: if cellname in par: rate_params[cellname] = par[cellname][3] # XXX correct for non-positive poisson rate constant that # is specified in null poisson drives of legacy # param files if not rate_params[cellname] > 0: rate_params[cellname] = 1 ampa_weight = par[cellname][0] nmda_weight = par[cellname][1] synaptic_delays = par[cellname][2] drive['weights_ampa'][cellname] = ampa_weight drive['weights_nmda'][cellname] = nmda_weight drive['synaptic_delays'][cellname] = synaptic_delays # do NOT allow negative times sometimes used in param-files drive['dynamics'] = {'tstart': max(0, par['t_interval'][0]), 'tstop': max(0, par['t_interval'][1]), 'rate_constant': rate_params} drive_specs[feed_name] = drive return drive_specs class Params(dict): """Params object. Parameters ---------- params_input : dict | None Dictionary of parameters. If None, use default parameters. """ def __init__(self, params_input=None): if params_input is None: params_input = dict() if isinstance(params_input, dict): nprox, ndist = _count_evoked_inputs(params_input) # create default params templated from params_input params_default = get_params_default(nprox, ndist) for key in params_default.keys(): if key in params_input: self[key] = params_input[key] else: self[key] = params_default[key] else: raise ValueError('params_input must be dict or None. Got %s' % type(params_input)) def __repr__(self): """Display the params nicely.""" return json.dumps(self, sort_keys=True, indent=4) def __getitem__(self, key): """Return a subset of parameters.""" keys = self.keys() if key in keys: return dict.__getitem__(self, key) else: matches = fnmatch.filter(keys, key) if len(matches) == 0: return dict.__getitem__(self, key) params = self.copy() for key in keys: if key not in matches: params.pop(key) return params def __setitem__(self, key, value): """Set the value for a subset of parameters.""" keys = self.keys() if key in keys: return dict.__setitem__(self, key, value) else: matches = fnmatch.filter(keys, key) if len(matches) == 0: return dict.__setitem__(self, key, value) for key in keys: if key in matches: self.update({key: value}) def copy(self): return deepcopy(self) def write(self, fname): """Write param values to a file. Parameters ---------- fname : str Full path to the output file (.json) """ with open(fname, 'w') as fp: json.dump(self, fp) def _validate_feed(p_ext_d, tstop): """Validate external inputs that are fed to all cells uniformly (i.e., rather than individually). For now, this only includes rhythmic inputs. Parameters ---------- p_ext_d : dict The parameter set to validate and append to p_ext. tstop : float Stop time of the simulation. Returns ------- p_ext : list Cumulative list of dicts with newly appended ExtFeed. """ # # reset tstop if the specified tstop exceeds the # # simulation runtime if p_ext_d['tstop'] > tstop: p_ext_d['tstop'] = tstop # if stdev is zero, increase synaptic weights 5 fold to make # single input equivalent to 5 simultaneous input to prevent spiking # <<---- SN: WHAT IS THIS RULE!?!?!? if not p_ext_d['stdev']: for key in p_ext_d.keys(): if key.endswith('Pyr'): p_ext_d[key] = (p_ext_d[key][0] * 5., p_ext_d[key][1]) elif key.endswith('Basket'): p_ext_d[key] = (p_ext_d[key][0] * 5., p_ext_d[key][1]) # if L5 delay is -1, use same delays as L2 unless L2 delay is 0.1 in # which case use 1. <<---- SN: WHAT IS THIS RULE!?!?!? if p_ext_d['L5Pyr_ampa'][1] == -1: for key in p_ext_d.keys(): if key.startswith('L5'): if p_ext_d['L2Pyr'][1] != 0.1: p_ext_d[key] = (p_ext_d[key][0], p_ext_d['L2Pyr'][1]) else: p_ext_d[key] = (p_ext_d[key][0], 1.) return p_ext_d def check_evoked_synkeys(p, nprox, ndist): # make sure ampa,nmda gbar values are in the param dict for evoked # inputs(for backwards compatibility) # evoked distal target cell types lctprox = ['L2Pyr', 'L5Pyr', 'L2Basket', 'L5Basket'] # evoked proximal target cell types lctdist = ['L2Pyr', 'L5Pyr', 'L2Basket'] lsy = ['ampa', 'nmda'] # synapse types used in evoked inputs for nev, pref, lct in zip([nprox, ndist], ['evprox_', 'evdist_'], [lctprox, lctdist]): for i in range(nev): skey = pref + str(i + 1) for sy in lsy: for ct in lct: k = 'gbar_' + skey + '_' + ct + '_' + sy # if the synapse-specific gbar not present, use the # existing weight for both ampa,nmda if k not in p: p[k] = p['gbar_' + skey + '_' + ct] # def check_pois_synkeys(p): # make sure ampa,nmda gbar values are in the param dict for Poisson inputs # (for backwards compatibility) lct = ['L2Pyr', 'L5Pyr', 'L2Basket', 'L5Basket'] # target cell types lsy = ['ampa', 'nmda'] # synapse types used in Poisson inputs for ct in lct: for sy in lsy: k = ct + '_Pois_A_weight_' + sy # if the synapse-specific weight not present, set it to 0 in p if k not in p: p[k] = 0.0 # creates the external feed params based on individual simulation params p def create_pext(p, tstop): """Indexable Python list of param dicts for parallel. Turn off individual feeds by commenting out relevant line here. always valid, no matter the length. Parameters ---------- p : dict The parameters returned by ExpParams(f_psim).return_pdict() """ p_common = list() # p_unique is a dict of input param types that end up going to each cell # uniquely p_unique = dict() # default params for common proximal inputs feed_prox = { 'f_input': p['f_input_prox'], 't0': p['t0_input_prox'], 'tstop': p['tstop_input_prox'], 'stdev': p['f_stdev_prox'], 'L2Pyr_ampa': (p['input_prox_A_weight_L2Pyr_ampa'], p['input_prox_A_delay_L2']), 'L2Pyr_nmda': (p['input_prox_A_weight_L2Pyr_nmda'], p['input_prox_A_delay_L2']), 'L5Pyr_ampa': (p['input_prox_A_weight_L5Pyr_ampa'], p['input_prox_A_delay_L5']), 'L5Pyr_nmda': (p['input_prox_A_weight_L5Pyr_nmda'], p['input_prox_A_delay_L5']), 'L2Basket_ampa': (p['input_prox_A_weight_L2Basket_ampa'], p['input_prox_A_delay_L2']), 'L2Basket_nmda': (p['input_prox_A_weight_L2Basket_nmda'], p['input_prox_A_delay_L2']), 'L5Basket_ampa': (p['input_prox_A_weight_L5Basket_ampa'], p['input_prox_A_delay_L5']), 'L5Basket_nmda': (p['input_prox_A_weight_L5Basket_nmda'], p['input_prox_A_delay_L5']), 'events_per_cycle': p['events_per_cycle_prox'], 'prng_seedcore': int(p['prng_seedcore_input_prox']), 'lamtha': 100., 'loc': 'proximal', 'n_drive_cells': p['repeats_prox'], 't0_stdev': p['t0_input_stdev_prox'], 'threshold': p['threshold'] } # ensures time interval makes sense p_common.append(_validate_feed(feed_prox, tstop)) # default params for common distal inputs feed_dist = { 'f_input': p['f_input_dist'], 't0': p['t0_input_dist'], 'tstop': p['tstop_input_dist'], 'stdev': p['f_stdev_dist'], 'L2Pyr_ampa': (p['input_dist_A_weight_L2Pyr_ampa'], p['input_dist_A_delay_L2']), 'L2Pyr_nmda': (p['input_dist_A_weight_L2Pyr_nmda'], p['input_dist_A_delay_L2']), 'L5Pyr_ampa': (p['input_dist_A_weight_L5Pyr_ampa'], p['input_dist_A_delay_L5']), 'L5Pyr_nmda': (p['input_dist_A_weight_L5Pyr_nmda'], p['input_dist_A_delay_L5']), 'L2Basket_ampa': (p['input_dist_A_weight_L2Basket_ampa'], p['input_dist_A_delay_L2']), 'L2Basket_nmda': (p['input_dist_A_weight_L2Basket_nmda'], p['input_dist_A_delay_L2']), 'events_per_cycle': p['events_per_cycle_dist'], 'prng_seedcore': int(p['prng_seedcore_input_dist']), 'lamtha': 100., 'loc': 'distal', 'n_drive_cells': p['repeats_dist'], 't0_stdev': p['t0_input_stdev_dist'], 'threshold': p['threshold'] } p_common.append(_validate_feed(feed_dist, tstop)) nprox, ndist = _count_evoked_inputs(p) # print('nprox,ndist evoked inputs:', nprox, ndist) # NEW: make sure all evoked synaptic weights present # (for backwards compatibility) # could cause differences between output of param files # since some nmda weights should be 0 while others > 0 # XXX dangerzone: params are modified in-place, values are imputed if # deemed missing (e.g. if 'gbar_evprox_1_L2Pyr_nmda' is not defined, the # code adds it to the p-dict with value: p['gbar_evprox_1_L2Pyr']) check_evoked_synkeys(p, nprox, ndist) # Create proximal evoked response parameters # f_input needs to be defined as 0 for i in range(nprox): skey = 'evprox_' + str(i + 1) p_unique['evprox' + str(i + 1)] = { 't0': p['t_' + skey], 'L2_pyramidal': (p['gbar_' + skey + '_L2Pyr_ampa'], p['gbar_' + skey + '_L2Pyr_nmda'], 0.1, p['sigma_t_' + skey]), 'L2_basket': (p['gbar_' + skey + '_L2Basket_ampa'], p['gbar_' + skey + '_L2Basket_nmda'], 0.1, p['sigma_t_' + skey]), 'L5_pyramidal': (p['gbar_' + skey + '_L5Pyr_ampa'], p['gbar_' + skey + '_L5Pyr_nmda'], 1., p['sigma_t_' + skey]), 'L5_basket': (p['gbar_' + skey + '_L5Basket_ampa'], p['gbar_' + skey + '_L5Basket_nmda'], 1., p['sigma_t_' + skey]), 'prng_seedcore': int(p['prng_seedcore_' + skey]), 'lamtha': 3., 'loc': 'proximal', 'sync_evinput': p['sync_evinput'], 'threshold': p['threshold'], 'numspikes': p['numspikes_' + skey] } # Create distal evoked response parameters # f_input needs to be defined as 0 for i in range(ndist): skey = 'evdist_' + str(i + 1) p_unique['evdist' + str(i + 1)] = { 't0': p['t_' + skey], 'L2_pyramidal': (p['gbar_' + skey + '_L2Pyr_ampa'], p['gbar_' + skey + '_L2Pyr_nmda'], 0.1, p['sigma_t_' + skey]), 'L5_pyramidal': (p['gbar_' + skey + '_L5Pyr_ampa'], p['gbar_' + skey + '_L5Pyr_nmda'], 0.1, p['sigma_t_' + skey]), 'L2_basket': (p['gbar_' + skey + '_L2Basket_ampa'], p['gbar_' + skey + '_L2Basket_nmda'], 0.1, p['sigma_t_' + skey]), 'prng_seedcore': int(p['prng_seedcore_' + skey]), 'lamtha': 3., 'loc': 'distal', 'sync_evinput': p['sync_evinput'], 'threshold': p['threshold'], 'numspikes': p['numspikes_' + skey] } # this needs to create many feeds # (amplitude, delay, mu, sigma). ordered this way to preserve compatibility # NEW: note double weight specification since only use ampa for gauss # inputs p_unique['extgauss'] = { 'stim': 'gaussian', 'L2_basket': (p['L2Basket_Gauss_A_weight'], p['L2Basket_Gauss_A_weight'], 1., p['L2Basket_Gauss_mu'], p['L2Basket_Gauss_sigma']), 'L2_pyramidal': (p['L2Pyr_Gauss_A_weight'], p['L2Pyr_Gauss_A_weight'], 0.1, p['L2Pyr_Gauss_mu'], p['L2Pyr_Gauss_sigma']), 'L5_basket': (p['L5Basket_Gauss_A_weight'], p['L5Basket_Gauss_A_weight'], 1., p['L5Basket_Gauss_mu'], p['L5Basket_Gauss_sigma']), 'L5_pyramidal': (p['L5Pyr_Gauss_A_weight'], p['L5Pyr_Gauss_A_weight'], 1., p['L5Pyr_Gauss_mu'], p['L5Pyr_Gauss_sigma']), 'lamtha': 100., 'prng_seedcore': int(p['prng_seedcore_extgauss']), 'loc': 'proximal', 'threshold': p['threshold'] } check_pois_synkeys(p) # Poisson distributed inputs to proximal # NEW: setting up AMPA and NMDA for Poisson inputs; why delays differ? p_unique['extpois'] = { 'stim': 'poisson', 'L2_basket': (p['L2Basket_Pois_A_weight_ampa'], p['L2Basket_Pois_A_weight_nmda'], 1., p['L2Basket_Pois_lamtha']), 'L2_pyramidal': (p['L2Pyr_Pois_A_weight_ampa'], p['L2Pyr_Pois_A_weight_nmda'], 0.1, p['L2Pyr_Pois_lamtha']), 'L5_basket': (p['L5Basket_Pois_A_weight_ampa'], p['L5Basket_Pois_A_weight_nmda'], 1., p['L5Basket_Pois_lamtha']), 'L5_pyramidal': (p['L5Pyr_Pois_A_weight_ampa'], p['L5Pyr_Pois_A_weight_nmda'], 1., p['L5Pyr_Pois_lamtha']), 'lamtha': 100., 'prng_seedcore': int(p['prng_seedcore_extpois']), 't_interval': (p['t0_pois'], p['T_pois']), 'loc': 'proximal', 'threshold': p['threshold'] } return p_common, p_unique # Takes two dictionaries (d1 and d2) and compares the keys in d1 to those in d2 # if any match, updates the (key, value) pair of d1 to match that of d2 # not real happy with variable names, but will have to do for now def compare_dictionaries(d1, d2): # iterate over intersection of key sets (i.e. any common keys) for key in d1.keys() and d2.keys(): # update d1 to have same (key, value) pair as d2 d1[key] = d2[key] return d1 def _any_positive_weights(drive): """ Checks a drive for any positive weights. """ weights = (list(drive['weights_ampa'].values()) + list(drive['weights_nmda'].values())) if any([val > 0 for val in weights]): return True else: return False def remove_nulled_drives(net): """Removes drives from network if they have been given null parameters. Legacy param files contained parameter placeholders for non-functional drives. These drives were nulled by assigning values outside typical ranges. This function removes drives on the following conditions: 1. Start time is larger than stop time 2. All weights are non-positive Parameters ---------- net : Network object Returns ------- net : Network object """ from .network import pick_connection net = deepcopy(net) drives_copy = net.external_drives.copy() extras = dict() for drive_name, drive in net.external_drives.items(): conn_indices = pick_connection(net, src_gids=drive_name) space_constant = net.connectivity[conn_indices[0]]['nc_dict']['lamtha'] probability = net.connectivity[conn_indices[0]]['probability'] extras[drive_name] = {'space_constant': space_constant, 'probability': probability} net.clear_drives() for drive_name, drive in drives_copy.items(): # Do not add drive if tstart is > tstop, or negative t_start = drive['dynamics'].get('tstart') t_stop = drive['dynamics'].get('tstop') if (t_start is not None and t_stop is not None and ((t_start > t_stop) or (t_start < 0) or (t_stop < 0))): continue # Do not add if all 0 weights elif not _any_positive_weights(drive): continue else: # Set n_drive_cells to 'n_cells' if equal to max number of cells if drive['cell_specific']: drive['n_drive_cells'] = 'n_cells' net._attach_drive(drive['name'], drive, drive['weights_ampa'], drive['weights_nmda'], drive['location'], extras[drive_name]['space_constant'], drive['synaptic_delays'], drive['n_drive_cells'], drive['cell_specific'], extras[drive_name]['probability']) return net def convert_to_json(params_fname, out_fname, include_drives=True, overwrite=True): """Converts legacy json or param format to hierarchical json format Parameters ---------- params_fname : str or Path Path to file out_fname: str Path to output include_drives: bool, default=True Include drives from params file overwrite: bool, default=True Overwrite file Returns ------- None """ from .network_models import jones_2009_model # Validate inputs _validate_type(params_fname, (str, Path), 'params_fname') _validate_type(out_fname, (str, Path), 'out_fname') # Convert to Path params_fname = Path(params_fname) out_fname = Path(out_fname) params_suffix = params_fname.suffix.lower().split('.')[-1] # Add suffix if not supplied if out_fname.suffix != '.json': out_fname = out_fname.with_suffix('.json') net = jones_2009_model(params=read_params(params_fname), add_drives_from_params=include_drives, legacy_mode=(True if params_suffix == 'param' else False), ) # Remove drives that have null attributes net = remove_nulled_drives(net) net.write_configuration(fname=out_fname, overwrite=overwrite, ) return # debug test function if __name__ == '__main__': fparam = 'param/debug.param'