"""IPywidgets GUI."""
# Authors: Mainak Jas <mjas@mgh.harvard.edu>
#          Huzi Cheng <hzcheng15@icloud.com>
import codecs
import io
import logging
import multiprocessing
import sys
import urllib.parse
import urllib.request
from collections import defaultdict
from pathlib import Path
from IPython.display import IFrame, display
from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
                        BoundedIntText, Button, Dropdown, FileUpload, VBox,
                        HBox, IntText, Layout, Output, RadioButtons, Tab, Text)
from ipywidgets.embed import embed_minimal_html
import hnn_core
from hnn_core import (JoblibBackend, MPIBackend, jones_2009_model, read_params,
                      simulate_dipole)
from hnn_core.gui._logging import logger
from hnn_core.gui._viz_manager import _VizManager, _idx2figname
from hnn_core.network import pick_connection
from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json,
                             _read_legacy_params)
class _OutputWidgetHandler(logging.Handler):
    def __init__(self, output_widget, *args, **kwargs):
        super(_OutputWidgetHandler, self).__init__(*args, **kwargs)
        self.out = output_widget
    def emit(self, record):
        formatted_record = self.format(record)
        new_output = {
            'name': 'stdout',
            'output_type': 'stream',
            'text': formatted_record + '\n'
        }
        self.out.outputs = (new_output, ) + self.out.outputs
