Skip to content

Commit

Permalink
[MRG] Add Min/Max Frequency and Add Visual Updates Ahead of 0.4 (jone…
Browse files Browse the repository at this point in the history
…scompneurolab#952)

* visual updates to gui, including placeholders for default values for min/max spectral frequency

* update visuals for adding a drive

* make min/max spectral widgets functional

* fix type in comment

* add whitespace around minus operator

* update test_gui_add_dries to reflect the capitalization of the drive location dropdown box

* add test for correct setting of min/max frequency

* update initial equivalency check for tests on default visualization parameters
  • Loading branch information
dylansdaniels authored Nov 27, 2024
1 parent b988c7f commit d867c6d
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 35 deletions.
14 changes: 8 additions & 6 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
simulation_names = tuple(data['simulations'].keys())
sim_index = 0
default_smoothing = fig_default_params['default_smoothing']
default_min_frequency = fig_default_params['default_min_frequency']
default_max_frequency = fig_default_params['default_max_frequency']
if not simulation_names:
simulation_names = ("None",)
else:
Expand Down Expand Up @@ -650,7 +652,7 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
style=analysis_style)

min_spectral_frequency = BoundedFloatText(
value=10,
value=default_min_frequency,
min=0.1,
max=1000,
description='Min Spectral Frequency (Hz):',
Expand All @@ -659,7 +661,7 @@ def _get_ax_control(widgets, data, fig_default_params, fig_idx, fig, ax):
style=analysis_style)

