"""Network io"""
# Authors: Rajat Partani <rajatpartani@gmail.com>
# Mainak Jas <mjas@mgh.harvard.edu>
# Nick Tolley <nicholas_tolley@brown.edu>
# George Dang <george_dang@brown.edu>
import os
import json
import numpy as np
from collections import OrderedDict
from pathlib import Path
from .cell import Cell, Section
from .cell_response import CellResponse
from .externals.mne import fill_doc
def _convert_np_array_to_list(obj):
"""Returns object with np.arrays converted to lists
Converts np.arrays to lists. Dicts or lists with nested np.arrays will
have nested arrays converted to lists.
"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {k: _convert_np_array_to_list(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_convert_np_array_to_list(item) for item in obj]
else:
return obj
def _cell_response_to_dict(net, write_output):
"""Returns a dict of cell response data."""
# Write cell_response as dict
if (not net.cell_response) or (not write_output):
return dict()
else:
return net.cell_response.to_dict()
def _rec_array_to_dict(value, write_output):
"""Returns a dict of rec_array data."""
rec_array_copy = value.copy()
if not write_output:
rec_array_copy._reset()
rec_array_copy_dict = rec_array_copy.to_dict()
return rec_array_copy_dict
def _conn_to_dict(conn):
"""Converts a Connectivity object parameters to a dict format."""
conn_data = {
"target_type": conn["target_type"],
"target_gids": list(conn["target_gids"]),
"num_targets": conn["num_targets"],
"src_type": conn["src_type"],
"src_gids": list(conn["src_gids"]),
"num_srcs": conn["num_srcs"],
"gid_pairs": {str(key): val for key, val in conn["gid_pairs"].items()},
"loc": conn["loc"],
"receptor": conn["receptor"],
"nc_dict": conn["nc_dict"],
"allow_autapses": int(conn["allow_autapses"]),
"probability": conn["probability"],
}
return conn_data
def _external_drive_to_dict(drive, write_output):
"""Returns dict of drive data from a Drive object."""
drive_data = dict()
for key in drive.keys():
# Cannot store sets with hdf5
if isinstance(drive[key], set):
drive_data[key] = list(drive[key])
else:
drive_data[key] = drive[key]
if not write_output:
drive_data["events"] = list()
return drive_data
def _str_to_node(node_string):
"""Returns tuple of node values from a comma-separated string format."""
node_tuple = node_string.split(",")
node_tuple[1] = int(node_tuple[1])
node = (node_tuple[0], node_tuple[1])
return node
def _read_cell_types(cell_types_data):
"""Returns a dict of Cell objects from json encoded data
This function handles both legacy format (direct cell data) and
new format (cell object with metadata) for backwards compatibility.
Parameters
----------
cell_types_data : dict
Dictionary containing cell type information in either:
- Legacy format: {cell_name: cell_dict} where cell_dict contains
cell properties directly
- New format: {cell_name: {"cell_object": cell_dict, "cell_metadata": metadata}}
where cell_dict contains cell properties and metadata contains additional info
Returns
-------
cell_types : dict
Dictionary with cell names as keys and dicts containing:
- "cell_object": Cell instance
- "cell_metadata": dict of metadata (empty dict for legacy format)
"""
cell_types = dict()
for cell_name in cell_types_data:
# Determine format and extract cell_data and metadata accordingly
if (
"cell_object" in cell_types_data[cell_name]
and "cell_metadata" in cell_types_data[cell_name]
):
# Format post-commit-6388f9f:
# Extract cell properties from nested "cell_object"
cell_data = cell_types_data[cell_name]["cell_object"]
cell_metadata = cell_types_data[cell_name]["cell_metadata"]
else:
# Format pre-commit-6388f9f:
# Treat the entire cell_data as the cell information
cell_data = cell_types_data[cell_name]
if cell_name == "L2_basket":
cell_metadata = {
"morpho_type": "basket",
"electro_type": "inhibitory",
"layer": "2",
"measure_dipole": False,
"reference": "https://doi.org/10.7554/eLife.51214",
}
elif cell_name == "L2_pyramidal":
cell_metadata = {
"morpho_type": "pyramidal",
"electro_type": "excitatory",
"layer": "2",
"measure_dipole": True,
"reference": "https://doi.org/10.7554/eLife.51214",
}
elif cell_name == "L5_basket":
cell_metadata = {
"morpho_type": "basket",
"electro_type": "inhibitory",
"layer": "5",
"measure_dipole": False,
"reference": "https://doi.org/10.7554/eLife.51214",
}
elif cell_name == "L5_pyramidal":
cell_metadata = {
"morpho_type": "pyramidal",
"electro_type": "excitatory",
"layer": "5",
"measure_dipole": True,
"reference": "https://doi.org/10.7554/eLife.51214",
}
# Now cell_data contains the cell properties regardless of format
sections = dict()
sections_data = cell_data["sections"]
for section_name in sections_data:
section_data = sections_data[section_name]
sections[section_name] = Section(
L=section_data["L"],
diam=section_data["diam"],
cm=section_data["cm"],
Ra=section_data["Ra"],
end_pts=section_data["end_pts"],
)
# Set section attributes
sections[section_name].syns = section_data["syns"]
sections[section_name].mechs = section_data["mechs"]
# cell_tree
cell_tree = None
if cell_data["cell_tree"] is not None:
cell_tree = dict()
for parent, children in cell_data["cell_tree"].items():
key = _str_to_node(parent)
value = list()
for child in children:
value.append(_str_to_node(child))
cell_tree[key] = value
cell_object = Cell(
name=cell_data["name"],
pos=tuple(cell_data["pos"]),
sections=sections,
synapses=cell_data["synapses"],
cell_tree=cell_tree,
sect_loc=cell_data["sect_loc"],
gid=cell_data["gid"],
)
# Set additional cell attributes
cell_object.dipole_pp = cell_data["dipole_pp"]
cell_object.vsec = cell_data["vsec"]
cell_object.isec = cell_data["isec"]
cell_object.tonic_biases = cell_data["tonic_biases"]
# Store in the unified format with cell_metadata
cell_types[cell_name] = {
"cell_object": cell_object,
"cell_metadata": cell_metadata,
}
return cell_types
def _read_cell_response(cell_response_data, read_output):
"""Returns CellResponse from json encoded data"""
if (not cell_response_data) or (not read_output):
return None
cell_response = CellResponse(
cell_type_names=cell_response_data["cell_type_names"],
spike_times=cell_response_data["spike_times"],
spike_gids=cell_response_data["spike_gids"],
spike_types=cell_response_data["spike_types"],
)
cell_response._times = cell_response_data["times"]
cell_response._vsec = list()
for trial in cell_response_data["vsec"]:
trial = dict((int(key), val) for key, val in trial.items())
cell_response._vsec.append(trial)
cell_response._isec = list()
for trial in cell_response_data["isec"]:
trial = dict((int(key), val) for key, val in trial.items())
cell_response._isec.append(trial)
return cell_response
def _set_from_cell_specific(drive_data):
"""Returns number of drive cells based on cell_specific bool
The n_drive_cells keyword for add_poisson_drive and add_bursty_drive
methods accept either an int or string (n_cells). If the bool keyword
cell_specific = True, n_drive_cells must be 'n_cells'.
"""
if drive_data["cell_specific"]:
return "n_cells"
return drive_data["n_drive_cells"]
def _read_external_drive(net, drive_data, read_output):
"""Adds drives encoded in json data to a Network"""
if (drive_data["type"] == "evoked") or (drive_data["type"] == "gaussian"):
# Skipped n_drive_cells here
net.add_evoked_drive(
name=drive_data["name"],
mu=drive_data["dynamics"]["mu"],
sigma=drive_data["dynamics"]["sigma"],
numspikes=drive_data["dynamics"]["numspikes"],
location=drive_data["location"],
n_drive_cells=_set_from_cell_specific(drive_data),
cell_specific=drive_data["cell_specific"],
weights_ampa=drive_data["weights_ampa"],
weights_nmda=drive_data["weights_nmda"],
synaptic_delays=drive_data["synaptic_delays"],
probability=drive_data["probability"],
event_seed=drive_data["event_seed"],
conn_seed=drive_data["conn_seed"],
)
elif drive_data["type"] == "poisson":
net.add_poisson_drive(
name=drive_data["name"],
tstart=drive_data["dynamics"]["tstart"],
tstop=drive_data["dynamics"]["tstop"],
rate_constant=(drive_data["dynamics"]["rate_constant"]),
location=drive_data["location"],
n_drive_cells=(_set_from_cell_specific(drive_data)),
cell_specific=drive_data["cell_specific"],
weights_ampa=drive_data["weights_ampa"],
weights_nmda=drive_data["weights_nmda"],
synaptic_delays=drive_data["synaptic_delays"],
probability=drive_data["probability"],
event_seed=drive_data["event_seed"],
conn_seed=drive_data["conn_seed"],
)
elif drive_data["type"] == "bursty":
net.add_bursty_drive(
name=drive_data["name"],
tstart=drive_data["dynamics"]["tstart"],
tstart_std=drive_data["dynamics"]["tstart_std"],
tstop=drive_data["dynamics"]["tstop"],
burst_rate=drive_data["dynamics"]["burst_rate"],
burst_std=drive_data["dynamics"]["burst_std"],
numspikes=drive_data["dynamics"]["numspikes"],
spike_isi=drive_data["dynamics"]["spike_isi"],
location=drive_data["location"],
n_drive_cells=_set_from_cell_specific(drive_data),
cell_specific=drive_data["cell_specific"],
weights_ampa=drive_data["weights_ampa"],
weights_nmda=drive_data["weights_nmda"],
synaptic_delays=drive_data["synaptic_delays"],
probability=drive_data["probability"],
event_seed=drive_data["event_seed"],
conn_seed=drive_data["conn_seed"],
)
net.external_drives[drive_data["name"]]["events"] = drive_data["events"]
if not read_output:
net.external_drives[drive_data["name"]]["events"] = list()
def _read_connectivity(net, conns_data):
"""Adds connections to a Network from json encoded connectivity"""
# Overwrite drive connections
net.connectivity = list()
for conn_data in conns_data:
src_gids = [int(s) for s in conn_data["gid_pairs"].keys()]
target_gids_nested = [
target_gid for target_gid in conn_data["gid_pairs"].values()
]
conn_data["allow_autapses"] = bool(conn_data["allow_autapses"])
net.add_connection(
src_gids=src_gids,
target_gids=target_gids_nested,
loc=conn_data["loc"],
receptor=conn_data["receptor"],
weight=conn_data["nc_dict"]["A_weight"],
delay=conn_data["nc_dict"]["A_delay"],
lamtha=conn_data["nc_dict"]["lamtha"],
allow_autapses=conn_data["allow_autapses"],
probability=conn_data["probability"],
)
def _read_rec_arrays(net, rec_arrays_data, read_output):
"""Adds rec arrays to Network from json data."""
for key in rec_arrays_data:
rec_array = rec_arrays_data[key]
net.add_electrode_array(
name=key,
electrode_pos=[tuple(pos) for pos in rec_array["positions"]],
conductivity=rec_array["conductivity"],
method=rec_array["method"],
min_distance=rec_array["min_distance"],
)
net.rec_arrays[key]._times = rec_array["times"]
net.rec_arrays[key]._data = rec_array["voltages"]
if not read_output:
net.rec_arrays[key]._reset()
def _read_pos_dict(pos_dict):
"""Returns position dictionary with nested positions converted to tuple."""
pos_dict_converted = dict()
for key, value in pos_dict.items():
if key == "origin":
pos_dict_converted[key] = tuple(value)
else:
pos_dict_converted[key] = [tuple(position) for position in value]
return pos_dict_converted
[docs]
def network_to_dict(net, write_output=False):
"""Returns a dict of parameters and outputs from Network.
Parameters
----------
net : Network
hnn-core Network object
write_output : bool
Includes simulation-associated data.
Returns
-------
dict
"""
# cell_types serialization, support both old and new formats
cell_types_data = {}
for name, template in net.cell_types.items():
if isinstance(template, dict) and "cell_object" in template:
# New format with cell_metadata
cell_types_data[name] = {
"cell_object": template["cell_object"].to_dict(),
"cell_metadata": template["cell_metadata"],
}
else:
# Legacy format, template is a Cell object directly
cell_types_data[name] = template.to_dict()
net_data = {
"object_type": "Network",
"legacy_mode": net._legacy_mode,
"N_pyr_x": net._N_pyr_x,
"N_pyr_y": net._N_pyr_y,
"celsius": net._params["celsius"],
"cell_types": cell_types_data,
"gid_ranges": {
cell: {"start": c_range.start, "stop": c_range.stop}
for cell, c_range in net.gid_ranges.items()
},
"pos_dict": {cell: pos for cell, pos in net.pos_dict.items()},
"cell_response": _cell_response_to_dict(net, write_output),
"external_drives": {
drive: _external_drive_to_dict(params, write_output)
for drive, params in net.external_drives.items()
},
"external_biases": net.external_biases,
"connectivity": [_conn_to_dict(conn) for conn in net.connectivity],
"rec_arrays": {
ra_name: _rec_array_to_dict(ex_array, write_output)
for ra_name, ex_array in net.rec_arrays.items()
},
"threshold": net.threshold,
"delay": net.delay,
}
return net_data
[docs]
@fill_doc
def write_network_configuration(net, output, overwrite=True):
"""Writes network configuration to a json file.
Writes network configurations as a hierarchical json similar to the Network
object's structure. Outputs recorded during simulation such as currents and
voltages are not saved due to size.
Parameters
----------
net : Network
hnn-core Network object
output : str, Path, or StringIO
Path or buffer to write outputs
overwrite : bool
Overwrite file if it exists. Default: True
Returns
-------
None
"""
net_data = net.to_dict(write_output=False)
net_data_converted = _convert_np_array_to_list(net_data)
if isinstance(output, (str, Path)):
if overwrite is False and os.path.exists(output):
raise FileExistsError(
"File already exists at path %s. Rename "
"the file or set overwrite=True." % (output,)
)
# Saving file
with open(output, "w", encoding="utf-8") as f:
json.dump(net_data_converted, f, ensure_ascii=False, indent=4)
else:
# Write to StringIO buffer
json.dump(net_data_converted, output, ensure_ascii=False, indent=4)
output.seek(0) # Reset buffer position to the start
def _order_drives(gid_ranges, external_drives):
"""Returns an ordered dict of external drives by ascending gid ranges
Drive data from hdf5 are ordered alphabetically by name. This function
reorders the external drives by ascending gid ranges.
Parameters
----------
gid_ranges : dict (keys: names) of range
Dictionary with connection or drive names as keys and ranges as values.
external_drives: dict (keys: drive names) of dict (keys: parameters)
Dictionary with drive name as keys and instances of _NetworkDrive as
values.
Returns
-------
OrderedDict : dict (keys: drive names) of dict (keys: parameters)
Ordered dict with drives by ascending gid ranges
"""
ordered_drives = OrderedDict()
min_gid_to_drive = {
min(gid_range): name
for (name, gid_range) in gid_ranges.items()
if name in external_drives.keys()
}
min_gid_sorted = sorted(list(min_gid_to_drive.keys()))
for min_gid in min_gid_sorted:
drive_name = min_gid_to_drive[min_gid]
ordered_drives[drive_name] = external_drives[drive_name]
return ordered_drives
[docs]
def dict_to_network(net_data, read_drives=True, read_external_biases=True):
"""Converts a dict of network configurations to a Network
Parameters
----------
net_data : dict
Dictionary containing network configurations.
read_drives : bool, optional
Read-in drives to Network object. Default is True.
read_external_biases : bool, optional
Read-in external biases to Network object. Default is True.
Returns : Network
-------
"""
# Importing Network.
# Cannot do this globally due to circular import.
from .network import Network
params = dict()
params["celsius"] = net_data["celsius"]
params["threshold"] = net_data["threshold"]
mesh_shape = (net_data["N_pyr_x"], net_data["N_pyr_y"])
# Instantiating network
net = Network(params, mesh_shape=mesh_shape, legacy_mode=net_data["legacy_mode"])
# Setting attributes
# Set cell types
net.cell_types = _read_cell_types(net_data["cell_types"])
# Set gid ranges
gid_ranges_data = dict()
for key in net_data["gid_ranges"]:
start = net_data["gid_ranges"][key]["start"]
stop = net_data["gid_ranges"][key]["stop"]
gid_ranges_data[key] = range(start, stop)
net.gid_ranges = OrderedDict(gid_ranges_data)
# Set pos_dict
net.pos_dict = _read_pos_dict(net_data["pos_dict"])
# Set cell_response
net.cell_response = _read_cell_response(
net_data["cell_response"], read_output=False
)
# Set external drives
external_drive_data = _order_drives(net.gid_ranges, net_data["external_drives"])
for key in external_drive_data.keys():
_read_external_drive(net, external_drive_data[key], read_output=False)
# Set external biases
if read_external_biases:
net.external_biases = net_data["external_biases"]
# Set connectivity
_read_connectivity(net, net_data["connectivity"])
# Set rec_arrays
_read_rec_arrays(net, net_data["rec_arrays"], read_output=False)
# Set threshold
net.threshold = net_data["threshold"]
# Set delay
net.delay = net_data["delay"]
if not read_drives:
net.clear_drives()
return net
[docs]
def read_network_configuration(fname, read_drives=True, read_external_biases=True):
"""Read network from a json configuration file.
Parameters
----------
fname : str or Path
Path to configuration file
read_drives : bool
Read-in drives to Network object
read_external_biases
Read-in external biases to Network object
Returns : Network
-------
"""
with open(fname, "r") as file:
net_data = json.load(file)
if net_data.get("object_type") != "Network":
raise ValueError(
"The json should encode a Network object. "
"The file contains object of "
"type %s" % (net_data.get("object_type"))
)
net = dict_to_network(net_data, read_drives, read_external_biases)
return net