[docs]class HNNGUI:
    """HNN GUI class
    Parameters
    ----------
    theme_color : str
        The theme color of the whole dashboard.
    total_height : int
        The height of the GUI (in pixel, same for all following parameters).
    total_width : int
        The width of the GUI.
    header_height : int
        The height of the header.
    button_height : int
        The height of buttons.
    operation_box_height : int
        The operation_box_height of operations box.
    drive_widget_width : int
        The width of GUI drive box.
    left_sidebar_width : int
        The width of left sidebad.
    log_window_height : int
        The height of logging window.
    status_height : int
        The height of status bar.
    dpi : int
        The screen dpi.
    Attributes
    ----------
    layout : dict
        The styling configuration of GUI.
    params : dict
        The parameters to use for constructing the network.
    simulation_data : dict
        Simulation related objects, such as net and dpls.
    widget_tstop : Widget
        Simulation stop time widget.
    widget_dt : Widget
        Simulation step size widget.
    widget_ntrials : Widget
        Widget that controls the number of trials in a single simulation.
    widget_backend_selection : Widget
        Widget that selects the backend used in simulations.
    widget_viz_layout_selection : Widget
        Widget that selects the layout of visualization window.
    widget_mpi_cmd : Widget
        Widget that specify the mpi command to use when the backend is
        MPIBackend.
    widget_n_jobs : Widget
        Widget that specify the cores in multi-trial simulations.
    widget_drive_type_selection : Widget
        Widget that is used to select the drive to be added to the network.
    widget_location_selection : Widget.
        Widget that specifies the location of network drives. Could be proximal
        or distal.
    add_drive_button : Widget
        Clickable widget that is used to add a drive to the network.
    run_button : Widget
        Clickable widget that triggers simulation.
    load_button : Widget
        Clickable widget that receives uploaded parameter files.
    delete_drive_button : Widget
        Clickable widget that clear all existing network drives.
    plot_outputs_dict : list
        A list of visualization panel outputs.
    plot_dropdown_types_dict : list
        A list of dropdown menus that control the plot types in
        plot_outputs_dict.
    drive_widgets : list
        A list of network drive widgets added by add_drive_button.
    drive_boxes : list
        A list of network drive layouts.
    connectivity_textfields : list
        A list of boxes that control the weight and probability of connections
        in the network.
    """
    def __init__(self, theme_color="#8A2BE2",
                 total_height=800,
                 total_width=1300,
                 header_height=50,
                 button_height=30,
                 operation_box_height=60,
                 drive_widget_width=200,
                 left_sidebar_width=576,
                 log_window_height=150,
                 status_height=30,
                 dpi=96,
                 ):
        # set up styling.
        self.total_height = total_height
        self.total_width = total_width
        viz_win_width = self.total_width - left_sidebar_width
        main_content_height = self.total_height - status_height
        config_box_height = main_content_height - (log_window_height +
                                                   operation_box_height)
        self.layout = {
            "dpi": dpi,
            "header_height": f"{header_height}px",
            "theme_color": theme_color,
            "btn": Layout(height=f"{button_height}px", width='auto'),
            "btn_full_w": Layout(height=f"{button_height}px", width='100%'),
            "del_fig_btn": Layout(height=f"{button_height}px", width='auto'),
            "log_out": Layout(border='1px solid gray',
                              height=f"{log_window_height-10}px",
                              overflow='auto'),
            "viz_config": Layout(width='99%'),
            "visualization_window": Layout(
                width=f"{viz_win_width-10}px",
                height=f"{main_content_height-10}px",
                border='1px solid gray',
                overflow='scroll'),
            "visualization_output": Layout(
                width=f"{viz_win_width-50}px",
                height=f"{main_content_height-100}px",
                border='1px solid gray',
                overflow='scroll'),
            "left_sidebar": Layout(width=f"{left_sidebar_width}px",
                                   height=f"{main_content_height}px"),
            "left_tab": Layout(width=f"{left_sidebar_width}px",
                               height=f"{config_box_height}px"),
            "operation_box": Layout(width=f"{left_sidebar_width}px",
                                    height=f"{operation_box_height}px",
                                    flex_wrap="wrap",
                                    ),
            "config_box": Layout(width=f"{left_sidebar_width}px",
                                 height=f"{config_box_height-100}px"),
            "drive_widget": Layout(width="auto"),
            "drive_textbox": Layout(width='270px', height='auto'),
            # simulation status related
            "simulation_status_height": f"{status_height}px",
            "simulation_status_common": "background:gray;padding-left:10px",
            "simulation_status_running": "background:orange;padding-left:10px",
            "simulation_status_failed": "background:red;padding-left:10px",
            "simulation_status_finished": "background:green;padding-left:10px",
        }
        self._simulation_status_contents = {
            "not_running":
            f"""<div style='{self.layout['simulation_status_common']};
            color:white;'>Not running</div>""",
            "running":
            f"""<div style='{self.layout['simulation_status_running']};
            color:white;'>Running...</div>""",
            "finished":
            f"""<div style='{self.layout['simulation_status_finished']};
            color:white;'>Simulation finished</div>""",
            "failed":
            f"""<div style='{self.layout['simulation_status_failed']};
            color:white;'>Simulation failed</div>""",
        }
        # load default parameters
        self.params = self.load_parameters()
        # In-memory storage of all simulation and visualization related data
        self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list()))
        # Simulation parameters
        self.widget_tstop = BoundedFloatText(
            value=170, description='tstop (ms):', min=0, max=1e6, step=1,
            disabled=False)
        self.widget_dt = BoundedFloatText(
            value=0.025, description='dt (ms):', min=0, max=10, step=0.01,
            disabled=False)
        self.widget_ntrials = IntText(value=1, description='Trials:',
                                      disabled=False)
        self.widget_simulation_name = Text(value='default',
                                           placeholder='ID of your simulation',
                                           description='Name:',
                                           disabled=False)
        self.widget_backend_selection = Dropdown(options=[('Joblib', 'Joblib'),
                                                          ('MPI', 'MPI')],
                                                 value='Joblib',
                                                 description='Backend:')
        self.widget_mpi_cmd = Text(value='mpiexec',
                                   placeholder='Fill if applies',
                                   description='MPI cmd:', disabled=False)
        self.widget_n_jobs = BoundedIntText(value=1, min=1,
                                            max=multiprocessing.cpu_count(),
                                            description='Cores:',
                                            disabled=False)
        self.load_data_button = FileUpload(
            accept='.txt', multiple=False,
            style={'button_color': self.layout['theme_color']},
            description='Load data',
            button_style='success')
        # Drive selection
        self.widget_drive_type_selection = RadioButtons(
            options=['Evoked', 'Poisson', 'Rhythmic'],
            value='Evoked',
            description='Drive:',
            disabled=False,
            layout=self.layout['drive_widget'])
        self.widget_location_selection = RadioButtons(
            options=['proximal', 'distal'], value='proximal',
            description='Location', disabled=False,
            layout=self.layout['drive_widget'])
        self.add_drive_button = create_expanded_button(
            'Add drive', 'primary', layout=self.layout['btn'],
            button_color=self.layout['theme_color'])
        # Dashboard level buttons
        self.run_button = create_expanded_button(
            'Run', 'success', layout=self.layout['btn'],
            button_color=self.layout['theme_color'])
        self.load_connectivity_button = FileUpload(
            accept='.json,.param', multiple=False,
            style={'button_color': self.layout['theme_color']},
            description='Load local network connectivity',
            layout=self.layout['btn_full_w'], button_style='success')
        self.load_drives_button = FileUpload(
            accept='.json,.param', multiple=False,
            style={'button_color': self.layout['theme_color']},
            description='Load external drives', layout=self.layout['btn'],
            button_style='success')
        self.delete_drive_button = create_expanded_button(
            'Delete drives', 'success', layout=self.layout['btn'],
            button_color=self.layout['theme_color'])
        # Plotting window
        # Visualization figure related dicts
        self.plot_outputs_dict = dict()
        self.plot_dropdown_types_dict = dict()
        self.plot_sim_selections_dict = dict()
        # Add drive section
        self.drive_widgets = list()
        self.drive_boxes = list()
        # Connectivity list
        self.connectivity_widgets = list()
        self._init_ui_components()
        self.add_logging_window_logger()
    def add_logging_window_logger(self):
        handler = _OutputWidgetHandler(self._log_out)
        handler.setFormatter(
            logging.Formatter('%(asctime)s  - [%(levelname)s] %(message)s'))
        logger.addHandler(handler)
    def _init_ui_components(self):
        """Initialize larger UI components and dynamical output windows.
        It's not encouraged for users to modify or access attributes in this
        part.
        """
        # dynamic larger components
        self._drives_out = Output()  # tab to add new drives
        self._connectivity_out = Output()  # tab to tune connectivity.
        self._log_out = Output()
        self.viz_manager = _VizManager(self.data, self.layout)
        # detailed configuration of backends
        self._backend_config_out = Output()
        # static parts
        # Running status
        self._simulation_status_bar = HTML(
            value=self._simulation_status_contents['not_running'])
        self._log_window = HBox([self._log_out], layout=self.layout['log_out'])
        self._operation_buttons = HBox(
            [self.run_button, self.load_data_button],
            layout=self.layout['operation_box'])
        # title
        self._header = HTML(value=f"""<div
            style='background:{self.layout['theme_color']};
            text-align:center;color:white;'>
            HUMAN NEOCORTICAL NEUROSOLVER</div>""")
    @property
    def analysis_config(self):
        """Provides everything viz window needs except for the data."""
        return {
            "viz_style": self.layout['visualization_output'],
            # widgets
            "plot_outputs": self.plot_outputs_dict,
            "plot_dropdowns": self.plot_dropdown_types_dict,
            "plot_sim_selections": self.plot_sim_selections_dict,
            "current_sim_name": self.widget_simulation_name.value,
        }
    @property
    def data(self):
        """Provides easy access to simulation-related data."""
        return {"simulation_data": self.simulation_data}
