Source code for hnn_core.batch_simulate

"""Batch Simulation."""

# Authors: Abdul Samad Siddiqui <abdulsamadsid1@gmail.com>
#          Nick Tolley <nicholas_tolley@brown.edu>
#          Ryan Thorpe <ryan_thorpe@brown.edu>
#          Mainak Jas <mjas@mgh.harvard.edu>

import numpy as np
import os
from joblib import Parallel, delayed, parallel_config

from .network import Network
from .externals.mne import _validate_type, _check_option
from .dipole import simulate_dipole
from .network_models import jones_2009_model


[docs] class BatchSimulate(object): """The BatchSimulate class. Parameters ---------- set_params : func User-defined function that sets parameters in network drives. ``set_params(net, params) -> None`` where ``net`` is a Network object and ``params`` is a dictionary of the parameters that will be set inside the function. net : Network object, optional The network model to use for simulations. Examples include the returned value of the following functions: - `jones_2009_model`: A network model based on Jones et al. (2009). - `law_2021_model`: A network model based on Law et al. (2021). - `calcium_model`: A network model incorporating calcium dynamics. Default is ``jones_2009_model()``. tstop : float, optional The stop time for the simulation. Default is 170 ms. dt : float, optional The time step for the simulation. Default is 0.025 ms. n_trials : int, optional The number of trials for the simulation. Default is 1. save_folder : str, optional The path to save the simulation outputs. Default is './sim_results'. batch_size : int, optional The maximum number of simulations saved in a single file. Default is 100. overwrite : bool, optional Whether to overwrite existing files and create file paths if they do not exist. Default is True. save_outputs : bool, optional Whether to save the simulation outputs to files. Default is False. save_dpl : bool, optional If True, save dipole results. Note, `save_outputs` must be True. Default: True. save_spiking : bool, optional If True, save spiking results. Note, `save_outputs` must be True. Default: False. save_lfp : bool, optional If True, save local field potential (lfp) results. Note, `save_outputs` must be True. Default: False. save_voltages : bool, optional If True, save voltages results. Note, `save_outputs` must be True. Default: False. save_currents : bool, optional If True, save currents results. Note, `save_outputs` must be True. Default: False. save_calcium : bool, optional If True, save calcium concentrations. Note, `save_outputs` must be True. Default: False. record_vsec : {False, 'all', 'soma'} Option to record voltages from all sections ('all'), or just the soma ('soma'). Default: False. record_isec : {False, 'all', 'soma'} Option to record voltages from all sections ('all'), or just the soma ('soma'). Default: False. postproc : bool, optional If True, smoothing (``dipole_smooth_win``) and scaling (``dipole_scalefctr``) values are read from the parameter file, and applied to the dipole objects before returning. Default: False. clear_cache : bool, optional Whether to clear the results cache after saving each batch. Default is False. summary_func : func, optional A function to calculate summary statistics from the simulation results. Default is None. Notes ----- When ``save_output=True``, the saved files will appear as ``sim_run_{start_idx}-{end_idx}.npz`` in the specified `save_folder` directory. The `start_idx` and `end_idx` indicate the range of simulation indices contained in each file. Each file will contain a maximum of `batch_size` simulations, split evenly among the available files. If ``overwrite=True``, existing files with the same name will be overwritten. """ def __init__( self, set_params, net=jones_2009_model(), tstop=170, dt=0.025, n_trials=1, save_folder="./sim_results", batch_size=100, overwrite=True, save_outputs=False, save_dpl=True, save_spiking=False, save_lfp=False, save_voltages=False, save_currents=False, save_calcium=False, record_vsec=False, record_isec=False, postproc=False, clear_cache=False, summary_func=None, ): _validate_type(net, Network, "net", "Network") _validate_type(tstop, types="numeric", item_name="tstop") _validate_type(dt, types="numeric", item_name="dt") _validate_type(n_trials, types="int", item_name="n_trials") _validate_type(save_folder, types="path-like", item_name="save_folder") _validate_type(batch_size, types="int", item_name="batch_size") _validate_type(save_outputs, types=(bool,), item_name="save_outputs") _validate_type(save_dpl, types=(bool,), item_name="save_dpl") _validate_type(save_spiking, types=(bool,), item_name="save_spiking") _validate_type(save_lfp, types=(bool,), item_name="save_lfp") _validate_type(save_voltages, types=(bool,), item_name="save_voltages") _validate_type(save_currents, types=(bool,), item_name="save_currents") _validate_type(save_calcium, types=(bool,), item_name="save_calcium") _check_option("record_vsec", record_vsec, ["all", "soma", False]) _check_option("record_isec", record_isec, ["all", "soma", False]) _validate_type(clear_cache, types=(bool,), item_name="clear_cache") if set_params is not None and not callable(set_params): raise TypeError("set_params must be a callable function") if summary_func is not None and not callable(summary_func): raise TypeError("summary_func must be a callable function") self.set_params = set_params self.net = net self.tstop = tstop self.dt = dt self.n_trials = n_trials self.save_folder = save_folder self.batch_size = batch_size self.overwrite = overwrite self.save_outputs = save_outputs self.save_dpl = save_dpl self.save_spiking = save_spiking self.save_lfp = save_lfp self.save_voltages = save_voltages self.save_currents = save_currents self.save_calcium = save_calcium self.record_vsec = record_vsec self.record_isec = record_isec self.postproc = postproc self.clear_cache = clear_cache self.summary_func = summary_func
[docs] def run( self, param_grid, return_output=True, combinations=True, n_jobs=1, backend="loky", verbose=50, ): """Run batch simulations. Parameters ---------- param_grid : dict Dictionary with parameter names and ranges. return_output : bool, optional Whether to return the simulation outputs. Default is True. combinations : bool, optional Whether to generate the Cartesian product of the parameter ranges. If False, generate combinations based on corresponding indices. Default is True. n_jobs : int, optional Number of parallel jobs. Default is 1. backend : str or joblib.parallel.ParallelBackendBase instance, optional The parallel backend to use. Can be one of `loky`, `threading`, `multiprocessing`, or `dask`. WARNING: currently only `loky` is completely operationable; all other backends are in development. Default is `loky`. verbose : int, optional The verbosity level for parallel execution. Default is 50. Returns ------- results : dict Dictionary containing 'summary_statistics' and optionally 'simulated_data'. 'simulated_data' may include keys: 'dpl', 'lfp', 'spikes', 'voltages', 'param_values', 'net', 'times'. Notes ----- Return content depends on `summary_func`, `return_output`, and `clear_cache` settings. """ _validate_type(param_grid, types=(dict,), item_name="param_grid") _validate_type(n_jobs, types="int", item_name="n_jobs") _check_option( "backend", backend, ["loky", "threading", "multiprocessing", "dask"] ) _validate_type(verbose, types="int", item_name="verbose") param_combinations = self._generate_param_combinations(param_grid, combinations) total_sims = len(param_combinations) batch_size = min(self.batch_size, total_sims) results = [] simulated_data = [] for i in range(0, total_sims, batch_size): start_idx = i end_idx = min(i + batch_size, total_sims) batch_results = self.simulate_batch( param_combinations[start_idx:end_idx], n_jobs=n_jobs, backend=backend, verbose=verbose, ) if self.save_outputs: self._save(batch_results, start_idx, end_idx) if self.summary_func is not None: summary_statistics = self.summary_func(batch_results) results.append(summary_statistics) if not self.clear_cache: simulated_data.append(batch_results) elif return_output and not self.clear_cache: simulated_data.append(batch_results) if self.clear_cache: del batch_results if return_output: if self.clear_cache: return {"summary_statistics": results} else: return {"summary_statistics": results, "simulated_data": simulated_data}
[docs] def simulate_batch( self, param_combinations, n_jobs=1, backend="loky", verbose=50, ): """Simulate a batch of parameter sets in parallel. Parameters ---------- param_combinations : list List of parameter combinations. n_jobs : int, optional Number of parallel jobs. Default is 1. backend : str or joblib.parallel.ParallelBackendBase instance, optional The parallel backend to use. Can be one of `loky`, `threading`, `multiprocessing`, or `dask`. WARNING: currently only `loky` is completely operationable; all other backends are in development. Default is `loky`. verbose : int, optional The verbosity level for parallel execution. Default is 50. Returns ------- res: list List of dictionaries containing simulation results. Each dictionary contains the following keys along with their associated values: - `net`: The network model used for the simulation. - `dpl`: The simulated dipole. - `param_values`: The parameter values used for the simulation. """ _validate_type( param_combinations, types=(list,), item_name="param_combinations" ) _validate_type(n_jobs, types="int", item_name="n_jobs") _check_option( "backend", backend, ["loky", "threading", "multiprocessing", "dask"] ) _validate_type(verbose, types="int", item_name="verbose") with parallel_config(backend=backend): res = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(self._run_single_sim)(params) for params in param_combinations ) return res
def _run_single_sim(self, param_values): """Run a single simulation. Parameters ---------- param_values : dict Dictionary of parameter values. Returns ------- dict Dictionary containing the simulation results. The dictionary contains the following keys: - `net`: The network model used for the simulation. - `dpl`: The simulated dipole. - `param_values`: The parameter values used for the simulation. """ net = self.net.copy() self.set_params(param_values, net) results = {"net": net, "param_values": param_values} if self.save_dpl: dpl = simulate_dipole( net, tstop=self.tstop, dt=self.dt, n_trials=self.n_trials, record_vsec=self.record_vsec, record_isec=self.record_isec, postproc=self.postproc, ) results["dpl"] = dpl if self.save_spiking: results["spiking"] = { "spike_times": net.cell_response.spike_times, "spike_types": net.cell_response.spike_types, "spike_gids": net.cell_response.spike_gids, } if self.save_lfp: results["lfp"] = net.rec_arrays if self.save_voltages: results["voltages"] = net.cell_response.vsec if self.save_currents: results["currents"] = net.cell_response.isec if self.save_calcium: results["calcium"] = net.cell_response.ca return results def _generate_param_combinations(self, param_grid, combinations=True): """Generate combinations of parameters from the grid. Parameters ---------- param_grid : dict Dictionary with parameter names and ranges. combinations : bool, optional Whether to generate the Cartesian product of the parameter ranges. If False, generate combinations based on corresponding indices. Default is True. Returns ------- param_combinations: list List of parameter combinations. """ from itertools import product keys, values = zip(*param_grid.items()) if combinations: param_combinations = [ dict(zip(keys, combination)) for combination in product(*values) ] else: param_combinations = [ dict(zip(keys, combination)) for combination in zip(*values) ] return param_combinations def _save(self, results, start_idx, end_idx): """Save simulation results to files. Parameters ---------- results : list The results to save. start_idx : int The index of the first simulation in this batch. end_idx : int The index of the last simulation in this batch. """ _validate_type(results, types=(list,), item_name="results") _validate_type(start_idx, types="int", item_name="start_idx") _validate_type(end_idx, types="int", item_name="end_idx") if not os.path.exists(self.save_folder): os.makedirs(self.save_folder) save_data = {"param_values": [result["param_values"] for result in results]} attributes_to_save = [ "dpl", "spiking", "lfp", "voltages", "currents", "calcium", ] for attr in attributes_to_save: if getattr(self, f"save_{attr}") and attr in results[0]: save_data[attr] = [result[attr] for result in results] metadata = { "batch_size": self.batch_size, "n_trials": self.n_trials, "tstop": self.tstop, "dt": self.dt, } save_data["metadata"] = metadata file_name = os.path.join(self.save_folder, f"sim_run_{start_idx}-{end_idx}.npz") if os.path.exists(file_name) and not self.overwrite: raise FileExistsError( f"File {file_name} already exists and overwrite is set to False." ) np.savez(file_name, **save_data)
[docs] def load_results(self, file_path, return_data=None): """Load simulation results from a file. Parameters ---------- file_path : str The path to the file containing the simulation results. return_data : list of str, optional List of data types to return. If None, returns the types specified during initialization. Defaults to None. Returns ------- results : dict Dictionary containing the loaded results and parameter values. """ if return_data is None: return_data = [] if self.save_dpl: return_data.append("dpl") if self.save_spiking: return_data.append("spiking") if self.save_lfp: return_data.append("lfp") if self.save_voltages: return_data.append("voltages") if self.save_currents: return_data.append("currents") if self.save_calcium: return_data.append("calcium") data = np.load(file_path, allow_pickle=True) results = { key: data[key].tolist() for key in data.files if key in return_data or key == "param_values" } return results
[docs] def load_all_results(self): """Load all simulation results from the files in `self.save_folder`. Returns ------- all_results : list List of dictionaries containing all loaded simulation results. """ all_results = [] for file_name in os.listdir(self.save_folder): if file_name.startswith("sim_run_") and file_name.endswith(".npz"): file_path = os.path.join(self.save_folder, file_name) results = self.load_results(file_path) all_results.append(results) return all_results