Source code for hnn_core.gui.gui

"""IPywidgets GUI."""

# Authors: Mainak Jas <mjas@mgh.harvard.edu>
#          Huzi Cheng <hzcheng15@icloud.com>
import codecs
import io
import logging
import multiprocessing
import sys
import urllib.parse
import urllib.request
from collections import defaultdict
from pathlib import Path
from IPython.display import IFrame, display
from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
                        BoundedIntText, Button, Dropdown, FileUpload, VBox,
                        HBox, IntText, Layout, Output, RadioButtons, Tab, Text)
from ipywidgets.embed import embed_minimal_html
import hnn_core
from hnn_core import (JoblibBackend, MPIBackend, jones_2009_model, read_params,
                      simulate_dipole)
from hnn_core.gui._logging import logger
from hnn_core.gui._viz_manager import _VizManager, _idx2figname
from hnn_core.network import pick_connection
from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json,
                             _read_legacy_params)


class _OutputWidgetHandler(logging.Handler):
    def __init__(self, output_widget, *args, **kwargs):
        super(_OutputWidgetHandler, self).__init__(*args, **kwargs)
        self.out = output_widget

    def emit(self, record):
        formatted_record = self.format(record)
        new_output = {
            'name': 'stdout',
            'output_type': 'stream',
            'text': formatted_record + '\n'
        }
        self.out.outputs = (new_output, ) + self.out.outputs