[docs]    @staticmethod
    def load_parameters(params_fname=None):
        """Read parameters from file."""
        if not params_fname:
            # by default load default.json
            hnn_core_root = Path(hnn_core.__file__).parent
            params_fname = hnn_core_root / 'param/default.json'
        return read_params(params_fname) 
    def _link_callbacks(self):
        """Link callbacks to UI components."""
        def _handle_backend_change(backend_type):
            return handle_backend_change(backend_type.new,
                                         self._backend_config_out,
                                         self.widget_mpi_cmd,
                                         self.widget_n_jobs)
        def _add_drive_button_clicked(b):
            return add_drive_widget(self.widget_drive_type_selection.value,
                                    self.drive_boxes, self.drive_widgets,
                                    self._drives_out, self.widget_tstop,
                                    self.widget_location_selection.value,
                                    layout=self.layout['drive_textbox'])
        def _delete_drives_clicked(b):
            self._drives_out.clear_output()
            # black magic: the following does not work
            # global drive_widgets; drive_widgets = list()
            while len(self.drive_widgets) > 0:
                self.drive_widgets.pop()
                self.drive_boxes.pop()
        def _on_upload_connectivity(change):
            return on_upload_params_change(
                change, self.params, self.widget_tstop, self.widget_dt,
                self._log_out, self.drive_boxes, self.drive_widgets,
                self._drives_out, self._connectivity_out,
                self.connectivity_widgets, self.layout['drive_textbox'],
                "connectivity")
        def _on_upload_drives(change):
            return on_upload_params_change(
                change, self.params, self.widget_tstop, self.widget_dt,
                self._log_out, self.drive_boxes, self.drive_widgets,
                self._drives_out, self._connectivity_out,
                self.connectivity_widgets, self.layout['drive_textbox'],
                "drives")
        def _on_upload_data(change):
            return on_upload_data_change(change, self.data, self.viz_manager,
                                         self._log_out)
        def _run_button_clicked(b):
            return run_button_clicked(
                self.widget_simulation_name, self._log_out, self.drive_widgets,
                self.data, self.widget_dt, self.widget_tstop,
                self.widget_ntrials, self.widget_backend_selection,
                self.widget_mpi_cmd, self.widget_n_jobs, self.params,
                self._simulation_status_bar, self._simulation_status_contents,
                self.connectivity_widgets, self.viz_manager)
        self.widget_backend_selection.observe(_handle_backend_change, 'value')
        self.add_drive_button.on_click(_add_drive_button_clicked)
        self.delete_drive_button.on_click(_delete_drives_clicked)
        self.load_connectivity_button.observe(_on_upload_connectivity,
                                              names='value')
        self.load_drives_button.observe(_on_upload_drives, names='value')
        self.run_button.on_click(_run_button_clicked)
        self.load_data_button.observe(_on_upload_data, names='value')
[docs]    def compose(self, return_layout=True):
        """Compose widgets.
        Parameters
        ----------
        return_layout : bool
            If the method returns the layout object which can be rendered by
            IPython.display.display() method.
        """
        simulation_box = VBox([
            VBox([
                self.widget_simulation_name, self.widget_tstop, self.widget_dt,
                self.widget_ntrials, self.widget_backend_selection,
                self._backend_config_out]),
        ], layout=self.layout['config_box'])
        connectivity_box = VBox([
            HBox([self.load_connectivity_button, ]),
            self._connectivity_out,
        ])
        # accordians to group local-connectivity by cell type
        connectivity_boxes = [
            VBox(slider) for slider in self.connectivity_widgets]
        connectivity_names = (
            'Layer 2/3 Pyramidal', 'Layer 5 Pyramidal', 'Layer 2 Basket',
            'Layer 5 Basket')
        cell_connectivity = Accordion(children=connectivity_boxes)
        for idx, connectivity_name in enumerate(connectivity_names):
            cell_connectivity.set_title(idx, connectivity_name)
        drive_selections = VBox([
            self.add_drive_button, self.widget_drive_type_selection,
            self.widget_location_selection],
            layout=Layout(flex="1"))
        drives_options = VBox([
            HBox([
                VBox([self.load_drives_button, self.delete_drive_button],
                     layout=Layout(flex="1")),
                drive_selections,
            ]), self._drives_out
        ])
        config_panel, figs_output = self.viz_manager.compose()
        # Tabs for left pane
        left_tab = Tab()
        left_tab.children = [
            simulation_box, connectivity_box, drives_options,
            config_panel,
        ]
        titles = ('Simulation', 'Network connectivity', 'External drives',
                  'Visualization')
        for idx, title in enumerate(titles):
            left_tab.set_title(idx, title)
        self.app_layout = AppLayout(
            header=self._header,
            left_sidebar=VBox([
                VBox([left_tab], layout=self.layout['left_tab']),
                self._operation_buttons,
                self._log_window,
            ], layout=self.layout['left_sidebar']),
            right_sidebar=figs_output,
            footer=self._simulation_status_bar,
            pane_widths=[
                self.layout['left_sidebar'].width, '0px',
                self.layout['visualization_window'].width
            ],
            pane_heights=[
                self.layout['header_height'],
                self.layout['visualization_window'].height,
                self.layout['simulation_status_height']
            ],
        )
        self._link_callbacks()
        # self.simulation_data[self.widget_simulation_name.value]
        # initialize drive and connectivity ipywidgets
        load_drive_and_connectivity(self.params, self._log_out,
                                    self._drives_out, self.drive_widgets,
                                    self.drive_boxes, self._connectivity_out,
                                    self.connectivity_widgets,
                                    self.widget_tstop, self.layout)
        if not return_layout:
            return
        else:
            return self.app_layout 
    def show(self):
        display(self.app_layout)