max_spectral_frequency = BoundedFloatText(
value=100,
value=default_max_frequency,
min=0.1,
max=1000,
description='Max Spectral Frequency (Hz):',
Expand Down Expand Up @@ -773,12 +775,12 @@ def _close_figure(b, widgets, data, fig_idx):
display(Label(_fig_placeholder))


def _add_axes_controls(widgets, data, fig_default_smoothing, fig, axd):
def _add_axes_controls(widgets, data, fig_default_params, fig, axd):
fig_idx = data['fig_idx']['idx']

controls = Tab()
children = [
_get_ax_control(widgets, data, fig_default_smoothing, fig_idx=fig_idx,
_get_ax_control(widgets, data, fig_default_params, fig_idx=fig_idx,
fig=fig, ax=ax)
for ax_key, ax in axd.items()
]
Expand All @@ -799,7 +801,7 @@ def _add_axes_controls(widgets, data, fig_default_smoothing, fig, axd):
widgets['axes_config_tabs'].set_title(n_tabs, _idx2figname(fig_idx))


def _add_figure(b, widgets, data, fig_default_smoothing,
def _add_figure(b, widgets, data, fig_default_params,
template_type, scale=0.95, dpi=96):
fig_idx = data['fig_idx']['idx']
viz_output_layout = data['visualization_output']
Expand Down Expand Up @@ -832,7 +834,7 @@ def _add_figure(b, widgets, data, fig_default_smoothing,
else:
display(fig.canvas)

_add_axes_controls(widgets, data, fig_default_smoothing, fig=fig, axd=axd)
_add_axes_controls(widgets, data, fig_default_params, fig=fig, axd=axd)

data['figs'][fig_idx] = fig
widgets['figs_tabs'].selected_index = n_tabs
Expand Down
149 changes: 124 additions & 25 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ipywidgets import (HTML, Accordion, AppLayout, BoundedFloatText,
BoundedIntText, Button, Dropdown, FileUpload, VBox,
HBox, IntText, Layout, Output, RadioButtons, Tab, Text,
Checkbox)
Checkbox, Box)
from ipywidgets.embed import embed_minimal_html
import hnn_core
from hnn_core import JoblibBackend, MPIBackend, simulate_dipole
Expand Down Expand Up @@ -290,7 +290,7 @@ def __init__(self, theme_color="#802989",
height=f"{operation_box_height}px",
flex_wrap="wrap",
),
"config_box": Layout(width=f"{left_sidebar_width}px",
"config_box": Layout(width=f"{left_sidebar_width - 40}px",
height=f"{config_box_height - 100}px"),
"drive_widget": Layout(width="auto"),
"drive_textbox": Layout(width='270px', height='auto'),
Expand Down Expand Up @@ -324,12 +324,37 @@ def __init__(self, theme_color="#802989",
self.simulation_data = defaultdict(lambda: dict(net=None, dpls=list()))

# Default visualization params for figures
analysis_style = {'description_width': '200px'}
layout = Layout(width="300px")

self.widget_default_smoothing = BoundedFloatText(
value=30.0, description='Smoothing:',
min=0.0, max=100.0, step=1.0, disabled=False)
min=0.0, max=100.0, step=1.0, disabled=False,
layout=layout, style=analysis_style,
)

self.widget_min_frequency = BoundedFloatText(
value=10,
min=0.1,
max=1000,
description='Min Spectral Frequency (Hz):',
disabled=False,
layout=layout,
style=analysis_style)

self.widget_max_frequency = BoundedFloatText(
value=100,
min=0.1,
max=1000,
description='Max Spectral Frequency (Hz):',
disabled=False,
layout=layout,
style=analysis_style)

self.fig_default_params = {
'default_smoothing': self.widget_default_smoothing.value
'default_smoothing': self.widget_default_smoothing.value,
'default_min_frequency': self.widget_min_frequency.value,
'default_max_frequency': self.widget_max_frequency.value,
}

# Simulation parameters
Expand Down Expand Up @@ -374,16 +399,20 @@ def __init__(self, theme_color="#802989",
description='',
layout={'width': '15%'})
# Drive selection
self.widget_drive_type_selection = RadioButtons(
self.widget_drive_type_selection = Dropdown(
options=['Evoked', 'Poisson', 'Rhythmic', 'Tonic'],
value='Evoked',
description='Drive:',
description='Drive type:',
disabled=False,
layout=self.layout['drive_widget'])
self.widget_location_selection = RadioButtons(
options=['proximal', 'distal'], value='proximal',
description='Location', disabled=False,
layout=self.layout['drive_widget'])
layout=self.layout['drive_widget'],
style={'description_width': '100px'}
)
self.widget_location_selection = Dropdown(
options=['Proximal', 'Distal'], value='Proximal',
description='Drive location:', disabled=False,
layout=self.layout['drive_widget'],
style={'description_width': '100px'},
)
self.add_drive_button = create_expanded_button(
'Add drive', 'primary', layout=self.layout['btn'],
button_color=self.layout['theme_color'])
Expand All @@ -405,7 +434,7 @@ def __init__(self, theme_color="#802989",
button_style='success')

self.delete_drive_button = create_expanded_button(
'Delete drives', 'success', layout=self.layout['btn'],
'Delete all drives', 'success', layout=self.layout['btn'],
button_color=self.layout['theme_color'])

self.cell_type_radio_buttons = RadioButtons(
Expand Down Expand Up @@ -543,9 +572,10 @@ def _handle_backend_change(backend_type):
self.widget_n_jobs)

def _add_drive_button_clicked(b):
location = self.widget_location_selection.value.lower()
return self.add_drive_widget(
self.widget_drive_type_selection.value,
self.widget_location_selection.value,
location,
)

def _delete_drives_clicked(b):
Expand Down Expand Up @@ -576,6 +606,7 @@ def _run_button_clicked(b):
self.widget_simulation_name, self._log_out, self.drive_widgets,
self.data, self.widget_dt, self.widget_tstop,
self.fig_default_params, self.widget_default_smoothing,
self.widget_min_frequency, self.widget_max_frequency,
self.widget_ntrials, self.widget_backend_selection,
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
Expand Down Expand Up @@ -677,11 +708,31 @@ def compose(self, return_layout=True):
If the method returns the layout object which can be rendered by
IPython.display.display() method.
"""
box_style = """
style="
background: gray;
color: white;
# font-weight: bold;
width: 290px;
padding: 0px 5px;
margin-bottom: 2px;
"
"""
simulation_box = VBox([
HTML(f"<div {box_style}>Simulation Parameters</div>"),
VBox([
self.widget_simulation_name, self.widget_tstop, self.widget_dt,
self.widget_ntrials, self.widget_default_smoothing,
self.widget_ntrials,
self.widget_backend_selection, self._backend_config_out]),
Box(layout=Layout(height="20px")),
HTML(
f"<div {box_style}'>Default Visualization Parameters</div>",
),
VBox([
self.widget_default_smoothing,
self.widget_min_frequency,
self.widget_max_frequency,
])
], layout=self.layout['config_box'])

connectivity_configuration = Tab()
Expand Down Expand Up @@ -1157,21 +1208,35 @@ def create_expanded_button(description, button_style, layout, disabled=False,
def _get_connectivity_widgets(conn_data):
"""Create connectivity box widgets from specified weight and probability"""

style = {'description_width': '150px'}
style = {}
style = {'description_width': '100px'}
sliders = list()
for receptor_name in conn_data.keys():
w_text_input = BoundedFloatText(
value=conn_data[receptor_name]['weight'], disabled=False,
continuous_update=False, min=0, max=1e6, step=0.01,
description="weight", style=style)
description="Weight:", style=style)

display_name = conn_data[receptor_name]['receptor'].upper()

map_display_names = {
'GABAA': 'GABA<sub>A</sub>',
'GABAB': 'GABA<sub>B</sub>',
}

if display_name in map_display_names:
display_name = map_display_names[display_name]

html_tab = '&emsp;'

conn_widget = VBox([
HTML(value=f"""<p>
Receptor: {conn_data[receptor_name]['receptor']}</p>"""),
w_text_input, HTML(value="<hr style='margin-bottom:5px'/>")
HTML(value=f"""<p style='margin:5px;'><b>{html_tab}{html_tab}
Receptor: {display_name}</b></p>"""),
w_text_input
])

# Add class to child Vboxes for targeted CSS
conn_widget.add_class('connectivity-subsection')

conn_widget._belongsto = {
"receptor": conn_data[receptor_name]['receptor'],
"location": conn_data[receptor_name]['location'],
Expand Down Expand Up @@ -1672,13 +1737,44 @@ def add_network_connectivity_tab(net, connectivity_out,
connectivity_textfields.append(
_get_connectivity_widgets(receptor_related_conn))

# Style the contents of the Connectivity Tab
# -------------------------------------------------------------------------

# define custom Vbox layout
# no_padding_layout = Layout(padding="0", margin="0") # unused

# Initialize sections within the Accordion

connectivity_boxes = [VBox(slider) for slider in connectivity_textfields]

# Add class to child Vboxes for targeted CSS
for box in connectivity_boxes:
box.add_class("connectivity-contents")

# Initialize the Accordion section

cell_connectivity = Accordion(children=connectivity_boxes)

# Add class to Accordion section for targeted CSS
cell_connectivity.add_class("connectivity-section")

for idx, connectivity_name in enumerate(connectivity_names):
cell_connectivity.set_title(idx, connectivity_name)

# Style the <div> automatically created around connectivity boxes
connectivity_out_style = HTML("""
<style>
/* CSS to style elements inside the Accordion */
.connectivity-section .jupyter-widget-Collapse-contents {
padding: 0px 0px 10px 0px !important;
margin: 0 !important;
}
</style>
""")

# Display the Accordion with styling
with connectivity_out:
display(cell_connectivity)
display(connectivity_out_style, cell_connectivity)

return net

Expand Down Expand Up @@ -1923,6 +2019,7 @@ def _init_network_from_widgets(params, dt, tstop, single_simulation_data,
def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
all_data, dt, tstop,
fig_default_params, widget_default_smoothing,
widget_min_frequency, widget_max_frequency,
ntrials, backend_selection,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
Expand Down Expand Up @@ -1974,12 +2071,14 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,

viz_manager.reset_fig_config_tabs()

# update default_smoothing in gui based on widget
# update default visualization params in gui based on widget
fig_default_params['default_smoothing'] = widget_default_smoothing.value
fig_default_params['default_min_frequency'] = widget_min_frequency.value
fig_default_params['default_max_frequency'] = widget_max_frequency.value

# change default smoothing in viz_manager to mirror gui
new_default_smoothing = fig_default_params['default_smoothing']
viz_manager.fig_default_params['default_smoothing'] = new_default_smoothing
# change default visualization params in viz_manager to mirror gui
for widget, value in fig_default_params.items():
viz_manager.fig_default_params[widget] = value

viz_manager.add_figure()
fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
Expand Down
45 changes: 41 additions & 4 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_gui_add_drives():
_ = gui.compose()

for val_drive_type in ("Poisson", "Evoked", "Rhythmic"):
for val_location in ("distal", "proximal"):
for val_location in ("Distal", "Proximal"):
gui.delete_drive_button.click()
assert len(gui.drive_widgets) == 0

Expand All @@ -334,7 +334,9 @@ def test_gui_add_drives():

assert len(gui.drive_widgets) == 1
assert gui.drive_widgets[0]['type'] == val_drive_type
assert gui.drive_widgets[0]['location'] == val_location
# note that val_location is transformed to .lower() after the
# add_drive_button.click() action
assert gui.drive_widgets[0]['location'] == val_location.lower()
assert val_drive_type in gui.drive_widgets[0]['name']
plt.close('all')

Expand Down Expand Up @@ -1128,8 +1130,7 @@ def test_default_smoothing(setup_gui):
gui_smooth_value = gui.fig_default_params['default_smoothing']
viz_smooth_value = gui.viz_manager.fig_default_params['default_smoothing']

assert gui_smooth_value == 30
assert viz_smooth_value == 30
assert gui_smooth_value == viz_smooth_value

# update simulation name
gui.widget_simulation_name.value = 'no_smoothing'
Expand Down Expand Up @@ -1174,3 +1175,39 @@ def test_default_smoothing(setup_gui):
assert gui.viz_manager.figs[figid].axes[0].has_data()

plt.close('all')


def test_default_frequencies(setup_gui):
"""Tests that default min/max frequency are inherited correctly"""
gui = setup_gui

# check that the defaults are the same everywhere after running
# the default simulation
gui.run_button.click()

gui_min = gui.fig_default_params['default_min_frequency']
viz_min = gui.viz_manager.fig_default_params['default_min_frequency']
gui_max = gui.fig_default_params['default_max_frequency']
viz_max = gui.viz_manager.fig_default_params['default_max_frequency']

assert gui_min == viz_min
assert gui_max == viz_max

# change value of default min/max frequencies in the widget
new_min = 5
new_max = 50
gui.widget_min_frequency.value = new_min
gui.widget_max_frequency.value = new_max

# update simulation name
gui.widget_simulation_name.value = 'new_defaults'
gui.run_button.click()

# check that the new default smoothing value is set everywhere
gui_min = gui.fig_default_params['default_min_frequency']
viz_min = gui.viz_manager.fig_default_params['default_min_frequency']
gui_max = gui.fig_default_params['default_max_frequency']
viz_max = gui.viz_manager.fig_default_params['default_max_frequency']

assert gui_min == viz_min == new_min
assert gui_max == viz_max == new_max

0 comments on commit d867c6d

Please sign in to comment.