[docs]class HNNGUI: """HNN GUI class Parameters ---------- theme_color : str The theme color of the whole dashboard. total_height : int The height of the GUI (in pixel, same for all following parameters). total_width : int The width of the GUI. header_height : int The height of the header. button_height : int The height of buttons. operation_box_height : int The operation_box_height of operations box. drive_widget_width : int The width of GUI drive box. left_sidebar_width : int The width of left sidebad. log_window_height : int The height of logging window. status_height : int The height of status bar. dpi : int The screen dpi. Attributes ---------- layout : dict The styling configuration of GUI. params : dict The parameters to use for constructing the network. simulation_data : dict Simulation related objects, such as net and dpls. widget_tstop : Widget Simulation stop time widget. widget_dt : Widget Simulation step size widget. widget_ntrials : Widget Widget that controls the number of trials in a single simulation. widget_backend_selection : Widget Widget that selects the backend used in simulations. widget_viz_layout_selection : Widget Widget that selects the layout of visualization window. widget_mpi_cmd : Widget Widget that specify the mpi command to use when the backend is MPIBackend. widget_n_jobs : Widget Widget that specify the cores in multi-trial simulations. widget_drive_type_selection : Widget Widget that is used to select the drive to be added to the network. widget_location_selection : Widget. Widget that specifies the location of network drives. Could be proximal or distal. add_drive_button : Widget Clickable widget that is used to add a drive to the network. run_button : Widget Clickable widget that triggers simulation. load_button : Widget Clickable widget that receives uploaded parameter files. delete_drive_button : Widget Clickable widget that clear all existing network drives. plot_outputs_dict : list A list of visualization panel outputs. plot_dropdown_types_dict : list A list of dropdown menus that control the plot types in plot_outputs_dict. drive_widgets : list A list of network drive widgets added by add_drive_button. drive_boxes : list A list of network drive layouts. connectivity_textfields : list A list of boxes that control the weight and probability of connections in the network. """ def __init__(self, theme_color="#8A2BE2", total_height=800, total_width=1300, header_height=50, button_height=30, operation_box_height=60, drive_widget_width=200, left_sidebar_width=576, log_window_height=150, status_height=30, dpi=96, ): # set up styling. self.total_height = total_height self.total_width = total_width viz_win_width = self.total_width - left_sidebar_width main_content_height = self.total_height - status_height config_box_height = main_content_height - (log_window_height + operation_box_height) self.layout = { "dpi": dpi, "header_height": f"{header_height}px", "theme_color": theme_color, "btn": Layout(height=f"{button_height}px", width='auto'), "btn_full_w": Layout(height=f"{button_height}px", width='100%'), "del_fig_btn": Layout(height=f"{button_height}px", width='auto'), "log_out": Layout(border='1px solid gray', height=f"{log_window_height-10}px", overflow='auto'), "viz_config": Layout(width='99%'), "visualization_window": Layout( width=f"{viz_win_width-10}px", height=f"{main_content_height-10}px", border='1px solid gray', overflow='scroll'), "visualization_output": Layout( width=f"{viz_win_width-50}px", height=f"{main_content_height-100}px", border='1px solid gray', overflow='scroll'), "left_sidebar": Layout(width=f"{left_sidebar_width}px", height=f"{main_content_height}px"), "left_tab": Layout(width=f"{left_sidebar_width}px", height=f"{config_box_height}px"), "operation_box": Layout(width=f"{left_sidebar_width}px", height=f"{operation_box_height}px", flex_wrap="wrap", ), "config_box": Layout(width=f"{left_sidebar_width}px", height=f"{config_box_height-100}px"), "drive_widget": Layout(width="auto"), "drive_textbox": Layout(width='270px', height='auto'), # simulation status related "simulation_status_height": f"{status_height}px", "simulation_status_common": "background:gray;padding-left:10px", "simulation_status_running": "background:orange;padding-left:10px", "simulation_status_failed": "background:red;padding-left:10px", "simulation_status_finished": "background:green;padding-left:10px", } self._simulation_status_contents = { "not_running": f"""<div style='{self.layout['simulation_status_common']}; color:white;'>Not running</div>""", "running": f"""<div style='{self.layout['simulation_status_running']}; color:white;'>Running...</div>""", "finished": f"""<div style='{self.layout['simulation_status_finished']}; color:white;'>Simulation finished</div>""", "failed": f"""<div style='{self.layout['simulation_status_failed']}; color:white;'>Simulation failed</div>""", } # load default parameters self.params = self.load_parameters() # In-memory storage of all simulation and visualization related data self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list())) # Simulation parameters self.widget_tstop = BoundedFloatText( value=170, description='tstop (ms):', min=0, max=1e6, step=1, disabled=False) self.widget_dt = BoundedFloatText( value=0.025, description='dt (ms):', min=0, max=10, step=0.01, disabled=False) self.widget_ntrials = IntText(value=1, description='Trials:', disabled=False) self.widget_simulation_name = Text(value='default', placeholder='ID of your simulation', description='Name:', disabled=False) self.widget_backend_selection = Dropdown(options=[('Joblib', 'Joblib'), ('MPI', 'MPI')], value='Joblib', description='Backend:') self.widget_mpi_cmd = Text(value='mpiexec', placeholder='Fill if applies', description='MPI cmd:', disabled=False) self.widget_n_jobs = BoundedIntText(value=1, min=1, max=multiprocessing.cpu_count(), description='Cores:', disabled=False) self.load_data_button = FileUpload( accept='.txt', multiple=False, style={'button_color': self.layout['theme_color']}, description='Load data', button_style='success') # Drive selection self.widget_drive_type_selection = RadioButtons( options=['Evoked', 'Poisson', 'Rhythmic'], value='Evoked', description='Drive:', disabled=False, layout=self.layout['drive_widget']) self.widget_location_selection = RadioButtons( options=['proximal', 'distal'], value='proximal', description='Location', disabled=False, layout=self.layout['drive_widget']) self.add_drive_button = create_expanded_button( 'Add drive', 'primary', layout=self.layout['btn'], button_color=self.layout['theme_color']) # Dashboard level buttons self.run_button = create_expanded_button( 'Run', 'success', layout=self.layout['btn'], button_color=self.layout['theme_color']) self.load_connectivity_button = FileUpload( accept='.json,.param', multiple=False, style={'button_color': self.layout['theme_color']}, description='Load local network connectivity', layout=self.layout['btn_full_w'], button_style='success') self.load_drives_button = FileUpload( accept='.json,.param', multiple=False, style={'button_color': self.layout['theme_color']}, description='Load external drives', layout=self.layout['btn'], button_style='success') self.delete_drive_button = create_expanded_button( 'Delete drives', 'success', layout=self.layout['btn'], button_color=self.layout['theme_color']) # Plotting window # Visualization figure related dicts self.plot_outputs_dict = dict() self.plot_dropdown_types_dict = dict() self.plot_sim_selections_dict = dict() # Add drive section self.drive_widgets = list() self.drive_boxes = list() # Connectivity list self.connectivity_widgets = list() self._init_ui_components() self.add_logging_window_logger() def add_logging_window_logger(self): handler = _OutputWidgetHandler(self._log_out) handler.setFormatter( logging.Formatter('%(asctime)s - [%(levelname)s] %(message)s')) logger.addHandler(handler) def _init_ui_components(self): """Initialize larger UI components and dynamical output windows. It's not encouraged for users to modify or access attributes in this part. """ # dynamic larger components self._drives_out = Output() # tab to add new drives self._connectivity_out = Output() # tab to tune connectivity. self._log_out = Output() self.viz_manager = _VizManager(self.data, self.layout) # detailed configuration of backends self._backend_config_out = Output() # static parts # Running status self._simulation_status_bar = HTML( value=self._simulation_status_contents['not_running']) self._log_window = HBox([self._log_out], layout=self.layout['log_out']) self._operation_buttons = HBox( [self.run_button, self.load_data_button], layout=self.layout['operation_box']) # title self._header = HTML(value=f"""<div style='background:{self.layout['theme_color']}; text-align:center;color:white;'> HUMAN NEOCORTICAL NEUROSOLVER</div>""") @property def analysis_config(self): """Provides everything viz window needs except for the data.""" return { "viz_style": self.layout['visualization_output'], # widgets "plot_outputs": self.plot_outputs_dict, "plot_dropdowns": self.plot_dropdown_types_dict, "plot_sim_selections": self.plot_sim_selections_dict, "current_sim_name": self.widget_simulation_name.value, } @property def data(self): """Provides easy access to simulation-related data.""" return {"simulation_data": self.simulation_data}
[docs] @staticmethod def load_parameters(params_fname=None): """Read parameters from file.""" if not params_fname: # by default load default.json hnn_core_root = Path(hnn_core.__file__).parent params_fname = hnn_core_root / 'param/default.json' return read_params(params_fname)
def _link_callbacks(self): """Link callbacks to UI components.""" def _handle_backend_change(backend_type): return handle_backend_change(backend_type.new, self._backend_config_out, self.widget_mpi_cmd, self.widget_n_jobs) def _add_drive_button_clicked(b): return add_drive_widget(self.widget_drive_type_selection.value, self.drive_boxes, self.drive_widgets, self._drives_out, self.widget_tstop, self.widget_location_selection.value, layout=self.layout['drive_textbox']) def _delete_drives_clicked(b): self._drives_out.clear_output() # black magic: the following does not work # global drive_widgets; drive_widgets = list() while len(self.drive_widgets) > 0: self.drive_widgets.pop() self.drive_boxes.pop() def _on_upload_connectivity(change): return on_upload_params_change( change, self.params, self.widget_tstop, self.widget_dt, self._log_out, self.drive_boxes, self.drive_widgets, self._drives_out, self._connectivity_out, self.connectivity_widgets, self.layout['drive_textbox'], "connectivity") def _on_upload_drives(change): return on_upload_params_change( change, self.params, self.widget_tstop, self.widget_dt, self._log_out, self.drive_boxes, self.drive_widgets, self._drives_out, self._connectivity_out, self.connectivity_widgets, self.layout['drive_textbox'], "drives") def _on_upload_data(change): return on_upload_data_change(change, self.data, self.viz_manager, self._log_out) def _run_button_clicked(b): return run_button_clicked( self.widget_simulation_name, self._log_out, self.drive_widgets, self.data, self.widget_dt, self.widget_tstop, self.widget_ntrials, self.widget_backend_selection, self.widget_mpi_cmd, self.widget_n_jobs, self.params, self._simulation_status_bar, self._simulation_status_contents, self.connectivity_widgets, self.viz_manager) self.widget_backend_selection.observe(_handle_backend_change, 'value') self.add_drive_button.on_click(_add_drive_button_clicked) self.delete_drive_button.on_click(_delete_drives_clicked) self.load_connectivity_button.observe(_on_upload_connectivity, names='value') self.load_drives_button.observe(_on_upload_drives, names='value') self.run_button.on_click(_run_button_clicked) self.load_data_button.observe(_on_upload_data, names='value')
[docs] def compose(self, return_layout=True): """Compose widgets. Parameters ---------- return_layout : bool If the method returns the layout object which can be rendered by IPython.display.display() method. """ simulation_box = VBox([ VBox([ self.widget_simulation_name, self.widget_tstop, self.widget_dt, self.widget_ntrials, self.widget_backend_selection, self._backend_config_out]), ], layout=self.layout['config_box']) connectivity_box = VBox([ HBox([self.load_connectivity_button, ]), self._connectivity_out, ]) # accordians to group local-connectivity by cell type connectivity_boxes = [ VBox(slider) for slider in self.connectivity_widgets] connectivity_names = ( 'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket', 'Layer 5 Basket') cell_connectivity = Accordion(children=connectivity_boxes) for idx, connectivity_name in enumerate(connectivity_names): cell_connectivity.set_title(idx, connectivity_name) drive_selections = VBox([ self.add_drive_button, self.widget_drive_type_selection, self.widget_location_selection], layout=Layout(flex="1")) drives_options = VBox([ HBox([ VBox([self.load_drives_button, self.delete_drive_button], layout=Layout(flex="1")), drive_selections, ]), self._drives_out ]) config_panel, figs_output = self.viz_manager.compose() # Tabs for left pane left_tab = Tab() left_tab.children = [ simulation_box, connectivity_box, drives_options, config_panel, ] titles = ('Simulation', 'Network connectivity', 'External drives', 'Visualization') for idx, title in enumerate(titles): left_tab.set_title(idx, title) self.app_layout = AppLayout( header=self._header, left_sidebar=VBox([ VBox([left_tab], layout=self.layout['left_tab']), self._operation_buttons, self._log_window, ], layout=self.layout['left_sidebar']), right_sidebar=figs_output, footer=self._simulation_status_bar, pane_widths=[ self.layout['left_sidebar'].width, '0px', self.layout['visualization_window'].width ], pane_heights=[ self.layout['header_height'], self.layout['visualization_window'].height, self.layout['simulation_status_height'] ], ) self._link_callbacks() # self.simulation_data[self.widget_simulation_name.value] # initialize drive and connectivity ipywidgets load_drive_and_connectivity(self.params, self._log_out, self._drives_out, self.drive_widgets, self.drive_boxes, self._connectivity_out, self.connectivity_widgets, self.widget_tstop, self.layout) if not return_layout: return else: return self.app_layout
def show(self): display(self.app_layout)
[docs] def capture(self, width=None, height=None, extra_margin=100, render=True): """Take a screenshot of the current GUI. Parameters ---------- width : int | None The width of iframe window use to show the snapshot. height : int | None The height of iframe window use to show the snapshot. extra_margin: int Extra margin in pixel for the GUI. render : bool Will return an IFrame object if False Returns ------- snapshot : An iframe snapshot object that can be rendered in notebooks. """ file = io.StringIO() embed_minimal_html(file, views=[self.app_layout], title='') if not width: width = self.total_width + extra_margin if not height: height = self.total_height + extra_margin content = urllib.parse.quote(file.getvalue().encode('utf8')) data_url = f"data:text/html,{content}" screenshot = IFrame(data_url, width=width, height=height) if render: display(screenshot) else: return screenshot
[docs] def run_notebook_cells(self): """Run all but the last cells sequentially in a Jupyter notebook. To properly use this function: 1. Put this into the penultimate cell. 2. init the HNNGUI in a single cell. 3. Hit 'run all' button to run the whole notebook and it will selectively run twice. """ js_string = """ function sleep(ms) { return new Promise(resolve => setTimeout(resolve, ms)); } function getRunningStatus(idx){ const htmlContent = Jupyter.notebook.get_cell(idx).element[0]; return htmlContent.childNodes[0].childNodes[0].textContent; } function cellContainsInitOrMarkdown(idx){ const cell = Jupyter.notebook.get_cell(idx); if(cell.cell_type!=='code'){ return true; } else{ const textVal = cell.element[0].childNodes[0].textContent; return textVal.includes('HNNGUI()') || textVal.includes( 'HNNGUI'); } } function cellContainsRunCells(idx){ const textVal = Jupyter.notebook.get_cell( idx).element[0].childNodes[0].textContent; return textVal.includes('run_notebook_cells()'); } async function runNotebook() { console.log("run notebook cell by cell"); const cellHtmlContents = Jupyter.notebook.element[0].children[0]; const nCells = cellHtmlContents.childElementCount; console.log(`In total we have ${nCells} cells`); for(let i=1; i<nCells-1; i++){ if(cellContainsRunCells(i)){ break } else if(cellContainsInitOrMarkdown(i)){ console.log(`Skip init or markdown cell ${i}...`); continue } else{ console.log(`About to execute cell ${i}..`); Jupyter.notebook.execute_cells([i]); while (getRunningStatus(i).includes("*")){ console.log("Still running, wait for another 2 secs"); await sleep(2000); } await sleep(1000); } } console.log('Done'); } runNotebook(); """ return js_string
# below are a series of methods that are used to manipulate the GUI def _simulate_upload_data(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) self.load_data_button.set_trait('value', uploaded_value) def _simulate_upload_connectivity(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) self.load_connectivity_button.set_trait('value', uploaded_value) def _simulate_upload_drives(self, file_url): uploaded_value = _prepare_upload_file_from_url(file_url) self.load_drives_button.set_trait('value', uploaded_value) def _simulate_left_tab_click(self, tab_title): tab_index = None left_tab = self.app_layout.left_sidebar.children[0].children[0] for idx in left_tab._titles.keys(): if tab_title == left_tab._titles[idx]: tab_index = int(idx) break if tab_index is None: raise ValueError("Incorrect tab title") left_tab.selected_index = tab_index def _simulate_make_figure(self,): self._simulate_left_tab_click("Visualization") self.viz_manager.make_fig_button.click() def _simulate_viz_action(self, action_name, *args, **kwargs): """A shortcut to call simulated actions in _VizManager. Parameters ---------- action_name : str The action to take. For example, to call `_simulate_add_fig` in _VizManager, you can run `_simulate_viz_action("add_fig")` args : list Optional positional parameters passed to the called method. kwargs: dict Optional keyword parameters passed to the called method. """ self._simulate_left_tab_click("Visualization") action = getattr(self.viz_manager, f"_simulate_{action_name}") action(*args, **kwargs)
def _prepare_upload_file_from_url(file_url): params_name = file_url.split("/")[-1] data = urllib.request.urlopen(file_url) content = b"" for line in data: content += line return { params_name: { 'metadata': { 'name': params_name, 'type': 'application/json', 'size': len(content), }, 'content': content } } def create_expanded_button(description, button_style, layout, disabled=False, button_color="#8A2BE2"): return Button(description=description, button_style=button_style, layout=layout, style={'button_color': button_color}, disabled=disabled) def _get_connectivity_widgets(conn_data): """Create connectivity box widgets from specified weight and probability""" style = {'description_width': '150px'} style = {} sliders = list() for receptor_name in conn_data.keys(): w_text_input = BoundedFloatText( value=conn_data[receptor_name]['weight'], disabled=False, continuous_update=False, min=0, max=1e6, step=0.01, description="weight", style=style) conn_widget = VBox([ HTML(value=f"""<p> Receptor: {conn_data[receptor_name]['receptor']}</p>"""), w_text_input, HTML(value="<hr style='margin-bottom:5px'/>") ]) conn_widget._belongsto = { "receptor": conn_data[receptor_name]['receptor'], "location": conn_data[receptor_name]['location'], "src_gids": conn_data[receptor_name]['src_gids'], "target_gids": conn_data[receptor_name]['target_gids'], } sliders.append(conn_widget) return sliders def _get_cell_specific_widgets(layout, style, location, data=None): default_data = { 'weights_ampa': { 'L5_pyramidal': 0., 'L2_pyramidal': 0., 'L5_basket': 0., 'L2_basket': 0. }, 'weights_nmda': { 'L5_pyramidal': 0., 'L2_pyramidal': 0., 'L5_basket': 0., 'L2_basket': 0. }, 'delays': { 'L5_pyramidal': 0.1, 'L2_pyramidal': 0.1, 'L5_basket': 0.1, 'L2_basket': 0.1 }, } if isinstance(data, dict): for k in default_data.keys(): if k in data and data[k] is not None: default_data[k].update(data[k]) kwargs = dict(layout=layout, style=style) cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] if location == "distal": cell_types.remove('L5_basket') weights_ampa, weights_nmda, delays = dict(), dict(), dict() for cell_type in cell_types: weights_ampa[f'{cell_type}'] = BoundedFloatText( value=default_data['weights_ampa'][cell_type], description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) weights_nmda[f'{cell_type}'] = BoundedFloatText( value=default_data['weights_nmda'][cell_type], description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs) delays[f'{cell_type}'] = BoundedFloatText( value=default_data['delays'][cell_type], description=f'{cell_type}:', min=0, max=1e6, step=0.1, **kwargs) widgets_dict = { 'weights_ampa': weights_ampa, 'weights_nmda': weights_nmda, 'delays': delays } widgets_list = ([HTML(value="<b>AMPA weights</b>")] + list(weights_ampa.values()) + [HTML(value="<b>NMDA weights</b>")] + list(weights_nmda.values()) + [HTML(value="<b>Synaptic delays</b>")] + list(delays.values())) return widgets_list, widgets_dict def _get_rhythmic_widget(name, tstop_widget, layout, style, location, data=None, default_weights_ampa=None, default_weights_nmda=None, default_delays=None): default_data = { 'tstart': 0., 'tstart_std': 0., 'tstop': 0., 'burst_rate': 7.5, 'burst_std': 0, 'repeats': 1, 'seedcore': 14, } if isinstance(data, dict): default_data.update(data) kwargs = dict(layout=layout, style=style) tstart = BoundedFloatText( value=default_data['tstart'], description='Start time (ms)', min=0, max=1e6, **kwargs) tstart_std = BoundedFloatText( value=default_data['tstart_std'], description='Start time dev (ms)', min=0, max=1e6, **kwargs) tstop = BoundedFloatText( value=default_data['tstop'], description='Stop time (ms)', max=tstop_widget.value, **kwargs, ) burst_rate = BoundedFloatText( value=default_data['burst_rate'], description='Burst rate (Hz)', min=0, max=1e6, **kwargs) burst_std = BoundedFloatText( value=default_data['burst_std'], description='Burst std dev (Hz)', min=0, max=1e6, **kwargs) repeats = BoundedIntText( value=default_data['repeats'], description='Repeats', min=0, max=int(1e6), **kwargs) seedcore = IntText(value=default_data['seedcore'], description='Seed', **kwargs) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ 'weights_ampa': default_weights_ampa, 'weights_nmda': default_weights_nmda, 'delays': default_delays, }, ) drive_box = VBox( [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore] + widgets_list) drive = dict(type='Rhythmic', name=name, tstart=tstart, tstart_std=tstart_std, burst_rate=burst_rate, burst_std=burst_std, repeats=repeats, seedcore=seedcore, location=location, tstop=tstop) drive.update(widgets_dict) return drive, drive_box def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None, default_weights_ampa=None, default_weights_nmda=None, default_delays=None): default_data = { 'tstart': 0.0, 'tstop': 0.0, 'seedcore': 14, 'rate_constant': { 'L5_pyramidal': 8.5, 'L2_pyramidal': 8.5, 'L5_basket': 8.5, 'L2_basket': 8.5, } } if isinstance(data, dict): default_data.update(data) tstart = BoundedFloatText( value=default_data['tstart'], description='Start time (ms)', min=0, max=1e6, layout=layout, style=style) tstop = BoundedFloatText( value=default_data['tstop'], max=tstop_widget.value, description='Stop time (ms)', layout=layout, style=style, ) seedcore = IntText(value=default_data['seedcore'], description='Seed', layout=layout, style=style) cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket'] rate_constant = dict() for cell_type in cell_types: rate_constant[f'{cell_type}'] = BoundedFloatText( value=default_data['rate_constant'][cell_type], description=f'{cell_type}:', min=0, max=1e6, step=0.01, layout=layout, style=style) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ 'weights_ampa': default_weights_ampa, 'weights_nmda': default_weights_nmda, 'delays': default_delays, }, ) widgets_dict.update({'rate_constant': rate_constant}) widgets_list.extend([HTML(value="<b>Rate constants</b>")] + list(widgets_dict['rate_constant'].values())) drive_box = VBox([tstart, tstop, seedcore] + widgets_list) drive = dict( type='Poisson', name=name, tstart=tstart, tstop=tstop, rate_constant=rate_constant, seedcore=seedcore, location=location, # notice this is a widget but a str! ) drive.update(widgets_dict) return drive, drive_box def _get_evoked_widget(name, layout, style, location, data=None, default_weights_ampa=None, default_weights_nmda=None, default_delays=None): default_data = { 'mu': 0, 'sigma': 1, 'numspikes': 1, 'seedcore': 14, } if isinstance(data, dict): default_data.update(data) kwargs = dict(layout=layout, style=style) mu = BoundedFloatText( value=default_data['mu'], description='Mean time:', min=0, max=1e6, step=0.01, **kwargs) sigma = BoundedFloatText( value=default_data['sigma'], description='Std dev time:', min=0, max=1e6, step=0.01, **kwargs) numspikes = IntText(value=default_data['numspikes'], description='No. Spikes:', **kwargs) seedcore = IntText(value=default_data['seedcore'], description='Seed: ', **kwargs) widgets_list, widgets_dict = _get_cell_specific_widgets( layout, style, location, data={ 'weights_ampa': default_weights_ampa, 'weights_nmda': default_weights_nmda, 'delays': default_delays, }, ) drive_box = VBox([mu, sigma, numspikes, seedcore] + widgets_list) drive = dict(type='Evoked', name=name, mu=mu, sigma=sigma, numspikes=numspikes, seedcore=seedcore, location=location, sync_within_trial=False) drive.update(widgets_dict) return drive, drive_box def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out, tstop_widget, location, layout, prespecified_drive_name=None, prespecified_drive_data=None, prespecified_weights_ampa=None, prespecified_weights_nmda=None, prespecified_delays=None, render=True, expand_last_drive=True, event_seed=14): """Add a widget for a new drive.""" style = {'description_width': '150px'} drives_out.clear_output() if not prespecified_drive_data: prespecified_drive_data = {} prespecified_drive_data.update({"seedcore": max(event_seed, 2)}) with drives_out: if not prespecified_drive_name: name = drive_type + str(len(drive_boxes)) else: name = prespecified_drive_name if drive_type in ('Rhythmic', 'Bursty'): drive, drive_box = _get_rhythmic_widget( name, tstop_widget, layout, style, location, data=prespecified_drive_data, default_weights_ampa=prespecified_weights_ampa, default_weights_nmda=prespecified_weights_nmda, default_delays=prespecified_delays, ) elif drive_type == 'Poisson': drive, drive_box = _get_poisson_widget( name, tstop_widget, layout, style, location, data=prespecified_drive_data, default_weights_ampa=prespecified_weights_ampa, default_weights_nmda=prespecified_weights_nmda, default_delays=prespecified_delays, ) elif drive_type in ('Evoked', 'Gaussian'): drive, drive_box = _get_evoked_widget( name, layout, style, location, data=prespecified_drive_data, default_weights_ampa=prespecified_weights_ampa, default_weights_nmda=prespecified_weights_nmda, default_delays=prespecified_delays, ) if drive_type in [ 'Evoked', 'Poisson', 'Rhythmic', 'Bursty', 'Gaussian' ]: drive_boxes.append(drive_box) drive_widgets.append(drive) if render: accordion = Accordion( children=drive_boxes, selected_index=len(drive_boxes) - 1 if expand_last_drive else None, ) for idx, drive in enumerate(drive_widgets): accordion.set_title(idx, f"{drive['name']} ({drive['location']})") display(accordion) def add_connectivity_tab(params, connectivity_out, connectivity_textfields): """Add all possible connectivity boxes to connectivity tab.""" net = jones_2009_model(params) cell_types = [ct for ct in net.cell_types.keys()] receptors = ('ampa', 'nmda', 'gabaa', 'gabab') locations = ('proximal', 'distal', 'soma') # clear existing connectivity connectivity_out.clear_output() while len(connectivity_textfields) > 0: connectivity_textfields.pop() connectivity_names = list() for src_gids in cell_types: for target_gids in cell_types: for location in locations: # the connectivity list should be built on this level receptor_related_conn = {} for receptor in receptors: conn_indices = pick_connection(net=net, src_gids=src_gids, target_gids=target_gids, loc=location, receptor=receptor) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] current_w = net.connectivity[ conn_idx]['nc_dict']['A_weight'] current_p = net.connectivity[ conn_idx]['probability'] # valid connection receptor_related_conn[receptor] = { "weight": current_w, "probability": current_p, # info used to identify connection "receptor": receptor, "location": location, "src_gids": src_gids, "target_gids": target_gids, } if len(receptor_related_conn) > 0: connectivity_names.append( f"{src_gids}{target_gids} ({location})") connectivity_textfields.append( _get_connectivity_widgets(receptor_related_conn)) connectivity_boxes = [VBox(slider) for slider in connectivity_textfields] cell_connectivity = Accordion(children=connectivity_boxes) for idx, connectivity_name in enumerate(connectivity_names): cell_connectivity.set_title(idx, connectivity_name) with connectivity_out: display(cell_connectivity) return net def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout): net = jones_2009_model(params) drive_specs = _extract_drive_specs_from_hnn_params( params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode) # clear before adding drives drives_out.clear_output() while len(drive_widgets) > 0: drive_widgets.pop() drive_boxes.pop() drive_names = sorted(drive_specs.keys()) for idx, drive_name in enumerate(drive_names): # order matters specs = drive_specs[drive_name] should_render = idx == (len(drive_names) - 1) add_drive_widget( specs['type'].capitalize(), drive_boxes, drive_widgets, drives_out, tstop, specs['location'], layout=layout, prespecified_drive_name=drive_name, prespecified_drive_data=specs['dynamics'], prespecified_weights_ampa=specs['weights_ampa'], prespecified_weights_nmda=specs['weights_nmda'], prespecified_delays=specs['synaptic_delays'], render=should_render, expand_last_drive=False, event_seed=specs['event_seed'], ) def load_drive_and_connectivity(params, log_out, drives_out, drive_widgets, drive_boxes, connectivity_out, connectivity_textfields, tstop, layout): """Add drive and connectivity ipywidgets from params.""" log_out.clear_output() with log_out: # Add connectivity add_connectivity_tab(params, connectivity_out, connectivity_textfields) # Add drives add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout) def on_upload_data_change(change, data, viz_manager, log_out): if len(change['owner'].value) == 0: logger.info("Empty change") return key = list(change['new'].keys())[0] data_fname = change['new'][key]['metadata']['name'].rstrip('.txt') if data_fname in data['simulation_data'].keys(): logger.error(f"Found existing data: {data_fname}.") return ext_content = change['new'][key]['content'] ext_content = codecs.decode(ext_content, encoding="utf-8") with log_out: data['simulation_data'][data_fname] = {'net': None, 'dpls': [ hnn_core.read_dipole(io.StringIO(ext_content)) ]} logger.info(f'External data {data_fname} loaded.') viz_manager.reset_fig_config_tabs(template_name='single figure') viz_manager.add_figure() fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) ax_plots = [("ax0", "current dipole")] for ax_name, plot_type in ax_plots: viz_manager._simulate_edit_figure( fig_name, ax_name, data_fname, plot_type, {}, "plot") def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes, drive_widgets, drives_out, connectivity_out, connectivity_textfields, layout, load_type): if len(change['owner'].value) == 0: logger.info("Empty change") return logger.info("Loading connectivity...") key = list(change['new'].keys())[0] params_fname = change['new'][key]['metadata']['name'] param_data = change['new'][key]['content'] param_data = codecs.decode(param_data, encoding="utf-8") ext = Path(params_fname).suffix read_func = {'.json': _read_json, '.param': _read_legacy_params} params_network = read_func[ext](param_data) # update simulation settings and params log_out.clear_output() with log_out: if 'tstop' in params_network.keys(): tstop.value = params_network['tstop'] if 'dt' in params_network.keys(): dt.value = params_network['dt'] params.update(params_network) # init network, add drives & connectivity if load_type == 'connectivity': add_connectivity_tab(params, connectivity_out, connectivity_textfields) elif load_type == 'drives': add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop, layout) else: raise ValueError change['owner'].set_trait('_counter', 0) change['owner'].set_trait('value', {}) def _init_network_from_widgets(params, dt, tstop, single_simulation_data, drive_widgets, connectivity_textfields, add_drive=True): """Construct network and add drives.""" print("init network") params['dt'] = dt.value params['tstop'] = tstop.value single_simulation_data['net'] = jones_2009_model( params, add_drives_from_params=False, ) # adjust connectivity according to the connectivity_tab for connectivity_slider in connectivity_textfields: for vbox in connectivity_slider: conn_indices = pick_connection( net=single_simulation_data['net'], src_gids=vbox._belongsto['src_gids'], target_gids=vbox._belongsto['target_gids'], loc=vbox._belongsto['location'], receptor=vbox._belongsto['receptor']) if len(conn_indices) > 0: assert len(conn_indices) == 1 conn_idx = conn_indices[0] single_simulation_data['net'].connectivity[conn_idx][ 'nc_dict']['A_weight'] = vbox.children[1].value single_simulation_data['net'].connectivity[conn_idx][ 'probability'] = vbox.children[2].value if add_drive is False: return # add drives to network for drive in drive_widgets: weights_ampa = { k: v.value for k, v in drive['weights_ampa'].items() } weights_nmda = { k: v.value for k, v in drive['weights_nmda'].items() } synaptic_delays = {k: v.value for k, v in drive['delays'].items()} print( f"drive type is {drive['type']}, location={drive['location']}") if drive['type'] == 'Poisson': rate_constant = { k: v.value for k, v in drive['rate_constant'].items() if v.value > 0 } weights_ampa = { k: v for k, v in weights_ampa.items() if k in rate_constant } weights_nmda = { k: v for k, v in weights_nmda.items() if k in rate_constant } single_simulation_data['net'].add_poisson_drive( name=drive['name'], tstart=drive['tstart'].value, tstop=drive['tstop'].value, rate_constant=rate_constant, location=drive['location'], weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, space_constant=100.0, event_seed=drive['seedcore'].value) elif drive['type'] in ('Evoked', 'Gaussian'): single_simulation_data['net'].add_evoked_drive( name=drive['name'], mu=drive['mu'].value, sigma=drive['sigma'].value, numspikes=drive['numspikes'].value, location=drive['location'], weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, space_constant=3.0, event_seed=drive['seedcore'].value) elif drive['type'] in ('Rhythmic', 'Bursty'): single_simulation_data['net'].add_bursty_drive( name=drive['name'], tstart=drive['tstart'].value, tstart_std=drive['tstart_std'].value, burst_rate=drive['burst_rate'].value, burst_std=drive['burst_std'].value, location=drive['location'], tstop=drive['tstop'].value, weights_ampa=weights_ampa, weights_nmda=weights_nmda, synaptic_delays=synaptic_delays, event_seed=drive['seedcore'].value) def run_button_clicked(widget_simulation_name, log_out, drive_widgets, all_data, dt, tstop, ntrials, backend_selection, mpi_cmd, n_jobs, params, simulation_status_bar, simulation_status_contents, connectivity_textfields, viz_manager): """Run the simulation and plot outputs.""" log_out.clear_output() simulation_data = all_data["simulation_data"] with log_out: # clear empty trash simulations for _name in tuple(simulation_data.keys()): if len(simulation_data[_name]['dpls']) == 0: del simulation_data[_name] _sim_name = widget_simulation_name.value if simulation_data[_sim_name]['net'] is not None: print("Simulation with the same name exists!") simulation_status_bar.value = simulation_status_contents[ 'failed'] return _init_network_from_widgets(params, dt, tstop, simulation_data[_sim_name], drive_widgets, connectivity_textfields) print("start simulation") if backend_selection.value == "MPI": backend = MPIBackend( n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value) else: backend = JoblibBackend(n_jobs=n_jobs.value) print(f"Using Joblib with {n_jobs.value} core(s).") with backend: simulation_status_bar.value = simulation_status_contents['running'] simulation_data[_sim_name]['dpls'] = simulate_dipole( simulation_data[_sim_name]['net'], tstop=tstop.value, dt=dt.value, n_trials=ntrials.value) simulation_status_bar.value = simulation_status_contents[ 'finished'] viz_manager.reset_fig_config_tabs() viz_manager.add_figure() fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1) ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")] for ax_name, plot_type in ax_plots: viz_manager._simulate_edit_figure(fig_name, ax_name, _sim_name, plot_type, {}, "plot") def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs): """Switch backends between MPI and Joblib.""" backend_config.clear_output() with backend_config: if backend_type == "MPI": display(mpi_cmd) elif backend_type == "Joblib": display(n_jobs) def launch(): """Launch voila with hnn_widget.ipynb. You can pass voila commandline parameters as usual. """ from voila.app import main notebook_path = Path(__file__).parent / 'hnn_widget.ipynb' main([str(notebook_path.resolve()), *sys.argv[1:]])