[docs]    def capture(self, width=None, height=None, extra_margin=100, render=True):
        """Take a screenshot of the current GUI.
        Parameters
        ----------
        width : int | None
            The width of iframe window use to show the snapshot.
        height : int | None
            The height of iframe window use to show the snapshot.
        extra_margin: int
            Extra margin in pixel for the GUI.
        render : bool
            Will return an IFrame object if False
        Returns
        -------
        snapshot : An iframe snapshot object that can be rendered in notebooks.
        """
        file = io.StringIO()
        embed_minimal_html(file, views=[self.app_layout], title='')
        if not width:
            width = self.total_width + extra_margin
        if not height:
            height = self.total_height + extra_margin
        content = urllib.parse.quote(file.getvalue().encode('utf8'))
        data_url = f"data:text/html,{content}"
        screenshot = IFrame(data_url, width=width, height=height)
        if render:
            display(screenshot)
        else:
            return screenshot 
[docs]    def run_notebook_cells(self):
        """Run all but the last cells sequentially in a Jupyter notebook.
        To properly use this function:
            1. Put this into the penultimate cell.
            2. init the HNNGUI in a single cell.
            3. Hit 'run all' button to run the whole notebook and it will
               selectively run twice.
        """
        js_string = """
        function sleep(ms) {
        return new Promise(resolve => setTimeout(resolve, ms));
        }
        function getRunningStatus(idx){
            const htmlContent = Jupyter.notebook.get_cell(idx).element[0];
            return htmlContent.childNodes[0].childNodes[0].textContent;
        }
        function cellContainsInitOrMarkdown(idx){
            const cell = Jupyter.notebook.get_cell(idx);
            if(cell.cell_type!=='code'){
                return true;
            }
            else{
                const textVal = cell.element[0].childNodes[0].textContent;
                return textVal.includes('HNNGUI()') || textVal.includes(
                    'HNNGUI');
            }
        }
        function cellContainsRunCells(idx){
            const textVal = Jupyter.notebook.get_cell(
                idx).element[0].childNodes[0].textContent;
            return textVal.includes('run_notebook_cells()');
        }
        async function runNotebook() {
            console.log("run notebook cell by cell");
            const cellHtmlContents = Jupyter.notebook.element[0].children[0];
            const nCells = cellHtmlContents.childElementCount;
            console.log(`In total we have ${nCells} cells`);
            for(let i=1; i<nCells-1; i++){
                if(cellContainsRunCells(i)){
                    break
                }
                else if(cellContainsInitOrMarkdown(i)){
                    console.log(`Skip init or markdown cell ${i}...`);
                    continue
                }
                else{
                    console.log(`About to execute cell ${i}..`);
                    Jupyter.notebook.execute_cells([i]);
                    while (getRunningStatus(i).includes("*")){
                        console.log("Still running, wait for another 2 secs");
                        await sleep(2000);
                    }
                    await sleep(1000);
                }
            }
            console.log('Done');
        }
        runNotebook();
        """
        return js_string 
    # below are a series of methods that are used to manipulate the GUI
    def _simulate_upload_data(self, file_url):
        uploaded_value = _prepare_upload_file_from_url(file_url)
        self.load_data_button.set_trait('value', uploaded_value)
    def _simulate_upload_connectivity(self, file_url):
        uploaded_value = _prepare_upload_file_from_url(file_url)
        self.load_connectivity_button.set_trait('value', uploaded_value)
    def _simulate_upload_drives(self, file_url):
        uploaded_value = _prepare_upload_file_from_url(file_url)
        self.load_drives_button.set_trait('value', uploaded_value)
    def _simulate_left_tab_click(self, tab_title):
        tab_index = None
        left_tab = self.app_layout.left_sidebar.children[0].children[0]
        for idx in left_tab._titles.keys():
            if tab_title == left_tab._titles[idx]:
                tab_index = int(idx)
                break
        if tab_index is None:
            raise ValueError("Incorrect tab title")
        left_tab.selected_index = tab_index
    def _simulate_make_figure(self,):
        self._simulate_left_tab_click("Visualization")
        self.viz_manager.make_fig_button.click()
    def _simulate_viz_action(self, action_name, *args, **kwargs):
        """A shortcut to call simulated actions in _VizManager.
        Parameters
        ----------
        action_name : str
            The action to take. For example, to call `_simulate_add_fig` in
            _VizManager, you can run `_simulate_viz_action("add_fig")`
        args : list
            Optional positional parameters passed to the called method.
        kwargs: dict
            Optional keyword parameters passed to the called method.
        """
        self._simulate_left_tab_click("Visualization")
        action = getattr(self.viz_manager, f"_simulate_{action_name}")
        action(*args, **kwargs) 
def _prepare_upload_file_from_url(file_url):
    params_name = file_url.split("/")[-1]
    data = urllib.request.urlopen(file_url)
    content = b""
    for line in data:
        content += line
    return {
        params_name: {
            'metadata': {
                'name': params_name,
                'type': 'application/json',
                'size': len(content),
            },
            'content': content
        }
    }
def create_expanded_button(description, button_style, layout, disabled=False,
                           button_color="#8A2BE2"):
    return Button(description=description, button_style=button_style,
                  layout=layout, style={'button_color': button_color},
                  disabled=disabled)
