"""Network class."""
# Authors: Mainak Jas <mjas@mgh.harvard.edu>
#          Sam Neymotin <samnemo@gmail.com>
#          Blake Caldwell <blake_caldwell@brown.edu>
#          Christopher Bailey <cjb@cfin.au.dk>
#          Nick Tolley <nicholas_tolley@brown.edu>
#          Ryan Thorpe <ryan_thorpe@brown.edu>
import itertools as it
from copy import deepcopy
from collections import OrderedDict, defaultdict
from typing import Dict
import numpy as np
import warnings
from .cell_response import read_spikes
from .drives import _drive_cell_event_times
from .drives import _get_target_properties, _add_drives_from_params
from .drives import _check_drive_parameter_values, _check_poisson_rates
from .cells_default import pyramidal, basket
from .params import _long_name
from .viz import plot_cells
from .externals.mne import _validate_type, _check_option
from .extracellular import ExtracellularArray
from .check import _check_gids, _gid_to_type, _string_input_to_list
from .hnn_io import write_network_configuration, network_to_dict
from .externals.mne import copy_doc
from .utils import _replace_dict_identifier
def _create_cell_coords(n_pyr_x, n_pyr_y, z_coord, inplane_distance):
    """Creates coordinate grid and place cells in it.
    Parameters
    ----------
    n_pyr_x : int
        The number of Pyramidal cells in x direction.
    n_pyr_y : int
        The number of Pyramidal cells in y direction.
    z_coord : float
        Expressed as a positive DEPTH of L2 relative to L5 pyramidal cell
        somas, where L5 is defined to lie at z==0. Interlaminar weight/delay
        calculations (lamtha) are not affected. The basket cells are
        arbitrarily placed slightly above (L5) and slightly below (L2) their
        respective pyramidal cell layers.
    inplane_distance : float
        The grid spacing of pyramidal cells (in um). Note that basket cells are
        placed in an uneven formation. Each one of them lies on a grid point
        together with a pyramidal cell, though (overlapping).
    Returns
    -------
    layer_dict : dict of list of tuple (x, y, z)
        Dictionary containing coordinate positions of 'layers'. After calling
        '_create_cell_coords', user can create their 'Network.pos_dict' by mapping
        'origin' and their celltypes onto the different layers in 'layer_dict'. Keys are
        'L2_bottom', 'L2_mid', 'L5_bottom', 'L5_mid', and 'origin'.
    Notes
    -----
    Common positions are all located at origin.
    Sort of a hack because of redundancy.
    """
    def _calc_pyramidal_coord(xxrange, yyrange, z_coord):
        list_coords = [pos for pos in it.product(xxrange, yyrange, [z_coord])]
        return list_coords
    def _calc_basket_coord(n_x, n_y, z_coord, inplane_distance, weight):
        xzero = np.arange(0, n_x, 3) * inplane_distance
        xone = np.arange(1, n_x, 3) * inplane_distance
        # split even and odd y vals
        yeven = np.arange(0, n_y, 2) * inplane_distance
        yodd = np.arange(1, n_y, 2) * inplane_distance
        # create general list of x,y coords and sort it
        coords = [pos for pos in it.product(xzero, yeven)] + [
            pos for pos in it.product(xone, yodd)
        ]
        coords_sorted = sorted(coords, key=lambda pos: pos[1])
        # append the z value for position
        list_coords = [
            (pos_xy[0], pos_xy[1], weight * z_coord) for pos_xy in coords_sorted
        ]
        return list_coords
    def _calc_origin(xxrange, yyrange, z_coord):
        # origin's z component isn't used in calculating distance functions.
        # will be used for adding external drives.
        origin_x = xxrange[int((len(xxrange) - 1) // 2)]
        origin_y = yyrange[int((len(yyrange) - 1) // 2)]
        origin_z = np.floor(z_coord / 2)
        origin = (origin_x, origin_y, origin_z)
        return origin
    # Calculate distances
    xxrange = np.arange(n_pyr_x) * inplane_distance
    yyrange = np.arange(n_pyr_y) * inplane_distance
    # Create layer dictionary with anatomical layer positions
    layer_dict = {
        "L5_bottom": _calc_pyramidal_coord(xxrange, yyrange, z_coord=0),
        "L2_bottom": _calc_pyramidal_coord(xxrange, yyrange, z_coord=z_coord),
        "L5_mid": _calc_basket_coord(
            n_pyr_x,
            n_pyr_y,
            z_coord=z_coord,
            inplane_distance=inplane_distance,
            weight=0.2,
        ),
        "L2_mid": _calc_basket_coord(
            n_pyr_x,
            n_pyr_y,
            z_coord=z_coord,
            inplane_distance=inplane_distance,
            weight=0.8,
        ),
        "origin": _calc_origin(xxrange, yyrange, z_coord),
    }
    return layer_dict
def _compare_lists(s, t):
    """
    Compares lists for equality
    From https://stackoverflow.com/a/7829388
    """
    t = list(t)  # make a mutable copy
    try:
        for elem in s:
            t.remove(elem)
    except ValueError:
        return False
    return not t
def _connection_probability(conn, probability, conn_seed=None):
    """Remove/keep a random subset of connections.
    Parameters
    ----------
    conn : Instance of _Connectivity object
        Object specifying the biophysical parameters and src target pairs
        of a specific connection class. Function modifies conn in place.
    probability : float
        Probability of connection between any src-target pair.
        Defaults to 1.0 producing an all-to-all pattern.
    conn_seed : int
        Optional initial seed for random number generator (default: None).
        Used to randomly remove connections when probability < 1.0.
    Notes
    -----
    num_srcs and num_targets are not updated after pruning connections.
    These variables are meant to describe the set of original connections
    before they are randomly removed.
    The probability attribute will store the most recent value passed to
    this function. As such, this number does not accurately describe the
    connections probability of the original set after successive calls.
    """
    # Random number generator for random connection selection
    rng = np.random.default_rng(conn_seed)
    _validate_type(probability, float, "probability")
    if probability <= 0.0 or probability >= 1.0:
        raise ValueError("probability must be in the range (0,1)")
    # Flatten connections into a list of targets.
    all_connections = np.concatenate(
        [target_src_pair for target_src_pair in conn["gid_pairs"].values()]
    )
    n_connections = np.round(len(all_connections) * probability).astype(int)
    # Select a random subset of connections to retain.
    new_connections = rng.choice(
        range(len(all_connections)), n_connections, replace=False
    )
    remove_srcs = list()
    connection_idx = 0
    for src_gid, target_src_pair in conn["gid_pairs"].items():
        target_new = list()
        for target_gid in target_src_pair:
            if connection_idx in new_connections:
                target_new.append(target_gid)
            connection_idx += 1
        # Update targets for src_gid
        if target_new:
            conn["gid_pairs"][src_gid] = target_new
        else:
            remove_srcs.append(src_gid)
    # Remove src_gids with no targets
    for src_gid in remove_srcs:
        conn["gid_pairs"].pop(src_gid)
[docs]
def pick_connection(net, src_gids=None, target_gids=None, loc=None, receptor=None):
    """Returns indices of connections that match search parameters.
    Parameters
    ----------
    net : Instance of Network object
        The Network object
    src_gids : str | int | range | list of int | None
        Identifier for source cells. Passing str arguments
        ('L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket') is
        equivalent to passing a list of gids for the relevant cell type.
        source - target connections are made in an all-to-all pattern.
    target_gids : str | int | range | list of int | None
        Identifier for targets of source cells. Passing str arguments
        ('L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket') is
        equivalent to passing a list of gids for the relevant cell type.
        source - target connections are made in an all-to-all pattern.
    loc : str | list of str | None
        Location of synapse on target cell. Must be
        'proximal', 'distal', or 'soma'. Note that inhibitory synapses
        (receptor='gabaa' or 'gabab') of L2 pyramidal neurons are only
        valid loc='soma'.
    receptor : str | list of str | None
        Synaptic receptor of connection. Must be one of:
        'ampa', 'nmda', 'gabaa', or 'gabab'.
    Returns
    -------
    conn_indices : list of int
        List of indices corresponding to items in net.connectivity.
        Connection indices are included if any of the provided parameter
        values are present in a connection.
    Notes
    -----
    Passing a list of values to a single parameter corresponds to a
    logical OR operation across indices. For example,
    loc=['distal', 'proximal'] returns all connections that target
    distal or proximal dendrites.
    Passing  multiple parameters corresponds to a logical AND operation.
    For example, net.pick_connection(loc='distal', receptor='ampa')
    returns only the indices of connections that target the distal
    dendrites and have ampa receptors.
    """
    # Convert src and target gids to lists
    valid_srcs = list(net.gid_ranges.keys())  # includes drives as srcs
    valid_targets = list(net.cell_types.keys())
    src_gids_checked = _check_gids(
        src_gids, net.gid_ranges, valid_srcs, "src_gids", same_type=False
    )
    target_gids_checked = _check_gids(
        target_gids, net.gid_ranges, valid_targets, "target_gids", same_type=False
    )
    _validate_type(loc, (str, list, None), "loc", "str, list, or None")
    _validate_type(receptor, (str, list, None), "receptor", "str, list, or None")
    valid_loc = ["proximal", "distal", "soma"]
    valid_receptor = ["ampa", "nmda", "gabaa", "gabab"]
    # Convert receptor and loc to list
    loc_list = _string_input_to_list(loc, valid_loc, "loc")
    receptor_list = _string_input_to_list(receptor, valid_receptor, "receptor")
    # Create lookup dictionaries
    src_dict, target_dict = defaultdict(list), defaultdict(list)
    loc_dict, receptor_dict = defaultdict(list), defaultdict(list)
    for conn_idx, conn in enumerate(net.connectivity):
        # Store connections matching each src_gid
        for src_gid in conn["src_gids"]:
            src_dict[src_gid].append(conn_idx)
        # Store connections matching each target_gid
        for target_gid in conn["target_gids"]:
            target_dict[target_gid].append(conn_idx)
        # Store connections matching each location
        loc_dict[conn["loc"]].append(conn_idx)
        # Store connections matching each receptor
        receptor_dict[conn["receptor"]].append(conn_idx)
    # Look up conn indices that match search terms and add to set.
    conn_set = set()
    search_pairs = [
        (src_gids_checked, src_dict),
        (target_gids_checked, target_dict),
        (loc_list, loc_dict),
        (receptor_list, receptor_dict),
    ]
    for search_terms, search_dict in search_pairs:
        if search_terms:
            inner_set = set()
            # Union of indices which match inputs for single parameter
            for term in search_terms:
                inner_set = inner_set.union(search_dict.get(term, list()))
            # Empty search
            if not inner_set:
                return list()
            # Initial search has results
            elif inner_set and not conn_set:
                conn_set = inner_set.copy()
            # Subsequent searches have results
            elif inner_set and conn_set:
                # Intersect across parameters
                conn_set = conn_set.intersection(inner_set)
                # If at any point there's no matching elements, return empty
                if not conn_set:
                    return list()
    return sorted(conn_set) 
def _get_cell_index_by_synapse_type(net):
    """Returns the indices of excitatory and inhibitory cells in the Network.
    This function extracts the source GIDs (cell ID) of excitatory and inhibitory cells
    based on their connection types. Excitatory and inhibitory cells are identified by
    their electrophysiological metadata values. This does *not* return GIDs of external
    drives.
    Parameters
    ----------
    net : Instance of Network object
        The Network object
    Returns
    -------
    e_cell_gids : list
        The source GIDs of excitatory cells.
    i_cell_gids : list
        The source GIDs of inhibitory cells.
    """
    e_cell_gids = list()
    i_cell_gids = list()
    for cell_type_name, cell_data in net.cell_types.items():
        if cell_data["cell_metadata"].get("electro_type") == "excitatory":
            e_cell_gids.extend(net.gid_ranges[cell_type_name])
        elif cell_data["cell_metadata"].get("electro_type") == "inhibitory":
            i_cell_gids.extend(net.gid_ranges[cell_type_name])
    return e_cell_gids, i_cell_gids
[docs]
class Network:
    """The Network class.
    Parameters
    ----------
    params : dict
        The parameters to use for constructing the network.
    add_drives_from_params : bool, default=False
        If True, add drives as defined in the params-dict. NB this is mainly
        for backward-compatibility with HNN GUI, and will be deprecated in a
        future release.
    legacy_mode : bool, default=False
        Set to True by default to enable matching HNN GUI output when drives
        are added suitably. Will be deprecated in a future release.
    mesh_shape : tuple of int (default: (10, 10))
        Defines the (n_x, n_y) shape of the grid of pyramidal cells.
    pos_dict : dict of list of tuple (x, y, z), optional
        Dictionary containing the coordinate positions of all cells.
        Keys are 'L2_pyramidal', 'L5_pyramidal', 'L2_basket', 'L5_basket',
        or any external drive name.
    cell_types : dict of dict of (Cell | dict), optional
        Dictionary containing names of real cell types in the network (e.g. 'L2_basket')
        as keys and child-dictionaries describing the cell type. The child-dictionary
        contains two keys: "cell_object" and "cell_metadata". The value of "cell_object"
        is the corresponding Cell instance of the cell type being described, and this
        instance is used as a template for the other cells of its type in the
        population. The value of "cell_metadata" is a dictionary containing several
        key-values pairs that describe different aspects of the cell type, described
        below:
            - "morpho_type" : either "basket" or "pyramidal"
            - "electro_type" : either "inhibitory" or "excitatory"
            - "layer" : either "2" or "5"
            - "measure_dipole" : either True or False
            - "reference": optional
    Attributes
    ----------
    cell_types : dict of dict of (Cell | dict)
        Dictionary containing names of real cell types in the network (e.g. 'L2_basket')
        as keys and child-dictionaries describing the cell type. The child-dictionary
        contains two keys: "cell_object" and "cell_metadata". The value of "cell_object"
        is the corresponding Cell instance of the cell type being described, and this
        instance is used as a template for the other cells of its type in the
        population. The value of "cell_metadata" is a dictionary containing several
        key-values pairs that describe different aspects of the cell type, described
        below:
            - "morpho_type" : either "basket" or "pyramidal"
            - "electro_type" : either "inhibitory" or "excitatory"
            - "layer" : either "2" or "5"
            - "measure_dipole" : either True or False
            - "reference": optional
    gid_ranges : dict
        A dictionary of unique identifiers of each real and artificial cell
        in the network. Every cell type is represented by a key read from
        cell_types, followed by keys read from external_drives. The value
        of each key is a range of ints, one for each cell in given category.
        Examples: 'L2_basket': range(0, 270), 'evdist1': range(272, 542), etc
    pos_dict : dict of list of tuple (x, y, z)
        Dictionary containing the coordinate positions of all cells.
        Keys are 'L2_pyramidal', 'L5_pyramidal', 'L2_basket', 'L5_basket',
        or any external drive name.
    cell_response : CellResponse
        An instance of the CellResponse object.
    external_drives : dict (keys: drive names) of dict (keys: parameters)
        The external driving inputs to the network. Drives are added by
        defining their spike-time dynamics, and their connectivity to the real
        cells of the network. Event times are instantiated before simulation,
        and are stored under the ``'events'``-key (list of list; first
        index for trials, second for event time lists for each drive cell).
    external_biases : dict of dict (bias parameters for each cell type)
        The parameters of bias inputs to cell somata, e.g., tonic current clamp
    connectivity : list of dict
        List of dictionaries specifying each cell-cell and drive-cell
        connection
    rec_arrays : dict
        Stores electrode position information and voltages recorded by them
        for extracellular potential measurements. Multiple electrode arrays
        may be defined as unique keys. The values of the dictionary are
        instances of :class:`hnn_core.extracellular.ExtracellularArray`.
    threshold : float
        Firing threshold of all cells.
    delay : float
        Synaptic delay in ms.
    Notes
    -----
    ``net = jones_2009_model(params)`` is the recommended path for creating a
    network. Instantiating the network as ``net = Network(params)`` will
    produce a network with no cell-to-cell connections. As such,
    connectivity information contained in ``params`` will be ignored.
    """
    def __init__(
        self,
        params,
        add_drives_from_params=False,
        legacy_mode=False,
        mesh_shape=(10, 10),
        pos_dict=None,
        cell_types=None,
    ):
        # Save the parameters used to create the Network
        _validate_type(params, dict, "params")
        self._params = params
        # Initialise a dictionary of cell ID's, which get used when the
        # network is constructed ('built') in NetworkBuilder
        # We want it to remain in each Network object, so that the user can
        # interrogate a built and simulated net. In addition, CellResponse is
        # attached to a Network during simulation---Network is the natural
        # place to keep this information. Order matters: cell gids first, then
        # artificial drive cells
        self.gid_ranges = OrderedDict()
        self._n_gids = 0  # utility: keep track of last GID
        # XXX this can be removed once tests are made independent of HNN GUI
        # creates nc_dict-entries for ALL cell types
        self._legacy_mode = legacy_mode
        if self._legacy_mode:
            warnings.warn(
                "Legacy mode is used solely to maintain compatibility with"
                ".param files of the old HNN GUI. This feature will be "
                "deprecrated in future releases.",
                DeprecationWarning,
                stacklevel=1,
            )
        self.cell_response = None
        # external drives and biases
        self.external_drives = dict()
        self.external_biases = dict()
        # network connectivity
        self.connectivity = list()
        self.threshold = self._params["threshold"]
        self.delay = 1.0
        # extracellular recordings (if applicable)
        self.rec_arrays = dict()
        # contents of pos_dict determines all downstream inferences of
        # cell counts, real and artificial
        self._n_cells = 0  # used in tests and MPIBackend checks
        self.pos_dict = dict()
        self.cell_types = dict()
        # set the mesh shape
        _validate_type(mesh_shape, tuple, "mesh_shape")
        _validate_type(mesh_shape[0], int, "mesh_shape[0]")
        _validate_type(mesh_shape[1], int, "mesh_shape[1]")
        if mesh_shape[0] < 1 or mesh_shape[1] < 1:
            raise ValueError(
                f"mesh_shape must be a tuple of positive integers, got: {mesh_shape}"
            )
        self._N_pyr_x = mesh_shape[0]
        self._N_pyr_y = mesh_shape[1]
        self._inplane_distance = 1.0  # XXX hard-coded default
        self._layer_separation = 1307.4  # XXX hard-coded default
        # Handle positions and cell types
        if pos_dict is not None and cell_types is not None:
            # Use provided positions and cell types
            _validate_type(pos_dict, dict, "pos_dict")
            _validate_type(cell_types, dict, "cell_types")
            self.pos_dict = deepcopy(pos_dict)
            # Add cell types from provided dictionary
            for cell_name, cell_template in cell_types.items():
                if cell_name in self.pos_dict:
                    self._add_cell_type(
                        cell_name, self.pos_dict[cell_name], cell_template=cell_template
                    )
        else:
            # Default behavior - create standard network
            cell_types_default = {
                "L2_basket": {
                    "cell_object": basket(cell_name="L2_basket"),
                    "cell_metadata": {
                        "morpho_type": "basket",
                        "electro_type": "inhibitory",
                        "layer": "2",
                        "measure_dipole": False,
                        "reference": "https://doi.org/10.7554/eLife.51214",
                    },
                },
                "L2_pyramidal": {
                    "cell_object": pyramidal(cell_name="L2_pyramidal"),
                    "cell_metadata": {
                        "morpho_type": "pyramidal",
                        "electro_type": "excitatory",
                        "layer": "2",
                        "measure_dipole": True,
                        "reference": "https://doi.org/10.7554/eLife.51214",
                    },
                },
                "L5_basket": {
                    "cell_object": basket(cell_name="L5_basket"),
                    "cell_metadata": {
                        "morpho_type": "basket",
                        "electro_type": "inhibitory",
                        "layer": "5",
                        "measure_dipole": False,
                        "reference": "https://doi.org/10.7554/eLife.51214",
                    },
                },
                "L5_pyramidal": {
                    "cell_object": pyramidal(cell_name="L5_pyramidal"),
                    "cell_metadata": {
                        "morpho_type": "pyramidal",
                        "electro_type": "excitatory",
                        "layer": "5",
                        "measure_dipole": True,
                        "reference": "https://doi.org/10.7554/eLife.51214",
                    },
                },
            }
            self.set_cell_positions(
                inplane_distance=self._inplane_distance,
                layer_separation=self._layer_separation,
            )
            # populates self.gid_ranges for the 1st time: order matters for
            # NetworkBuilder!
            for cell_name, cell_template in cell_types_default.items():
                self._add_cell_type(
                    cell_name,
                    self.pos_dict[cell_name],
                    cell_template=cell_template,
                )
        if add_drives_from_params:
            _add_drives_from_params(self)
        self._tstop = None
        self._dt = None
[docs]
    def __repr__(self):
        class_name = self.__class__.__name__
        # Dynamically create the description based on the current cell types
        descriptions = []
        for cell_name in self.cell_types:
            cell_count = len(self.pos_dict.get(cell_name, []))
            descriptions.append(f"{cell_count} {cell_name} cells")
        # Combine all descriptions into a single string
        description_str = "\n".join(descriptions)
        return f"<{class_name} | {description_str}>" 
    def __eq__(self, other):
        if not isinstance(other, Network):
            return NotImplemented
        # Check connectivity
        if (len(self.connectivity) != len(other.connectivity)) or not (
            _compare_lists(self.connectivity, other.connectivity)
        ):
            return False
        # Check all other attributes
        attrs_to_ignore = ["connectivity"]
        for attr in vars(self).keys():
            if attr.startswith("_") or attr in attrs_to_ignore:
                continue
            if hasattr(self, attr) and hasattr(other, attr):
                if getattr(self, attr) != getattr(other, attr):
                    return False
            else:
                # Does not have the same set of attributes
                return False
        return True
[docs]
    def set_cell_positions(self, *, inplane_distance=None, layer_separation=None):
        """Set relative positions of cells arranged in a square grid
        Note that it is possible to change only a subset of the parameters
        (the default value of each is None, which implies no change).
        Parameters
        ----------
        inplane_distance : float
            The in plane-distance (in um) between pyramidal cell somas in the
            square grid. Note that this parameter does not affect the amplitude
            of the dipole moment.
        layer_separation : float
            The separation of pyramidal cell soma layers 2/3 and 5. Note that
            this parameter does not affect the amplitude of the dipole moment.
        """
        if inplane_distance is None:
            inplane_distance = self._inplane_distance
        _validate_type(inplane_distance, (float, int), "inplane_distance")
        if not inplane_distance > 0.0:
            raise ValueError(
                f"In-plane distance must be positive, got: {inplane_distance}"
            )
        if layer_separation is None:
            layer_separation = self._layer_separation
        _validate_type(layer_separation, (float, int), "layer_separation")
        if not layer_separation > 0.0:
            raise ValueError(
                f"Layer separation must be positive, got: {layer_separation}"
            )
        # Get layer positions using layer dict
        layer_dict = _create_cell_coords(
            n_pyr_x=self._N_pyr_x,
            n_pyr_y=self._N_pyr_y,
            z_coord=layer_separation,
            inplane_distance=inplane_distance,
        )
        # Map layers to cell types, for default mapping
        self.pos_dict = {
            "L5_pyramidal": layer_dict["L5_bottom"],
            "L2_pyramidal": layer_dict["L2_bottom"],
            "L5_basket": layer_dict["L5_mid"],
            "L2_basket": layer_dict["L2_mid"],
            "origin": layer_dict["origin"],
        }
        # update drives to be positioned at network origin
        for drive_name, drive in self.external_drives.items():
            pos = [self.pos_dict["origin"]] * drive["n_drive_cells"]
            self.pos_dict[drive_name] = pos
        self._inplane_distance = inplane_distance
        self._layer_separation = layer_separation 
[docs]
    def copy(self):
        """Return a copy of the Network instance
        The returned copy retains the intrinsic connectivity between cells, as
        well as those of any external drives or biases added to the network.
        The parameters of drive dynamics are also retained, but the
        instantiated ``events`` of the drives are cleared. This allows
        iterating over the values defining drive dynamics, without the need to
        re-define connectivity. Extracellular recording arrays are retained in
        the network, but cleared of existing data.
        Returns
        -------
        net_copy : instance of Network
            A copy of the instance with previous simulation results and
            ``events`` of external drives removed.
        """
        net_copy = deepcopy(self)
        net_copy._reset_drives()
        net_copy._reset_rec_arrays()
        return net_copy 
[docs]
    def add_evoked_drive(
        self,
        name,
        *,
        mu,
        sigma,
        numspikes,
        location,
        n_drive_cells="n_cells",
        cell_specific=True,
        weights_ampa=None,
        weights_nmda=None,
        space_constant=3.0,
        synaptic_delays=0.1,
        probability=1.0,
        event_seed=2,
        conn_seed=3,
    ):
        """Add an 'evoked' external drive to the network
        Parameters
        ----------
        name : str
            Unique name for the drive
        mu : float
            Mean of Gaussian event time distribution
        sigma : float
            Standard deviation of event time distribution
        numspikes : int
            Number of spikes at each target cell
        location : str
            Target location of synapses. Must be an element of
            `Cell.sect_loc` such as 'proximal' or 'distal', which defines a
            group of sections, or an existing section such as 'soma' or
            'apical_tuft' (defined in `Cell.sections` for all targeted cells).
            The parameter `legacy_mode` of the `Network` must be set to `False`
            to target specific sections.
        n_drive_cells : int | 'n_cells'
            The number of drive cells that each contribute an independently
            sampled synaptic spike to the network according to the Gaussian
            time distribution (mu, sigma). If n_drive_cells='n_cells'
            (default) and cell_specific=True, a drive cell gets assigned to
            each available simulated cell in the network with 1-to-1
            connectivity. Otherwise, drive cells are assigned with
            all-to-all connectivity. If you wish to synchronize the timing of
            this evoked drive across the network in a given trial with one
            spike, set n_drive_cells=1 and cell_specific=False.
        cell_specific : bool
            Whether each artificial drive cell has 1-to-1 (True, default) or
            all-to-all (False) connection parameters. Note that 1-to-1
            connectivity requires that n_drive_cells='n_cells', where 'n_cells'
            denotes the number of all available cells that this drive can
            target in the network.
        weights_ampa : dict or None
            Synaptic weights (in uS) of AMPA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        weights_nmda : dict or None
            Synaptic weights (in uS) of NMDA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        synaptic_delays : dict or float
            Synaptic delay (in ms) at the column origin, dispersed laterally as
            a function of the space_constant. If float, applies to all target
            cell types. Use dict to create delay->cell mapping.
        space_constant : float
            Describes lateral dispersion (from the column origin) of synaptic
            weights and delays within the simulated column. The constant is
            measured in the units of ``inplane_distance`` of
            :class:`~hnn_core.Network`. For example, for ``space_constant=3``,
            the weights are modulated by the factor
            ``exp(-(x / (3 * inplane_distance)) ** 2)``, where x is the
            physical distance (in um) between the connected cells in the xy
            plane (delays are modulated by the inverse of this factor).
        probability : dict or float (default: 1.0)
            Probability of connection between any src-target pair.
            Use dict to create probability->cell mapping. If float, applies to
            all target cell types
        event_seed : int
            Optional initial seed for random number generator (default: 2).
            Used to generate event times for drive cells.
            Not fixed across trials (see Notes)
        conn_seed : int
            Optional initial seed for random number generator (default: 3).
            Used to randomly remove connections when probability < 1.0.
            Fixed across trials (see Notes)
        Notes
        -----
        Random seeding behavior across trials is different for event_seed
        and conn_seed (n_trials > 1 in simulate_dipole(..., n_trials):
        - event_seed
            Across trials, the random seed is incremented such that
            the exact spike times are different
        - conn_seed
            The random seed does not change across trials. This means for
            probability < 1.0, the random subset of gids targeted is the same.
        """
        if not self._legacy_mode:
            _check_drive_parameter_values("evoked", sigma=sigma, numspikes=numspikes)
        drive = _NetworkDrive()
        drive["type"] = "evoked"
        drive["location"] = location
        if name == "extgauss":
            drive["type"] = "gaussian"  # XXX needed to pass legacy tests!
        drive["n_drive_cells"] = n_drive_cells
        drive["event_seed"] = event_seed
        drive["conn_seed"] = conn_seed
        drive["dynamics"] = dict(mu=mu, sigma=sigma, numspikes=numspikes)
        drive["events"] = list()
        # Need to save this information
        drive["weights_ampa"] = weights_ampa
        drive["weights_nmda"] = weights_nmda
        drive["synaptic_delays"] = synaptic_delays
        drive["probability"] = probability
        self._attach_drive(
            name,
            drive,
            weights_ampa,
            weights_nmda,
            location,
            space_constant,
            synaptic_delays,
            n_drive_cells,
            cell_specific,
            probability,
        ) 
[docs]
    def add_poisson_drive(
        self,
        name,
        *,
        tstart=0,
        tstop=None,
        rate_constant,
        location,
        n_drive_cells="n_cells",
        cell_specific=True,
        weights_ampa=None,
        weights_nmda=None,
        space_constant=100.0,
        synaptic_delays=0.1,
        probability=1.0,
        event_seed=2,
        conn_seed=3,
    ):
        """Add a Poisson-distributed external drive to the network
        Parameters
        ----------
        name : str
            Unique name for the drive
        tstart : float
            Start time of Poisson-distributed spike train (default: 0)
        tstop : float
            End time of the spike train (defaults to None: tstop is set to the
            end of the simulation)
        rate_constant : float or dict of floats
            Rate constant (lambda > 0) of the renewal-process generating the
            samples. If a float is provided, the same rate constant is applied
            to each target cell type. Cell type-specific values may be
            provided as a dictionary, in which a key must be present for each
            cell type with non-zero AMPA or NMDA weights.
        location : str
            Target location of synapses. Must be an element of
            `Cell.sect_loc` such as 'proximal' or 'distal', which defines a
            group of sections, or an existing section such as 'soma' or
            'apical_tuft' (defined in `Cell.sections` for all targeted cells).
            The parameter `legacy_mode` of the `Network` must be set to `False`
            to target specific sections.
        n_drive_cells : int | 'n_cells'
            The number of drive cells that each contribute an independently
            sampled synaptic spike to the network according to a Poisson
            process. If n_drive_cells='n_cells' (default) and
            cell_specific=True, a drive cell gets assigned to each available
            simulated cell in the network with 1-to-1 connectivity. Otherwise,
            drive cells are assigned with all-to-all connectivity. If you wish
            to synchronize the timing of Poisson drive across the network in a
            given trial, set n_drive_cells=1 and cell_specific=False.
        cell_specific : bool
            Whether each artificial drive cell has 1-to-1 (True, default) or
            all-to-all (False) connection parameters. Note that 1-to-1
            connectivity requires that n_drive_cells='n_cells', where 'n_cells'
            denotes the number of all available cells that this drive can
            target in the network.
        weights_ampa : dict or None
            Synaptic weights (in uS) of AMPA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        weights_nmda : dict or None
            Synaptic weights (in uS) of NMDA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        synaptic_delays : dict or float
            Synaptic delay (in ms) at the column origin, dispersed laterally as
            a function of the space_constant. If float, applies to all target
            cell types. Use dict to create delay->cell mapping.
        space_constant : float
            Describes lateral dispersion (from the column origin) of synaptic
            weights and delays within the simulated column. The constant is
            measured in the units of ``inplane_distance`` of
            :class:`~hnn_core.Network`. For example, for ``space_constant=3``,
            the weights and delays are modulated by the factor
            ``exp(-(x / (3 * inplane_distance)) ** 2)``, where ``x`` is the
            physical distance (in um) between the connected cells in the xy
            plane.
        probability : dict or float (default: 1.0)
            Probability of connection between any src-target pair.
            Use dict to create probability->cell mapping. If float, applies to
            all target cell types.
        event_seed : int
            Optional initial seed for random number generator (default: 2).
            Used to generate event times for drive cells.
        conn_seed : int
            Optional initial seed for random number generator (default: 3).
            Used to randomly remove connections when probability < 1.0.
        """
        _check_drive_parameter_values("Poisson", tstart=tstart, tstop=tstop)
        target_populations, _, _, _ = _get_target_properties(
            weights_ampa,
            weights_nmda,
            synaptic_delays,
            location,
            self.cell_types,
            probability=probability,
        )
        _check_poisson_rates(rate_constant, target_populations, self.cell_types.keys())
        if isinstance(rate_constant, dict):
            if not cell_specific:
                raise ValueError(
                    f"Drives specific to cell types are only "
                    f"possible with cell_specific=True and "
                    f"n_drive_cells='n_cells'. Got cell_specific"
                    f" cell_specific={cell_specific} and "
                    f"n_drive_cells={n_drive_cells}."
                )
        elif isinstance(rate_constant, (float, int)):
            if cell_specific:
                rate_constant = {
                    cell_type: rate_constant for cell_type in target_populations
                }
        drive = _NetworkDrive()
        drive["type"] = "poisson"
        drive["location"] = location
        drive["n_drive_cells"] = n_drive_cells
        drive["event_seed"] = event_seed
        drive["conn_seed"] = conn_seed
        drive["dynamics"] = dict(
            tstart=tstart, tstop=tstop, rate_constant=rate_constant
        )
        drive["events"] = list()
        # Need to save this information
        drive["weights_ampa"] = weights_ampa
        drive["weights_nmda"] = weights_nmda
        drive["synaptic_delays"] = synaptic_delays
        drive["probability"] = probability
        self._attach_drive(
            name,
            drive,
            weights_ampa,
            weights_nmda,
            location,
            space_constant,
            synaptic_delays,
            n_drive_cells,
            cell_specific,
            probability,
        ) 
[docs]
    def add_bursty_drive(
        self,
        name,
        *,
        tstart=0,
        tstart_std=0,
        tstop=None,
        location,
        burst_rate,
        burst_std=0,
        numspikes=2,
        spike_isi=10,
        n_drive_cells=1,
        cell_specific=False,
        weights_ampa=None,
        weights_nmda=None,
        synaptic_delays=0.1,
        space_constant=100.0,
        probability=1.0,
        event_seed=2,
        conn_seed=3,
    ):
        """Add a bursty (rhythmic) external drive to all cells of the network
        Parameters
        ----------
        name : str
            Unique name for the drive
        tstart : float
            Start time of the burst trains (default: 0)
        tstart_std : float
            If greater than 0, randomize start time with standard deviation
            tstart_std (unit: ms). Effectively jitters start time across
            multiple trials.
        tstop : float
            End time of burst trains (defaults to None: tstop is set to the
            end of the simulation)
        location : str
            Target location of synapses. Must be an element of
            `Cell.sect_loc` such as 'proximal' or 'distal', which defines a
            group of sections, or an existing section such as 'soma' or
            'apical_tuft' (defined in `Cell.sections` for all targeted cells).
            The parameter `legacy_mode` of the `Network` must be set to `False`
            to target specific sections.
        burst_rate : float
            The mean rate at which cyclic bursts occur (unit: Hz)
        burst_std : float
            The standard deviation of the burst occurrence on each cycle
            (unit: ms). Default: 0 ms
        numspikes : int
            The number of spikes in a burst. This is the spikes/burst parameter
            in the GUI. Default: 2 (doublet)
        spike_isi : float
            Time between spike events within a cycle (ISI). Default: 10 ms
        n_drive_cells : int | 'n_cells'
            The number of drive cells that contribute an independently sampled
            burst at each cycle. If n_drive_cells='n_cells' and
            cell_specific=True, a drive cell gets assigned to
            each available simulated cell in the network with 1-to-1
            connectivity. Otherwise (default: 1), drive cells are assigned with
            all-to-all connectivity and provide synchronous input to cells in
            the network.
        cell_specific : bool
            Whether each artificial drive cell has 1-to-1 (True) or all-to-all
            (False, default) connection parameters. Note that 1-to-1
            connectivity requires that n_drive_cells='n_cells', where 'n_cells'
            denotes the number of all available cells that this drive can
            target in the network.
        weights_ampa : dict or None
            Synaptic weights (in uS) of AMPA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        weights_nmda : dict or None
            Synaptic weights (in uS) of NMDA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        synaptic_delays : dict or float
            Synaptic delay (in ms) at the column origin, dispersed laterally as
            a function of the space_constant. If float, applies to all target
            cell types. Use dict to create delay->cell mapping.
        space_constant : float
            Describes lateral dispersion (from the column origin) of synaptic
            weights and delays within the simulated column. The constant is
            measured in the units of ``inplane_distance`` of
            :class:`~hnn_core.Network`. For example, for ``space_constant=3``,
            the weights and delays are modulated by the factor
            ``exp(-(x / (3 * inplane_distance)) ** 2)``, where ``x`` is the
            physical distance (in um) between the connected cells in the xy
            plane.
        probability : dict or float (default: 1.0)
            Probability of connection between any src-target pair.
            Use dict to create probability->cell mapping. If float, applies to
            all target cell types.
        event_seed : int
            Optional initial seed for random number generator (default: 2).
            Used to generate event times for drive cells.
        conn_seed : int
            Optional initial seed for random number generator (default: 3).
            Used to randomly remove connections when probability < 1.0.
        """
        if not self._legacy_mode:
            _check_drive_parameter_values(
                "bursty",
                tstart=tstart,
                tstop=tstop,
                sigma=tstart_std,
                location=location,
            )
            _check_drive_parameter_values(
                "bursty",
                sigma=burst_std,
                numspikes=numspikes,
                spike_isi=spike_isi,
                burst_rate=burst_rate,
            )
        drive = _NetworkDrive()
        drive["type"] = "bursty"
        drive["location"] = location
        drive["n_drive_cells"] = n_drive_cells
        drive["event_seed"] = event_seed
        drive["conn_seed"] = conn_seed
        drive["dynamics"] = dict(
            tstart=tstart,
            tstart_std=tstart_std,
            tstop=tstop,
            burst_rate=burst_rate,
            burst_std=burst_std,
            numspikes=numspikes,
            spike_isi=spike_isi,
        )
        drive["events"] = list()
        # Need to save this information
        drive["weights_ampa"] = weights_ampa
        drive["weights_nmda"] = weights_nmda
        drive["synaptic_delays"] = synaptic_delays
        drive["probability"] = probability
        self._attach_drive(
            name,
            drive,
            weights_ampa,
            weights_nmda,
            location,
            space_constant,
            synaptic_delays,
            n_drive_cells,
            cell_specific,
            probability,
        ) 
[docs]
    def add_spike_train_drive(
        self,
        name,
        *,
        spike_data,
        location,
        weights_ampa=None,
        weights_nmda=None,
        synaptic_delays=0.1,
        space_constant=3.0,
        probability=1.0,
        conn_seed=None,
    ):
        """Add an external drive from explicitly defined spike trains.
        This method enables the target network to receive spike trains from a source
        network (e.g., another HNN simulation) or external data, driving activity in the
        target network's cells.
        Parameters
        ----------
        name : str
            Unique name for the drive (e.g., 'drive_from_NetA').
        spike_data : dict or list of tuple
            Spike train data from the **source network** (or external source) in one of
            three formats:
            - *Format 1 (dictionary)*: Keys are unique identifiers (str) for source
            cells and values are lists of spike times in milliseconds. The keys
            can be any string that helps identify the source.
            Example:
            ```
            {"NetA_L2_pyramidal_GID0": [10.2, 25.3], "NetA_L5_pyramidal_GID1": [15.1, 30.4]}
            ```
            - *Format 2 (tuples)*: A list of (time, gid) tuples, where each tuple
            contains a spike time (float, in ms) and a GID (int). The GIDs can be
            any integers that uniquely identify different source cells.
            Example:
            ```
            [(10.2, 0), (25.3, 1), (15.1, 0), (30.4, 1)]
            ```
            - *Format 3*: String path (or glob pattern) to spike files that can be loaded
            with :func:`~hnn_core.read_spikes`. Example: "path/to/spk_*.txt"
            Note: The GIDs in both formats refer to the source cells in the originating
            network (or external data). These are arbitrary identifiers that will be
            remapped internally to sequential drive cell IDs (0 to n-1) in the target
            network. Different GIDs should be used for different source cells.
        location : str
            Target location of synapses in the target network. Must be 'proximal', 'distal', or
            'soma', or a specific section name (when legacy_mode=False).
        weights_ampa : dict or None
            Synaptic weights (in uS) of AMPA receptors for each targeted cell type (dict keys).
            Cell types omitted are set to zero.
        weights_nmda : dict or None
            Synaptic weights (in uS) of NMDA receptors for each targeted cell type (dict keys).
            Cell types omitted are set to zero.
        synaptic_delays : dict or float
            Synaptic delay (in ms) at the column origin, dispersed laterally as
            a function of the space_constant. If float, applies to all target
            cell types. Use dict to create delay->cell mapping.
        space_constant : float
            Lateral dispersion constant (in units of inplane_distance) for synaptic weights and
            delays within the target network. Default: 3.0
        probability : float or dict
            Connection probability between source and target cells. Default: 1.0 (all-to-all).
        conn_seed : int
            Optional seed for random number generator for connectivity (default: None).
        cell_specific : bool
            If True, enables cell-specific connectivity (e.g., for 1-to-1 mapping). Default: False.
        n_drive_cells : str or int
            Number of drive cells. Use 'n_cells' for 1-to-1 mapping with target cell types,
            or an integer for a fixed number. Default: None (inferred from spike_data).
        """
        if not self._legacy_mode:
            warnings.warn(
                "Spike train drives can only target sections defined in "
                "`Cell.sect_loc` when `legacy_mode=False`.",
                UserWarning,
            )
        # Create the drive object
        drive = _NetworkDrive()
        drive["type"] = "spike_train"
        drive["location"] = location
        drive["events"] = list()  # Will be populated during instantiation
        # Process spike_data into a standardized format
        standardized_data, n_drive_cells, source_to_gid_map = (
            self._standardize_spike_data(spike_data)
        )
        # Set drive properties
        drive["dynamics"] = standardized_data
        drive["n_drive_cells"] = n_drive_cells
        if source_to_gid_map is not None:
            drive["source_to_gid_map"] = source_to_gid_map
        drive["conn_seed"] = conn_seed
        drive["event_seed"] = (
            0  # Not used for spike train, but included for consistency
        )
        # Save connection parameters
        drive["weights_ampa"] = weights_ampa
        drive["weights_nmda"] = weights_nmda
        drive["synaptic_delays"] = synaptic_delays
        drive["probability"] = probability
        # Attach the drive to network cells
        self._attach_drive(
            name,
            drive,
            weights_ampa=weights_ampa,
            weights_nmda=weights_nmda,
            location=location,
            space_constant=space_constant,
            synaptic_delays=synaptic_delays,
            n_drive_cells=drive["n_drive_cells"],
            cell_specific=False,
            probability=probability,
        ) 
    def _attach_drive(
        self,
        name,
        drive,
        weights_ampa,
        weights_nmda,
        location,
        space_constant,
        synaptic_delays,
        n_drive_cells,
        cell_specific,
        probability,
    ):
        """Attach a drive to network based on connectivity information
        Parameters
        ----------
        name : str
            Name of drive (must be unique)
        drive : instance of _NetworkDrive
            Collection of parameters defining the dynamics of the drive
        weights_ampa : dict or None
            Synaptic weights (in uS) of AMPA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        weights_nmda : dict or None
            Synaptic weights (in uS) of NMDA receptors on each targeted cell
            type (dict keys). Cell types omitted from the dict are set to zero.
        location : str
            Target location of synapses. Must be an element of
            `Cell.sect_loc` such as 'proximal' or 'distal', which defines a
            group of sections, or an existing section such as 'soma' or
            'apical_tuft' (defined in `Cell.sections` for all targeted cells).
            The parameter `legacy_mode` of the `Network` must be set to `False`
            to target specific sections.
        space_constant : float
            Describes lateral dispersion (from the column origin) of synaptic
            weights and delays within the simulated column. The constant is
            measured in the units of ``inplane_distance`` of
            :class:`~hnn_core.Network`. For example, for ``space_constant=3``,
            the weights and delays are modulated by the factor
            ``exp(-(x / (3 * inplane_distance)) ** 2)``, where ``x`` is the
            physical distance (in um) between the connected cells in the xy
            plane.
        synaptic_delays : dict or float
            Synaptic delay (in ms) at the column origin, dispersed laterally as
            a function of the space_constant
        n_drive_cells : int | 'n_cells'
            The number of drive cells (i.e., ArtificialCell objects) that
            contribute to this drive. If n_drive_cells='n_cells' and
            cell_specific=True, an artificial drive cell gets assigned to each
            available cell in the network with 1-to-1 connectivity (completely
            unsynchronous). Otherwise, drive cells get assigned with all-to-all
            connectivity. If you wish to synchronize the timing of this evoked
            drive across the network in a given trial with one spike, set
            n_drive_cells=1 and cell_specific=False.
        cell_specific : bool
            Whether each artificial drive cell has 1-to-1 (True) or all-to-all
            (False) connection parameters. Note that 1-to-1
            connectivity requires that n_drive_cells='n_cells', where 'n_cells'
            denotes the number of all available cells that this drive can
            target in the network.
        probability : dict or float (default: 1.0)
            Probability of connection between any src-target pair.
            Use dict to create probability->cell mapping. If float, applies to
            all target cell types
        Attached drive is stored in self.external_drives[name]
        self.pos_dict is updated, and self._update_gid_ranges() called
        """
        if name in self.external_drives:
            raise ValueError(f"Drive {name} already defined")
        _validate_type(probability, (float, dict), "probability", "float or dict")
        # allow passing weights as None, convert to dict here
        (target_populations, weights_by_type, delays_by_type, probability_by_type) = (
            _get_target_properties(
                weights_ampa,
                weights_nmda,
                synaptic_delays,
                location,
                self.cell_types,
                probability=probability,
            )
        )
        # weights passed must correspond to cells in the network
        if not target_populations.issubset(set(self.cell_types.keys())):
            raise ValueError(
                "Allowed drive target cell types are: ", f"{self.cell_types.keys()}"
            )
        # enforce the same order as in self.cell_types - necessary for
        # consistent source gid assignment
        target_populations = [
            cell_type
            for cell_type in self.cell_types.keys()
            if cell_type in target_populations
        ]
        # Ensure location exists for all target cells
        cell_sections = [
            set(self.cell_types[cell_type]["cell_object"].sections.keys())
            for cell_type in target_populations
        ]
        sect_locs = [
            set(self.cell_types[cell_type]["cell_object"].sect_loc.keys())
            for cell_type in target_populations
        ]
        valid_cell_sections = set.intersection(*cell_sections)
        valid_sect_locs = set.intersection(*sect_locs)
        valid_loc = list(valid_cell_sections) + list(valid_sect_locs)
        _check_option(
            "location",
            location,
            valid_loc,
            extra=(
                f" (the location '{location}' is not defined "
                "for one of the targeted cells)"
            ),
        )
        if self._legacy_mode:
            # allows tests must match HNN GUI output by preserving original
            # gid assignment convention
            target_populations = list(self.cell_types.keys())
            for target_type in target_populations:
                if target_type not in weights_by_type:
                    weights_by_type.update({target_type: {"ampa": 0.0}})
                if target_type not in delays_by_type:
                    delays_by_type.update({target_type: 0.1})
                if target_type not in probability_by_type:
                    probability_by_type.update({target_type: 1.0})
        elif len(target_populations) == 0:
            raise ValueError(
                "No target populations have been specified for this drive."
            )
        if cell_specific and n_drive_cells != "n_cells":
            raise ValueError(
                f"If cell_specific is True, n_drive_cells must"
                f" equal 'n_cells'. Got {n_drive_cells}."
            )
        elif not cell_specific:
            if not isinstance(n_drive_cells, int):
                raise ValueError(
                    f"If cell_specific is False, n_drive_cells "
                    f"must be of type int. Got "
                    f"{type(n_drive_cells)}."
                )
            if not n_drive_cells > 0:
                raise ValueError(
                    "Number of drive cells must be greater than "
                    f"0. Got {n_drive_cells}."
                )
        drive["name"] = name  # for easier for-looping later
        drive["target_types"] = target_populations  # for _connect_celltypes
        drive["cell_specific"] = cell_specific
        if n_drive_cells == "n_cells":
            n_drive_cells = 0
            for cell_type in target_populations:
                n_drive_cells += len(self.gid_ranges[cell_type])
        drive["n_drive_cells"] = n_drive_cells
        self.external_drives[name] = drive
        pos = [self.pos_dict["origin"]] * n_drive_cells
        self._add_cell_type(name, pos)
        # Set the starting index for cell-specific source gids
        # This will be updated depending on the number of target cells
        # of each cell type
        src_idx = 0
        # seed_increment increased by 1 for each target cell type,
        # added to conn_seed to ensure statistical independence of random
        # connections when probability < 1.0
        for seed_increment, target_cell_type in enumerate(target_populations):
            target_gids = list(self.gid_ranges[target_cell_type])
            delays = delays_by_type[target_cell_type]
            probability = probability_by_type[target_cell_type]
            if cell_specific:
                target_gids_nested = [[target_gid] for target_gid in target_gids]
                src_idx_end = src_idx + len(target_gids)
                src_gids = list(self.gid_ranges[name])[src_idx:src_idx_end]
                src_idx = src_idx_end
                for receptor_idx, receptor in enumerate(
                    weights_by_type[target_cell_type]
                ):
                    weights = weights_by_type[target_cell_type][receptor]
                    self.add_connection(
                        src_gids=src_gids,
                        target_gids=target_gids_nested,
                        loc=location,
                        receptor=receptor,
                        weight=weights,
                        delay=delays,
                        lamtha=space_constant,
                        probability=probability,
                        conn_seed=drive["conn_seed"] + seed_increment,
                    )
                    # Ensure that AMPA/NMDA connections target the same gids
                    if receptor_idx > 0:
                        self.connectivity[-1]["src_gids"] = self.connectivity[-2][
                            "src_gids"
                        ]
            else:
                for receptor_idx, receptor in enumerate(
                    weights_by_type[target_cell_type]
                ):
                    weights = weights_by_type[target_cell_type][receptor]
                    self.add_connection(
                        src_gids=name,
                        target_gids=target_gids,
                        loc=location,
                        receptor=receptor,
                        weight=weights,
                        delay=delays,
                        lamtha=space_constant,
                        probability=probability,
                        conn_seed=drive["conn_seed"] + seed_increment,
                    )
                    # Ensure that AMPA/NMDA connections target the same gids
                    # when probability < 1
                    if receptor_idx > 0:
                        self.connectivity[-1]["src_gids"] = self.connectivity[-2][
                            "src_gids"
                        ]
    def _reset_drives(self):
        # reset every time called again, e.g., from dipole.py or in self.copy()
        for drive_name in self.external_drives.keys():
            self.external_drives[drive_name]["events"] = list()
    def _reset_rec_arrays(self):
        # clear the data in rec_arrays
        for arr in self.rec_arrays.values():
            arr._reset()
    def _instantiate_drives(self, tstop, n_trials=1):
        """Creates event time vectors for all drives across trials
        Parameters
        ----------
        tstop : float
            The simulation stop time (ms)
        n_trials : int
            Number of trials to create events for (default: 1)
        NB this must be a separate method because dipole.py:simulate_dipole
        accepts an n_trials-argument, which overrides the N_trials-parameter
        used at initialisation time. The good news is that only the event_times
        need to be recalculated, all the GIDs etc remain the same.
        """
        self._reset_drives()
        # each trial needs unique event time vectors
        for trial_idx in range(n_trials):
            for drive in self.external_drives.values():
                event_times = list()  # new list for each trial and drive
                for drive_cell_gid in self.gid_ranges[drive["name"]]:
                    drive_cell_gid_offset = (
                        drive_cell_gid - self.gid_ranges[drive["name"]][0]
                    )
                    trial_seed_offset = self._n_gids
                    if drive["cell_specific"]:
                        # loop over drives (one for each target cell
                        # population) and create event times
                        conn_idxs = pick_connection(self, src_gids=drive_cell_gid)
                        target_types = set(
                            [
                                self.connectivity[conn_idx]["target_type"]
                                for conn_idx in conn_idxs
                            ]
                        )
                        for target_type in target_types:
                            event_times.append(
                                _drive_cell_event_times(
                                    drive["type"],
                                    drive["dynamics"],
                                    target_type=target_type,
                                    trial_idx=trial_idx,
                                    drive_cell_gid=drive_cell_gid_offset,
                                    event_seed=drive["event_seed"],
                                    tstop=tstop,
                                    trial_seed_offset=trial_seed_offset,
                                )
                            )
                    else:
                        src_event_times = _drive_cell_event_times(
                            drive["type"],
                            drive["dynamics"],
                            tstop=tstop,
                            target_type="any",
                            trial_idx=trial_idx,
                            drive_cell_gid=drive_cell_gid_offset,
                            event_seed=drive["event_seed"],
                            trial_seed_offset=trial_seed_offset,
                        )
                        event_times.append(src_event_times)
                # 'events': nested list (n_trials x n_drive_cells x n_events)
                self.external_drives[drive["name"]]["events"].append(event_times)
[docs]
    def add_tonic_bias(
        self,
        *,
        cell_type=None,
        section="soma",
        bias_name="tonic",
        amplitude,
        t0=0,
        tstop=None,
    ):
        """Attaches parameters of tonic bias input for given cell types
        Parameters
        ----------
        cell_types : str | None
            The name of the cell type to add a tonic input. When supplied,
            a float value must be provided with the `amplitude` keyword.
            Valid inputs are those listed in  `net.cell_types`.
        section : str
            name of cell section the bias should be applied to.
            See net.cell_types[cell_type].sections.keys()
        bias_name : str
            The name of the bias.
        amplitude: dict | float
            A dictionary of cell type keys (str) to amplitude values (float).
            Valid inputs for cell types are those listed in `net.cell_types`.
            If `cell_types` is not None, `amplitude` should be
            a float indicating the amplitude of the tonic input
            for the specified cell type.
        t0 : float
            The start time of tonic input (in ms). Default: 0 (beginning of
            simulation). This value will be applied to all the  tonic biases if
            multiple are specified with the `amplitude` keyword.
        tstop : float
            The end time of tonic input (in ms). Default: end of simulation.
            This value will be applied to all the  tonic biases if
            multiple are specified with the `amplitude` keyword.
        """
        # old functionality single cell type - amplitude
        if cell_type is not None:
            warnings.warn(
                "cell_type argument will be deprecated and "
                "removed in future releases. Use amplitude as a "
                "cell_type:str,amplitude:float dictionary."
                "Read the function docustring for more information",
                DeprecationWarning,
                stacklevel=1,
            )
            _validate_type(amplitude, (float, int), "amplitude")
            _add_cell_type_bias(
                network=self,
                cell_type=cell_type,
                section=section,
                bias_name=bias_name,
                amplitude=float(amplitude),
                t_0=t0,
                t_stop=tstop,
            )
        else:
            _validate_type(amplitude, dict, "amplitude")
            if len(amplitude) == 0:
                warnings.warn(
                    "No bias have been defined, no action taken",
                    UserWarning,
                    stacklevel=1,
                )
                return
            for _cell_type, _amplitude in amplitude.items():
                _add_cell_type_bias(
                    network=self,
                    cell_type=_cell_type,
                    section=section,
                    bias_name=bias_name,
                    amplitude=_amplitude,
                    t_0=t0,
                    t_stop=tstop,
                ) 
    def _add_cell_type(self, cell_name, pos, cell_template=None):
        """Add cell type by updating pos_dict and gid_ranges."""
        ll = self._n_gids
        self._n_gids += len(pos)
        self.gid_ranges[cell_name] = range(ll, self._n_gids)
        self.pos_dict[cell_name] = pos
        if cell_template is not None:
            self.cell_types.update({cell_name: cell_template})
            self._n_cells += len(pos)
    def _rename_cell_types(self, name_mapping: Dict[str, str]):
        """Renames cell types in the network.
        XXX: All HNN functionality is not supported, such as Dipole calculation
        Parameters
        ----------
        name_mapping: dict[str, str]
            Dictionary of what cell type names to change, and what to change
            them to. Keys are existing cell type name strings, and values are
            what string to change each key to. Note that both elements must be
            strings.
        """
        _validate_type(name_mapping, dict, "name_mapping")
        # Store original cell object names to preserve them
        original_cell_names = {}
        for original_name in name_mapping.keys():
            if (
                isinstance(self.cell_types.get(original_name), dict)
                and "cell_object" in self.cell_types[original_name]
            ):
                original_cell_names[original_name] = self.cell_types[original_name][
                    "cell_object"
                ].name
        for original_name, new_name in name_mapping.items():
            if original_name not in self.cell_types.keys():
                raise ValueError(f"'{original_name}' is not in cell_types!")
            elif new_name in self.cell_types.keys():
                raise ValueError(f"'{new_name}' is already in cell_types!")
            elif original_name in self.cell_types.keys():
                for attr_name in [
                    "cell_types",
                    "pos_dict",
                    "external_biases",
                    "external_drives",
                    "gid_ranges",
                ]:
                    attr = getattr(self, attr_name)
                    if isinstance(attr, dict):
                        updated_attr = _replace_dict_identifier(
                            attr, original_name, new_name
                        )
                        setattr(self, attr_name, updated_attr)
                # Update Network.connectivity
                for connection in self.connectivity:
                    if connection["src_type"] == original_name:
                        connection["src_type"] = new_name
                    if connection["target_type"] == original_name:
                        connection["target_type"] = new_name
                # Restore original cell object name to preserve cell template identity
                if (
                    new_name in self.cell_types
                    and isinstance(self.cell_types[new_name], dict)
                    and "cell_object" in self.cell_types[new_name]
                    and original_name in original_cell_names
                ):
                    self.cell_types[new_name]["cell_object"].name = original_cell_names[
                        original_name
                    ]
[docs]
    def gid_to_type(self, gid):
        """Reverse lookup of gid to type."""
        return _gid_to_type(gid, self.gid_ranges) 
[docs]
    def add_connection(
        self,
        src_gids,
        target_gids,
        loc,
        receptor,
        weight,
        delay,
        lamtha,
        threshold=None,
        gain=1.0,
        allow_autapses=True,
        probability=1.0,
        conn_seed=None,
    ):
        """Appends connections to connectivity list
        Parameters
        ----------
        src_gids : str | int | range | list of int
            Identifier for source cells. Passing str arguments ('evdist1',
            'L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket', etc.) is
            equivalent to passing a list of gids for the relevant cell type.
            source - target connections are made in an all-to-all pattern.
        target_gids : str | int | range | list of int
            Identifier for targets of source cells. Passing str arguments
            ('L2_pyramidal', 'L2_basket', 'L5_pyramidal', 'L5_basket') is
            equivalent to passing a list of gids for the relevant cell type.
            source - target connections are made in an all-to-all pattern.
        loc : str
            Target location of synapses. Must be an element of
            `Cell.sect_loc` such as 'proximal' or 'distal', which defines a
            group of sections, or an existing section such as 'soma' or
            'apical_tuft' (defined in `Cell.sections` for all targeted cells).
            The parameter `legacy_mode` of the `Network` must be set to `False`
            to target specific sections.
        receptor : str
            Synaptic receptor of connection. Must be one of:
            'ampa', 'nmda', 'gabaa', or 'gabab'.
        weight : float
            Synaptic weight on target cell.
        delay : float
            Synaptic delay in ms.
        lamtha : float
            Space constant.
        threshold : float, default=None
            Firing threshold of cells for connection. If None (the default), inherit the
            threshold from the Network object.
        gain : float, default=1.0
            Multiplicative factor for synaptic weight.
        allow_autapses : bool, default=True
            If True, allow connecting neuron to itself.
        probability : float, default=1.0
            Probability of connection between any src-target pair.
            Defaults to 1.0 producing an all-to-all pattern.
        conn_seed : int, default=None
            Optional initial seed for random number generator (default: None).
            Used to randomly remove connections when probability < 1.0.
        Notes
        -----
        Connections are stored in ``net.connectivity[idx]['gid_pairs']``, a
        dictionary indexed by src gids with the format:
        {src_gid: [target_gids, ...], ...} where each src_gid indexes a list of
        all its targets.
        """
        conn = _Connectivity()
        # Threshold's value is validated later below with the rest of nc_dict
        if threshold is None:
            threshold = self.threshold
        _validate_type(
            target_gids,
            (int, list, range, str),
            "target_gids",
            "int list, range or str",
        )
        _validate_type(allow_autapses, bool, "target_gids", "bool")
        valid_source_cells = list(self.gid_ranges.keys())
        # Convert src_gids to list
        src_gids = _check_gids(
            src_gids, self.gid_ranges, valid_source_cells, "src_gids"
        )
        # Convert target_gids to list of list, one element for each src_gid
        valid_target_cells = list(self.cell_types.keys())
        if isinstance(target_gids, int):
            target_gids = [[target_gids] for _ in range(len(src_gids))]
        elif isinstance(target_gids, str):
            _check_option("target_gids", target_gids, valid_target_cells)
            target_gids = [
                list(self.gid_ranges[_long_name(target_gids)])
                for _ in range(len(src_gids))
            ]
        elif isinstance(target_gids, range):
            target_gids = [list(target_gids) for _ in range(len(src_gids))]
        elif isinstance(target_gids, list) and all(
            isinstance(t_gid, int) for t_gid in target_gids
        ):
            target_gids = [target_gids for _ in range(len(src_gids))]
        # Validate each target list - src pairs.
        # set() used to avoid redundant checks.
        target_set = set()
        for target_src_pair in target_gids:
            _validate_type(target_src_pair, list, "target_gids[idx]", "list or range")
            for target_gid in target_src_pair:
                target_set.add(target_gid)
        target_type = self.gid_to_type(target_gids[0][0])
        for target_gid in target_set:
            _validate_type(target_gid, int, "target_gid", "int")
            # Ensure gids in range of Network.gid_ranges
            gid_type = self.gid_to_type(target_gid)
            if gid_type is None:
                raise AssertionError(f"target_gid {target_gid}not in net.gid_ranges")
            elif gid_type != target_type:
                raise AssertionError("All target_gids must be of the same type")
        conn["target_type"] = target_type
        conn["target_gids"] = target_set
        conn["num_targets"] = len(target_set)
        if len(target_gids) != len(src_gids):
            raise AssertionError("target_gids must have a list for each src.")
        # Format gid_pairs and add to conn dictionary
        gid_pairs = dict()
        for src_gid, target_src_pair in zip(src_gids, target_gids):
            if not allow_autapses:
                mask = np.isin(target_src_pair, src_gid, invert=True)
                target_src_pair = np.array(target_src_pair)[mask].tolist()
            gid_pairs[src_gid] = target_src_pair
        conn["src_type"] = self.gid_to_type(src_gids[0])
        conn["src_gids"] = set(src_gids)
        conn["num_srcs"] = len(src_gids)
        conn["gid_pairs"] = gid_pairs
        # Validate string inputs
        _validate_type(loc, str, "loc")
        _validate_type(receptor, str, "receptor")
        target_sect_loc = self.cell_types[target_type]["cell_object"].sect_loc
        target_sections = self.cell_types[target_type]["cell_object"].sections
        valid_loc = list(target_sect_loc.keys()) + list(target_sections.keys())
        _check_option(
            "loc",
            loc,
            valid_loc,
            extra=(f" (the loc '{loc}' is not defined for '{target_type}' cells)"),
        )
        conn["loc"] = loc
        # `loc` specifies a group of sections, all must contain the synapse
        # specified by `receptor`
        if loc in target_sect_loc:
            for sec_name in target_sect_loc[loc]:
                valid_receptor = target_sections[sec_name].syns
                _check_option(
                    "receptor",
                    receptor,
                    valid_receptor,
                    extra=f" (the '{receptor}' receptor is not "
                    f"defined for the '{sec_name}' of"
                    f"'{target_type}' cells)",
                )
        # `loc` specifies an individual section
        else:
            valid_receptor = target_sections[loc].syns
            _check_option(
                "receptor",
                receptor,
                valid_receptor,
                extra=f"(the '{receptor}' receptor is not "
                f"defined for the '{loc}' of"
                f"'{target_type}' cells)",
            )
        conn["receptor"] = receptor
        # Create and validate nc_dict
        conn["nc_dict"] = dict()
        arg_names = ["delay", "weight", "lamtha", "threshold", "gain"]
        nc_dict_keys = ["A_delay", "A_weight", "lamtha", "threshold", "gain"]
        nc_conn_items = [delay, weight, lamtha, threshold, gain]
        for key, arg_name, item in zip(nc_dict_keys, arg_names, nc_conn_items):
            _validate_type(item, (int, float), arg_name, "int or float")
            if arg_name == "gain":
                if item < 0.0:
                    raise ValueError(
                        f"Synaptic gains must be non-negative."
                        f"Got {gain} for connection {conn['src_type']}->{conn['target_type']}."
                    )
            conn["nc_dict"][key] = item
        # Probabilistically define connections
        if probability != 1.0:
            _connection_probability(conn, probability, conn_seed)
        conn["probability"] = probability
        conn["allow_autapses"] = allow_autapses
        self.connectivity.append(deepcopy(conn)) 
[docs]
    def clear_connectivity(self):
        """Remove all connections defined in Network.connectivity"""
        connectivity = list()
        for conn in self.connectivity:
            if conn["src_type"] in self.external_drives.keys():
                connectivity.append(conn)
        self.connectivity = connectivity 
[docs]
    def clear_drives(self):
        """Remove all drives defined in Network.connectivity"""
        self.connectivity = [
            conn
            for conn in self.connectivity
            if conn["src_type"] not in self.external_drives.keys()
        ]
        for cell_name in list(self.gid_ranges.keys()):
            if cell_name in self.external_drives:
                self._n_gids -= len(self.gid_ranges[cell_name])
                del self.gid_ranges[cell_name]
                del self.pos_dict[cell_name]
        self.external_drives = dict() 
[docs]
    def add_electrode_array(
        self, name, electrode_pos, *, conductivity=0.3, method="psa", min_distance=0.5
    ):
        """Specify coordinates of electrode array for extracellular recording.
        Parameters
        ----------
        name : str
            Unique name of the array.
        electrode_pos : tuple | list of tuple
            Coordinates specifying the position for extracellular electrodes in
            the form of (x, y, z) (in um).
        conductivity : float
            Extracellular conductivity, in S/m, of the assumed infinite,
            homogeneous volume conductor that the cell and electrode are in.
        method : str
            Approximation to use. ``'psa'`` (point source approximation) treats
            each segment junction as a point extracellular current source.
            ``'lsa'`` (line source approximation) treats each segment as a line
            source of current, which extends from the previous to the next
            segment center point: /---x---/, where x is the current segment
            flanked by /.
        min_distance : float (default: 0.5; unit: um)
            To avoid numerical errors in calculating potentials, apply a
            minimum distance limit between the electrode contacts and the
            active neuronal membrane elements that act as sources of current.
            The default value of 0.5 um corresponds to 1 um diameter dendrites.
        """
        _validate_type(name, str, "name")
        if name in self.rec_arrays.keys():
            raise ValueError(f"{name} already exists, use another name!")
        # let ExtracellularArray perform all remaining argument checks
        self.rec_arrays.update(
            {
                name: ExtracellularArray(
                    electrode_pos,
                    conductivity=conductivity,
                    method=method,
                    min_distance=min_distance,
                )
            }
        ) 
[docs]
    def set_global_synaptic_gains(
        self, e_e=None, e_i=None, i_e=None, i_i=None, copy=False
    ):
        """Change the synaptic gains of the celltypes in the Network.
        Parameters
        ----------
        e_e : float, default=None
            Synaptic gain of excitatory to excitatory connections
        e_i : float, default=None
            Synaptic gain of excitatory to inhibitory connections
        i_e : float, default=None
            Synaptic gain of inhibitory to excitatory connections
        i_i : float, default=None
            Synaptic gain of inhibitory to inhibitory connections
        copy : bool, default=False
            If True, returns a copy of the network. If False,
            the network is updated in place with a return of None.
        Returns
        -------
        net : instance of Network
            A copy of the instance with updated synaptic gains if copy=True.
        Notes
        -----
        Synaptic gains must be non-negative. The synaptic gains will only be
        updated if a float value is provided. If None is provided
        (the default), the synaptic gain will remain unchanged.
        This does **not** change the synaptic gains of external drives.
        """
        _validate_type(copy, bool, "copy")
        net = self.copy() if copy else self
        e_gids, i_gids = _get_cell_index_by_synapse_type(self)
        # Define the connection types to modify
        conn_types = {
            "e_e": (e_e, e_gids, e_gids),
            "e_i": (e_i, e_gids, i_gids),
            "i_e": (i_e, i_gids, e_gids),
            "i_i": (i_i, i_gids, i_gids),
        }
        for conn_type, (gain, src_gids, target_gids) in conn_types.items():
            if gain is None:
                continue
            _validate_type(gain, (int, float), conn_type, "int or float")
            if gain < 0.0:
                raise ValueError(
                    f"Synaptic gains must be non-negative. Got {gain} for '{conn_type}'."
                )
            conn_indices = pick_connection(
                net, src_gids=src_gids, target_gids=target_gids
            )
            for conn_idx in conn_indices:
                net.connectivity[conn_idx]["nc_dict"]["gain"] = gain
        if copy:
            return net 
[docs]
    def get_global_synaptic_gains(self):
        """Retrieve gain values for different celltype connections in the Network.
        This function identifies excitatory and inhibitory cells in the Network
        and retrieves the `gain` value for each type of synaptic connection:
            - excitatory to excitatory (e_e)
            - excitatory to inhibitory (e_i)
            - inhibitory to excitatory (i_e)
            - inhibitory to inhibitory (i_i)
        The gain is assumed to be uniform across all instances of each connection type
        (for example, between AMPA and NMDA, and between `L2_pyramidal->L2_pyramidal` and
        `L2_pyramidal->L5_pyramidal`, etc.). Only the first connection's gain value is
        used for each type.
        This does **not** return the synaptic gains of external drives.
        Returns
        -------
        values : dict
            A dictionary with the connection types ('e_e', 'e_i', 'i_e', 'i_i') as keys
            and their corresponding gain values.
        """
        e_gids, i_gids = _get_cell_index_by_synapse_type(self)
        # Define the connection types and source/target cell indexes
        conn_types = {
            "e_e": (e_gids, e_gids),
            "e_i": (e_gids, i_gids),
            "i_e": (i_gids, e_gids),
            "i_i": (i_gids, i_gids),
        }
        # Retrieve the gain value for each connection type
        values = {}
        for conn_type, (src_idxs, target_idxs) in conn_types.items():
            picks = pick_connection(self, src_gids=src_idxs, target_gids=target_idxs)
            if picks:
                # Extract the gain from the first connection
                values[conn_type] = self.connectivity[picks[0]]["nc_dict"]["gain"]
        # This writes the warning to stdout
        _check_global_synaptic_gains_uniformity(self)
        return values 
[docs]
    def plot_cells(self, ax=None, show=True):
        """Plot the cells using Network.pos_dict.
        Parameters
        ----------
        ax : instance of matplotlib Axes3D | None
            An axis object from matplotlib. If None,
            a new figure is created.
        show : bool
            If True, show the figure.
        Returns
        -------
        fig : instance of matplotlib Figure
            The matplotlib figure handle.
        """
        return plot_cells(net=self, ax=ax, show=show) 
    def to_dict(self, write_output=False):
        return network_to_dict(self, write_output=write_output)
[docs]
    @copy_doc(write_network_configuration)
    def write_configuration(self, fname, overwrite=True):
        write_network_configuration(self, fname, overwrite) 
[docs]
    def filter_cell_types(self, **metadata_filters):
        """
        Filter cell types based on cell_metadata criteria
        """
        filtered_types = []
        for cell_type_name, cell_type_data in self.cell_types.items():
            cell_metadata = cell_type_data["cell_metadata"]
            match = True
            for key, value in metadata_filters.items():
                if key not in cell_metadata or cell_metadata[key] != value:
                    match = False
                    break
            if match:
                filtered_types.append(cell_type_name)
        return filtered_types 
    def _standardize_spike_data(self, spike_data):
        """Standardize spike data to internal format with 'times' and 'gids' keys.
        Parameters
        ----------
        spike_data : dict or list or str
            Input spike data in one of three formats:
            - Format 1: Dictionary where keys are source identifiers and values are
              lists of spike times in ms.
              Example: {"NetA_L2_pyramidal_GID0": [10.2, 25.3], ...}
            - Format 2: List of (time, gid) tuples where time is the spike time in ms
              and gid identifies the source cell.
              Example: [(10.2, 0), (15.6, 1), (25.3, 0)]
            - Format 3: String path (or glob pattern) to spike files that can be loaded
              with hnn_core.read_spikes(), like "path/to/spk_*.txt"
        Returns
        -------
        standardized_data : dict
            Dictionary with 'times' and 'gids' keys containing spike information
            in standardized internal format
        n_drive_cells : int
            Number of unique source cells detected
        source_to_gid_map : dict or None
            Mapping from source identifiers to sequential GIDs (for Format 1),
            or None (for Format 2 or Format 3)
        """
        source_to_gid_map = None
        if isinstance(spike_data, dict):
            # Format 1: {source_id: [spike_times], ...}
            source_ids = list(spike_data.keys())
            n_drive_cells = len(source_ids)
            # Transform to standardized format
            all_times = []
            all_gids = []
            # Map source IDs to sequential gids
            source_to_gid_map = {src_id: i for i, src_id in enumerate(source_ids)}
            # Collect all spike times and corresponding gids
            for src_id, times in spike_data.items():
                gid = source_to_gid_map[src_id]
                if isinstance(times, (list, np.ndarray)):
                    all_times.extend(times)
                    all_gids.extend([gid] * len(times))
                else:
                    raise ValueError(
                        f"Spike times for source '{src_id}' must be a list or array. "
                        f"Got {type(times)}."
                    )
            standardized_data = {
                "times": all_times,
                "gids": all_gids,
            }
        elif isinstance(spike_data, list) and all(
            isinstance(x, tuple) and len(x) == 2 for x in spike_data
        ):
            # Format 2: List of (time, gid) tuples
            times = [pair[0] for pair in spike_data]
            gids = [pair[1] for pair in spike_data]
            # Count unique drive cells
            unique_gids = np.unique(gids)
            n_drive_cells = len(unique_gids)
            # Ensure gids are sequential from 0 to n-1
            if len(unique_gids) > 0:
                if (
                    np.min(unique_gids) != 0
                    or np.max(unique_gids) != len(unique_gids) - 1
                ):
                    # Reindex gids to be 0-based sequential integers
                    gid_map = {
                        old_gid: new_gid
                        for new_gid, old_gid in enumerate(sorted(unique_gids))
                    }
                    new_gids = [gid_map[gid] for gid in gids]
                    standardized_data = {
                        "times": times,
                        "gids": new_gids,
                    }
                else:
                    standardized_data = {
                        "times": times,
                        "gids": gids,
                    }
            else:
                standardized_data = {"times": [], "gids": []}
        elif isinstance(spike_data, str):
            # Format 3: Handle string input as file path
            try:
                # Read spike data from file
                cell_response = read_spikes(spike_data)
            except Exception as e:
                raise ValueError(
                    f"Error loading spike data from file '{spike_data}': {str(e)}"
                )
            # By default, use the first trial
            trial_idx = 0
            if trial_idx >= len(cell_response.spike_times):
                raise ValueError(
                    f"Trial index {trial_idx} exceeds available trials "
                    f"({len(cell_response.spike_times)})"
                )
            # Extract spike data from specified trial
            spike_times = cell_response.spike_times[trial_idx]
            spike_gids = cell_response.spike_gids[trial_idx]
            spike_types = cell_response.spike_types[trial_idx]
            # Convert to dictionary format (Format 1)
            spike_data_dict = {}
            for t, g, cell_type in zip(spike_times, spike_gids, spike_types):
                src_id = f"{cell_type}_GID{g}"
                if src_id not in spike_data_dict:
                    spike_data_dict[src_id] = []
                spike_data_dict[src_id].append(t)
            # Recursively call this function with the dictionary data
            return self._standardize_spike_data(spike_data_dict)
        else:
            raise ValueError(
                "spike_data must be either:\n"
                "1. A dictionary {source_id: [spike_times], ...}\n"
                "2. A list of (time, gid) tuples\n"
                "3. A file path string loadable with read_spikes()\n"
                f"Got {type(spike_data)}."
            )
        return standardized_data, n_drive_cells, source_to_gid_map 
class _Connectivity(dict):
    """A class for containing the connectivity details of the network
    Class instances are essentially dictionaries, with the keys described below
    as 'attributes'.
    Attributes
    ----------
    src_type : str
        Cell type of source gids.
    target_type : str
        Cell type of target gids.
    gid_pairs : dict
        dict indexed by src gids with the format:
        {src_gid: [target_gids, ...], ...}
        where each src_gid indexes a list of all its targets.
    num_srcs : int
        Number of unique source gids.
    num_targets : int
        Number of unique target gids.
    src_gids : set of int
        Set of unique source gids in connection.
    target_gids : set of int
        Set of unique target gids in connection.
    loc : str
        Location of synapse on target cell. Must be
        'proximal', 'distal', or 'soma'. Note that inhibitory synapses
        (receptor='gabaa' or 'gabab') of L2 pyramidal neurons are only
        valid loc='soma'.
    receptor : str
        Synaptic receptor of connection. Must be one of:
        'ampa', 'nmda', 'gabaa', or 'gabab'.
    nc_dict : dict
        Dictionary containing details of synaptic connection.
        Elements include:
        A_weight : float
            Synaptic weight on target cell.
        A_delay : float
            Synaptic delay in ms.
        lamtha : float
            Space constant.
        threshold : float
            Firing threshold of cells for connection.
        gain : float
            Multiplicative factor for synaptic weight.
    probability : float
        Probability of connection between any src-target pair.
        Defaults to 1.0 producing an all-to-all pattern.
    Notes
    -----
    The len() of src_range or target_range will not match
    num_srcs and num_targets for probability < 1.0.
    """
    def __repr__(self):
        entr = f"{self['src_type']} -> {self['target_type']}"
        entr += f"\ncell counts: {self['num_srcs']} srcs, "
        entr += f"{self['num_targets']} targets"
        entr += f"\nconnection probability: {self['probability']} "
        entr += f"\nloc: '{self['loc']}'; receptor: '{self['receptor']}'"
        entr += f"\nweight: {self['nc_dict']['A_weight']}; "
        entr += f"delay: {self['nc_dict']['A_delay']}; "
        entr += f"lamtha: {self['nc_dict']['lamtha']}"
        entr += f"threshold: {self['nc_dict']['threshold']}"
        entr += f"gain: {self['nc_dict']['gain']}"
        entr += "\n "
        return entr
class _NetworkDrive(dict):
    """A class for containing the parameters of external drives
    Class instances are essentially dictionaries, with keys described below
    as 'attributes'. For example, drive['events'] contains the spike times of
    exogeneous inputs.
    Attributes
    ----------
    name : str
        Name of drive (must be unique)
    location : str
        Target location of synapses ('distal' or 'proximal').
    type : str
        Examples: 'evoked', 'gaussian', 'poisson', 'bursty'
    events : list of lists
        List of spike time lists. First index is of length n_trials. Second
        index is over the 'artificial' cells associated with this drive.
    n_drive_cells : int
        The number of drive cells that contribute to this drive.
    cell_specific : bool
        Whether each cell has unique connection parameters (default: True)
        or all cells have common connections to a global (single) drive.
    event_seed : int
        Optional initial seed for random number generator used for event times.
        Each artificial drive cell has seed = event_seed + gid
    conn_seed : int
        Optional initial seed for random number generator.
        Used to randomly remove connections when probability < 1.0.
    target_types : set or list of str
        Names of cell types targeted by this drive (must be subset of
        net.cell_types.keys()).
    dynamics : dict
        Parameters describing how the temporal dynamics of spike trains in the
        drive. The keys are specific to the type of drive ('evoked', 'bursty',
        etc.). See the drive add-methods in Network for details.
    """
    def __repr__(self):
        entr = f"<External drive '{self['name']}'"
        if "type" in self.keys():
            entr += f"\ndrive class: {self['type']}"
            entr += f"\ntarget location: {self['location']}"
            entr += f"\ntarget cell types: {self['target_types']}"
            entr += f"\nnumber of drive cells: {self['n_drive_cells']}"
            entr += f"\ncell-specific: {self['cell_specific']}"
            entr += "\ndynamic parameters:"
            for key, val in self["dynamics"].items():
                entr += f"\n\t{key}: {val}"
        if len(self["events"]) > 0:
            plurl = "s" if len(self["events"]) > 1 else ""
            entr += f"\nevent times instantiated for {len(self['events'])} trial{plurl}"
        entr += ">"
        return entr
def _add_cell_type_bias(
    network: Network,
    amplitude: float,
    cell_type: str,
    section="soma",
    bias_name="tonic",
    t_0=0,
    t_stop=None,
):
    """Add a tonic bias to a specific cell type in the network.
    Parameters
    ----------
    network : Network
        The network to which the tonic bias is added.
    amplitude : float
        The amplitude of the tonic input (in nA) applied to the specified
        `cell_type`.
    cell_type : str
        The cell type to which the bias is applied.
    section : str, default 'soma'
        The section of the cell where the bias is applied (e.g., 'soma',
        'apical_tuft').
    bias_name : str, default 'tonic'
        A name identifier for the bias configuration, allowing multiple biases
        to be applied.
    t_0 : float, default 0
        The start time of the tonic input in milliseconds.
    t_stop : float, optional
        The end time of the tonic input in milliseconds. If None, the bias
        continues until the end of the simulation.
    """
    # Validate cell_type value
    if cell_type not in network.cell_types:
        raise ValueError(
            f"cell_type must be one of "
            f"{list(network.cell_types.keys())}. "
            f"Got {cell_type}"
        )
    if bias_name not in network.external_biases:
        network.external_biases[bias_name] = dict()
    if cell_type in network.external_biases[bias_name]:
        raise ValueError(f"Bias named {bias_name} already defined for {cell_type}")
    cell_type_bias = {
        "amplitude": amplitude,
        "t0": t_0,
        "tstop": t_stop,
        "section": section,
    }
    sections = list(network.cell_types[cell_type]["cell_object"].sections.keys())
    # error when section is defined that doesn't exist.
    if section not in sections:
        raise ValueError(f"section must be one of {sections}. Got {section}.")
    else:
        cell_type_bias["section"] = section
    network.external_biases[bias_name][cell_type] = cell_type_bias
def _check_global_synaptic_gains_uniformity(net):
    """Check whether gain values are uniform within their type of connection.
    This function identifies excitatory and inhibitory cells in the Network
    and retrieves the gain value for each type of synaptic connection:
    - excitatory to excitatory (e_e)
    - excitatory to inhibitory (e_i)
    - inhibitory to excitatory (i_e)
    - inhibitory to inhibitory (i_i)
    The gain is then checked to see if it is uniform for all instances within each
    connection type (for example, between AMPA and NMDA, and between different
    connections like L2_pyramidal->L2_pyramidal and L2_pyramidal->L5_pyramidal). This
    does **not** check the synaptic gains of external drives.
    Returns
    -------
    output_indicator : bool
        A truth value indicating whether the synaptic gains are uniform within their
        connection type (True) or non-uniform (False).
    """
    e_gids, i_gids = _get_cell_index_by_synapse_type(net)
    # Define the connection types and source/target cell indexes
    conn_types = {
        "e_e": (e_gids, e_gids),
        "e_i": (e_gids, i_gids),
        "i_e": (i_gids, e_gids),
        "i_i": (i_gids, i_gids),
    }
    output_indicator = True
    # Retrieve the gain value for each connection type
    for conn_type, (src_idxs, target_idxs) in conn_types.items():
        picks = pick_connection(net, src_gids=src_idxs, target_gids=target_idxs)
        first_value = net.connectivity[picks[0]]["nc_dict"]["gain"]
        for other_idx in range(1, len(picks)):
            if not bool(
                np.isclose(
                    net.connectivity[picks[other_idx]]["nc_dict"]["gain"],
                    first_value,
                )
            ):
                output_indicator = False
                print(
                    """
                    WARNING: Your imported Network uses custom synaptic gain values. Global synaptic gain values such as "Excitatory-to-Inhibitory" etc. will NOT be read or displayed properly. This is because Global synaptic gain values assume that initially, all gains are the same. If you continue to modify your Global synaptic gain values, double-check each connection's final synaptic gain value. To stop this warning, change your synaptic weights instead of your synaptic gains.
                    """
                )
                break
    return output_indicator