def _get_connectivity_widgets(conn_data):
    """Create connectivity box widgets from specified weight and probability"""
    style = {'description_width': '150px'}
    style = {}
    sliders = list()
    for receptor_name in conn_data.keys():
        w_text_input = BoundedFloatText(
            value=conn_data[receptor_name]['weight'], disabled=False,
            continuous_update=False, min=0, max=1e6, step=0.01,
            description="weight", style=style)
        conn_widget = VBox([
            HTML(value=f"""<p>
            Receptor: {conn_data[receptor_name]['receptor']}</p>"""),
            w_text_input, HTML(value="<hr style='margin-bottom:5px'/>")
        ])
        conn_widget._belongsto = {
            "receptor": conn_data[receptor_name]['receptor'],
            "location": conn_data[receptor_name]['location'],
            "src_gids": conn_data[receptor_name]['src_gids'],
            "target_gids": conn_data[receptor_name]['target_gids'],
        }
        sliders.append(conn_widget)
    return sliders
def _get_cell_specific_widgets(layout, style, location, data=None):
    default_data = {
        'weights_ampa': {
            'L5_pyramidal': 0.,
            'L2_pyramidal': 0.,
            'L5_basket': 0.,
            'L2_basket': 0.
        },
        'weights_nmda': {
            'L5_pyramidal': 0.,
            'L2_pyramidal': 0.,
            'L5_basket': 0.,
            'L2_basket': 0.
        },
        'delays': {
            'L5_pyramidal': 0.1,
            'L2_pyramidal': 0.1,
            'L5_basket': 0.1,
            'L2_basket': 0.1
        },
    }
    if isinstance(data, dict):
        for k in default_data.keys():
            if k in data and data[k] is not None:
                default_data[k].update(data[k])
    kwargs = dict(layout=layout, style=style)
    cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket']
    if location == "distal":
        cell_types.remove('L5_basket')
    weights_ampa, weights_nmda, delays = dict(), dict(), dict()
    for cell_type in cell_types:
        weights_ampa[f'{cell_type}'] = BoundedFloatText(
            value=default_data['weights_ampa'][cell_type],
            description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs)
        weights_nmda[f'{cell_type}'] = BoundedFloatText(
            value=default_data['weights_nmda'][cell_type],
            description=f'{cell_type}:', min=0, max=1e6, step=0.01, **kwargs)
        delays[f'{cell_type}'] = BoundedFloatText(
            value=default_data['delays'][cell_type],
            description=f'{cell_type}:', min=0, max=1e6, step=0.1, **kwargs)
    widgets_dict = {
        'weights_ampa': weights_ampa,
        'weights_nmda': weights_nmda,
        'delays': delays
    }
    widgets_list = ([HTML(value="<b>AMPA weights</b>")] +
                    list(weights_ampa.values()) +
                    [HTML(value="<b>NMDA weights</b>")] +
                    list(weights_nmda.values()) +
                    [HTML(value="<b>Synaptic delays</b>")] +
                    list(delays.values()))
    return widgets_list, widgets_dict
def _get_rhythmic_widget(name, tstop_widget, layout, style, location,
                         data=None, default_weights_ampa=None,
                         default_weights_nmda=None, default_delays=None):
    default_data = {
        'tstart': 0.,
        'tstart_std': 0.,
        'tstop': 0.,
        'burst_rate': 7.5,
        'burst_std': 0,
        'repeats': 1,
        'seedcore': 14,
    }
    if isinstance(data, dict):
        default_data.update(data)
    kwargs = dict(layout=layout, style=style)
    tstart = BoundedFloatText(
        value=default_data['tstart'], description='Start time (ms)',
        min=0, max=1e6, **kwargs)
    tstart_std = BoundedFloatText(
        value=default_data['tstart_std'], description='Start time dev (ms)',
        min=0, max=1e6, **kwargs)
    tstop = BoundedFloatText(
        value=default_data['tstop'],
        description='Stop time (ms)',
        max=tstop_widget.value,
        **kwargs,
    )
    burst_rate = BoundedFloatText(
        value=default_data['burst_rate'], description='Burst rate (Hz)',
        min=0, max=1e6, **kwargs)
    burst_std = BoundedFloatText(
        value=default_data['burst_std'], description='Burst std dev (Hz)',
        min=0, max=1e6, **kwargs)
    repeats = BoundedIntText(
        value=default_data['repeats'], description='Repeats', min=0,
        max=int(1e6), **kwargs)
    seedcore = IntText(value=default_data['seedcore'],
                       description='Seed',
                       **kwargs)
    widgets_list, widgets_dict = _get_cell_specific_widgets(
        layout,
        style,
        location,
        data={
            'weights_ampa': default_weights_ampa,
            'weights_nmda': default_weights_nmda,
            'delays': default_delays,
        },
    )
    drive_box = VBox(
        [tstart, tstart_std, tstop, burst_rate, burst_std, repeats, seedcore] +
        widgets_list)
    drive = dict(type='Rhythmic',
                 name=name,
                 tstart=tstart,
                 tstart_std=tstart_std,
                 burst_rate=burst_rate,
                 burst_std=burst_std,
                 repeats=repeats,
                 seedcore=seedcore,
                 location=location,
                 tstop=tstop)
    drive.update(widgets_dict)
    return drive, drive_box
def _get_poisson_widget(name, tstop_widget, layout, style, location, data=None,
                        default_weights_ampa=None, default_weights_nmda=None,
                        default_delays=None):
    default_data = {
        'tstart': 0.0,
        'tstop': 0.0,
        'seedcore': 14,
        'rate_constant': {
            'L5_pyramidal': 8.5,
            'L2_pyramidal': 8.5,
            'L5_basket': 8.5,
            'L2_basket': 8.5,
        }
    }
    if isinstance(data, dict):
        default_data.update(data)
    tstart = BoundedFloatText(
        value=default_data['tstart'], description='Start time (ms)',
        min=0, max=1e6, layout=layout, style=style)
    tstop = BoundedFloatText(
        value=default_data['tstop'],
        max=tstop_widget.value,
        description='Stop time (ms)',
        layout=layout,
        style=style,
    )
    seedcore = IntText(value=default_data['seedcore'],
                       description='Seed',
                       layout=layout,
                       style=style)
    cell_types = ['L5_pyramidal', 'L2_pyramidal', 'L5_basket', 'L2_basket']
    rate_constant = dict()
    for cell_type in cell_types:
        rate_constant[f'{cell_type}'] = BoundedFloatText(
            value=default_data['rate_constant'][cell_type],
            description=f'{cell_type}:', min=0, max=1e6, step=0.01,
            layout=layout, style=style)
    widgets_list, widgets_dict = _get_cell_specific_widgets(
        layout,
        style,
        location,
        data={
            'weights_ampa': default_weights_ampa,
            'weights_nmda': default_weights_nmda,
            'delays': default_delays,
        },
    )
    widgets_dict.update({'rate_constant': rate_constant})
    widgets_list.extend([HTML(value="<b>Rate constants</b>")] +
                        list(widgets_dict['rate_constant'].values()))
    drive_box = VBox([tstart, tstop, seedcore] + widgets_list)
    drive = dict(
        type='Poisson',
        name=name,
        tstart=tstart,
        tstop=tstop,
        rate_constant=rate_constant,
        seedcore=seedcore,
        location=location,  # notice this is a widget but a str!
    )
    drive.update(widgets_dict)
    return drive, drive_box
def _get_evoked_widget(name, layout, style, location, data=None,
                       default_weights_ampa=None, default_weights_nmda=None,
                       default_delays=None):
    default_data = {
        'mu': 0,
        'sigma': 1,
        'numspikes': 1,
        'seedcore': 14,
    }
    if isinstance(data, dict):
        default_data.update(data)
    kwargs = dict(layout=layout, style=style)
    mu = BoundedFloatText(
        value=default_data['mu'], description='Mean time:', min=0, max=1e6,
        step=0.01, **kwargs)
    sigma = BoundedFloatText(
        value=default_data['sigma'], description='Std dev time:', min=0,
        max=1e6, step=0.01, **kwargs)
    numspikes = IntText(value=default_data['numspikes'],
                        description='No. Spikes:',
                        **kwargs)
    seedcore = IntText(value=default_data['seedcore'],
                       description='Seed: ',
                       **kwargs)
    widgets_list, widgets_dict = _get_cell_specific_widgets(
        layout,
        style,
        location,
        data={
            'weights_ampa': default_weights_ampa,
            'weights_nmda': default_weights_nmda,
            'delays': default_delays,
        },
    )
    drive_box = VBox([mu, sigma, numspikes, seedcore] + widgets_list)
    drive = dict(type='Evoked',
                 name=name,
                 mu=mu,
                 sigma=sigma,
                 numspikes=numspikes,
                 seedcore=seedcore,
                 location=location,
                 sync_within_trial=False)
    drive.update(widgets_dict)
    return drive, drive_box
def add_drive_widget(drive_type, drive_boxes, drive_widgets, drives_out,
                     tstop_widget, location, layout,
                     prespecified_drive_name=None,
                     prespecified_drive_data=None,
                     prespecified_weights_ampa=None,
                     prespecified_weights_nmda=None,
                     prespecified_delays=None, render=True,
                     expand_last_drive=True, event_seed=14):
    """Add a widget for a new drive."""
    style = {'description_width': '150px'}
    drives_out.clear_output()
    if not prespecified_drive_data:
        prespecified_drive_data = {}
    prespecified_drive_data.update({"seedcore": max(event_seed, 2)})
    with drives_out:
        if not prespecified_drive_name:
            name = drive_type + str(len(drive_boxes))
        else:
            name = prespecified_drive_name
        if drive_type in ('Rhythmic', 'Bursty'):
            drive, drive_box = _get_rhythmic_widget(
                name,
                tstop_widget,
                layout,
                style,
                location,
                data=prespecified_drive_data,
                default_weights_ampa=prespecified_weights_ampa,
                default_weights_nmda=prespecified_weights_nmda,
                default_delays=prespecified_delays,
            )
        elif drive_type == 'Poisson':
            drive, drive_box = _get_poisson_widget(
                name,
                tstop_widget,
                layout,
                style,
                location,
                data=prespecified_drive_data,
                default_weights_ampa=prespecified_weights_ampa,
                default_weights_nmda=prespecified_weights_nmda,
                default_delays=prespecified_delays,
            )
        elif drive_type in ('Evoked', 'Gaussian'):
            drive, drive_box = _get_evoked_widget(
                name,
                layout,
                style,
                location,
                data=prespecified_drive_data,
                default_weights_ampa=prespecified_weights_ampa,
                default_weights_nmda=prespecified_weights_nmda,
                default_delays=prespecified_delays,
            )
        if drive_type in [
                'Evoked', 'Poisson', 'Rhythmic', 'Bursty', 'Gaussian'
        ]:
            drive_boxes.append(drive_box)
            drive_widgets.append(drive)
        if render:
            accordion = Accordion(
                children=drive_boxes,
                selected_index=len(drive_boxes) -
                1 if expand_last_drive else None,
            )
            for idx, drive in enumerate(drive_widgets):
                accordion.set_title(idx,
                                    f"{drive['name']} ({drive['location']})")
            display(accordion)
def add_connectivity_tab(params, connectivity_out,
                         connectivity_textfields):
    """Add all possible connectivity boxes to connectivity tab."""
    net = jones_2009_model(params)
    cell_types = [ct for ct in net.cell_types.keys()]
    receptors = ('ampa', 'nmda', 'gabaa', 'gabab')
    locations = ('proximal', 'distal', 'soma')
    # clear existing connectivity
    connectivity_out.clear_output()
    while len(connectivity_textfields) > 0:
        connectivity_textfields.pop()
    connectivity_names = list()
    for src_gids in cell_types:
        for target_gids in cell_types:
            for location in locations:
                # the connectivity list should be built on this level
                receptor_related_conn = {}
                for receptor in receptors:
                    conn_indices = pick_connection(net=net,
                                                   src_gids=src_gids,
                                                   target_gids=target_gids,
                                                   loc=location,
                                                   receptor=receptor)
                    if len(conn_indices) > 0:
                        assert len(conn_indices) == 1
                        conn_idx = conn_indices[0]
                        current_w = net.connectivity[
                            conn_idx]['nc_dict']['A_weight']
                        current_p = net.connectivity[
                            conn_idx]['probability']
                        # valid connection
                        receptor_related_conn[receptor] = {
                            "weight": current_w,
                            "probability": current_p,
                            # info used to identify connection
                            "receptor": receptor,
                            "location": location,
                            "src_gids": src_gids,
                            "target_gids": target_gids,
                        }
                if len(receptor_related_conn) > 0:
                    connectivity_names.append(
                        f"{src_gids}→{target_gids} ({location})")
                    connectivity_textfields.append(
                        _get_connectivity_widgets(receptor_related_conn))
    connectivity_boxes = [VBox(slider) for slider in connectivity_textfields]
    cell_connectivity = Accordion(children=connectivity_boxes)
    for idx, connectivity_name in enumerate(connectivity_names):
        cell_connectivity.set_title(idx, connectivity_name)
    with connectivity_out:
        display(cell_connectivity)
    return net
def add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
                  layout):
    net = jones_2009_model(params)
    drive_specs = _extract_drive_specs_from_hnn_params(
        params, list(net.cell_types.keys()), legacy_mode=net._legacy_mode)
    # clear before adding drives
    drives_out.clear_output()
    while len(drive_widgets) > 0:
        drive_widgets.pop()
        drive_boxes.pop()
    drive_names = sorted(drive_specs.keys())
    for idx, drive_name in enumerate(drive_names):  # order matters
        specs = drive_specs[drive_name]
        should_render = idx == (len(drive_names) - 1)
        add_drive_widget(
            specs['type'].capitalize(),
            drive_boxes,
            drive_widgets,
            drives_out,
            tstop,
            specs['location'],
            layout=layout,
            prespecified_drive_name=drive_name,
            prespecified_drive_data=specs['dynamics'],
            prespecified_weights_ampa=specs['weights_ampa'],
            prespecified_weights_nmda=specs['weights_nmda'],
            prespecified_delays=specs['synaptic_delays'],
            render=should_render,
            expand_last_drive=False,
            event_seed=specs['event_seed'],
        )
def load_drive_and_connectivity(params, log_out, drives_out,
                                drive_widgets, drive_boxes, connectivity_out,
                                connectivity_textfields, tstop, layout):
    """Add drive and connectivity ipywidgets from params."""
    log_out.clear_output()
    with log_out:
        # Add connectivity
        add_connectivity_tab(params, connectivity_out, connectivity_textfields)
        # Add drives
        add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
                      layout)
def on_upload_data_change(change, data, viz_manager, log_out):
    if len(change['owner'].value) == 0:
        logger.info("Empty change")
        return
    key = list(change['new'].keys())[0]
    data_fname = change['new'][key]['metadata']['name'].rstrip('.txt')
    if data_fname in data['simulation_data'].keys():
        logger.error(f"Found existing data: {data_fname}.")
        return
    ext_content = change['new'][key]['content']
    ext_content = codecs.decode(ext_content, encoding="utf-8")
    with log_out:
        data['simulation_data'][data_fname] = {'net': None, 'dpls': [
            hnn_core.read_dipole(io.StringIO(ext_content))
        ]}
        logger.info(f'External data {data_fname} loaded.')
        viz_manager.reset_fig_config_tabs(template_name='single figure')
        viz_manager.add_figure()
        fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
        ax_plots = [("ax0", "current dipole")]
        for ax_name, plot_type in ax_plots:
            viz_manager._simulate_edit_figure(
                fig_name, ax_name, data_fname, plot_type, {}, "plot")
def on_upload_params_change(change, params, tstop, dt, log_out, drive_boxes,
                            drive_widgets, drives_out, connectivity_out,
                            connectivity_textfields, layout, load_type):
    if len(change['owner'].value) == 0:
        logger.info("Empty change")
        return
    logger.info("Loading connectivity...")
    key = list(change['new'].keys())[0]
    params_fname = change['new'][key]['metadata']['name']
    param_data = change['new'][key]['content']
    param_data = codecs.decode(param_data, encoding="utf-8")
    ext = Path(params_fname).suffix
    read_func = {'.json': _read_json, '.param': _read_legacy_params}
    params_network = read_func[ext](param_data)
    # update simulation settings and params
    log_out.clear_output()
    with log_out:
        if 'tstop' in params_network.keys():
            tstop.value = params_network['tstop']
        if 'dt' in params_network.keys():
            dt.value = params_network['dt']
        params.update(params_network)
    # init network, add drives & connectivity
    if load_type == 'connectivity':
        add_connectivity_tab(params, connectivity_out, connectivity_textfields)
    elif load_type == 'drives':
        add_drive_tab(params, drives_out, drive_widgets, drive_boxes, tstop,
                      layout)
    else:
        raise ValueError
    change['owner'].set_trait('_counter', 0)
    change['owner'].set_trait('value', {})
def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
                               drive_widgets, connectivity_textfields,
                               add_drive=True):
    """Construct network and add drives."""
    print("init network")
    params['dt'] = dt.value
    params['tstop'] = tstop.value
    single_simulation_data['net'] = jones_2009_model(
        params,
        add_drives_from_params=False,
    )
    # adjust connectivity according to the connectivity_tab
    for connectivity_slider in connectivity_textfields:
        for vbox in connectivity_slider:
            conn_indices = pick_connection(
                net=single_simulation_data['net'],
                src_gids=vbox._belongsto['src_gids'],
                target_gids=vbox._belongsto['target_gids'],
                loc=vbox._belongsto['location'],
                receptor=vbox._belongsto['receptor'])
            if len(conn_indices) > 0:
                assert len(conn_indices) == 1
                conn_idx = conn_indices[0]
                single_simulation_data['net'].connectivity[conn_idx][
                    'nc_dict']['A_weight'] = vbox.children[1].value
                single_simulation_data['net'].connectivity[conn_idx][
                    'probability'] = vbox.children[2].value
    if add_drive is False:
        return
    # add drives to network
    for drive in drive_widgets:
        weights_ampa = {
            k: v.value
            for k, v in drive['weights_ampa'].items()
        }
        weights_nmda = {
            k: v.value
            for k, v in drive['weights_nmda'].items()
        }
        synaptic_delays = {k: v.value for k, v in drive['delays'].items()}
        print(
            f"drive type is {drive['type']}, location={drive['location']}")
        if drive['type'] == 'Poisson':
            rate_constant = {
                k: v.value
                for k, v in drive['rate_constant'].items() if v.value > 0
            }
            weights_ampa = {
                k: v
                for k, v in weights_ampa.items() if k in rate_constant
            }
            weights_nmda = {
                k: v
                for k, v in weights_nmda.items() if k in rate_constant
            }
            single_simulation_data['net'].add_poisson_drive(
                name=drive['name'],
                tstart=drive['tstart'].value,
                tstop=drive['tstop'].value,
                rate_constant=rate_constant,
                location=drive['location'],
                weights_ampa=weights_ampa,
                weights_nmda=weights_nmda,
                synaptic_delays=synaptic_delays,
                space_constant=100.0,
                event_seed=drive['seedcore'].value)
        elif drive['type'] in ('Evoked', 'Gaussian'):
            single_simulation_data['net'].add_evoked_drive(
                name=drive['name'],
                mu=drive['mu'].value,
                sigma=drive['sigma'].value,
                numspikes=drive['numspikes'].value,
                location=drive['location'],
                weights_ampa=weights_ampa,
                weights_nmda=weights_nmda,
                synaptic_delays=synaptic_delays,
                space_constant=3.0,
                event_seed=drive['seedcore'].value)
        elif drive['type'] in ('Rhythmic', 'Bursty'):
            single_simulation_data['net'].add_bursty_drive(
                name=drive['name'],
                tstart=drive['tstart'].value,
                tstart_std=drive['tstart_std'].value,
                burst_rate=drive['burst_rate'].value,
                burst_std=drive['burst_std'].value,
                location=drive['location'],
                tstop=drive['tstop'].value,
                weights_ampa=weights_ampa,
                weights_nmda=weights_nmda,
                synaptic_delays=synaptic_delays,
                event_seed=drive['seedcore'].value)
def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
                       all_data, dt, tstop, ntrials, backend_selection,
                       mpi_cmd, n_jobs, params, simulation_status_bar,
                       simulation_status_contents, connectivity_textfields,
                       viz_manager):
    """Run the simulation and plot outputs."""
    log_out.clear_output()
    simulation_data = all_data["simulation_data"]
    with log_out:
        # clear empty trash simulations
        for _name in tuple(simulation_data.keys()):
            if len(simulation_data[_name]['dpls']) == 0:
                del simulation_data[_name]
        _sim_name = widget_simulation_name.value
        if simulation_data[_sim_name]['net'] is not None:
            print("Simulation with the same name exists!")
            simulation_status_bar.value = simulation_status_contents[
                'failed']
            return
        _init_network_from_widgets(params, dt, tstop,
                                   simulation_data[_sim_name], drive_widgets,
                                   connectivity_textfields)
        print("start simulation")
        if backend_selection.value == "MPI":
            backend = MPIBackend(
                n_procs=multiprocessing.cpu_count() - 1, mpi_cmd=mpi_cmd.value)
        else:
            backend = JoblibBackend(n_jobs=n_jobs.value)
            print(f"Using Joblib with {n_jobs.value} core(s).")
        with backend:
            simulation_status_bar.value = simulation_status_contents['running']
            simulation_data[_sim_name]['dpls'] = simulate_dipole(
                simulation_data[_sim_name]['net'],
                tstop=tstop.value,
                dt=dt.value,
                n_trials=ntrials.value)
            simulation_status_bar.value = simulation_status_contents[
                'finished']
    viz_manager.reset_fig_config_tabs()
    viz_manager.add_figure()
    fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
    ax_plots = [("ax0", "input histogram"), ("ax1", "current dipole")]
    for ax_name, plot_type in ax_plots:
        viz_manager._simulate_edit_figure(fig_name, ax_name, _sim_name,
                                          plot_type, {}, "plot")
def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs):
    """Switch backends between MPI and Joblib."""
    backend_config.clear_output()
    with backend_config:
        if backend_type == "MPI":
            display(mpi_cmd)
        elif backend_type == "Joblib":
            display(n_jobs)
def launch():
    """Launch voila with hnn_widget.ipynb.
    You can pass voila commandline parameters as usual.
    """
    from voila.app import main
    notebook_path = Path(__file__).parent / 'hnn_widget.ipynb'
    main([str(notebook_path.resolve()), *sys.argv[1:]])