Skip to content

Commit

Permalink
Merge pull request #13 from ModECI/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
rimjhimittal authored Aug 27, 2024
2 parents 3115cba + 7cc6f57 commit b218236
Show file tree
Hide file tree
Showing 3 changed files with 7,374 additions and 64 deletions.
140 changes: 76 additions & 64 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from modeci_mdf.mdf import Model, Graph, Node, Parameter, OutputPort
from modeci_mdf.utils import load_mdf_json, load_mdf, load_mdf_yaml
from modeci_mdf.execution_engine import EvaluableGraph, EvaluableOutput
import json
import json, yaml, bson
import numpy as np
import requests
st.set_page_config(layout="wide", page_icon="logo.png", page_title="Model Description Format", menu_items={
'Report a bug': "https://github.com/ModECI/MDF/",
st.set_page_config(layout="wide", page_icon="page_icon.png", page_title="Model Description Format", menu_items={
'Report a bug': "https://github.com/ModECI/MDF-UI/",
'About': "ModECI (Model Exchange and Convergence Initiative) is a multi-investigator collaboration that aims to develop a standardized format for exchanging computational models across diverse software platforms and domains of scientific research and technology development, with a particular focus on neuroscience, Machine Learning and Artificial Intelligence. Refer to https://modeci.org/ for more."
})

Expand All @@ -24,9 +24,6 @@ def run_simulation(param_inputs, mdf_model, stateful):
if stateful:
duration = param_inputs["Simulation Duration (s)"]
dt = param_inputs["Time Step (s)"]



for node in nodes:
eg = EvaluableGraph(mod_graph, verbose=False)
t = 0
Expand Down Expand Up @@ -59,7 +56,7 @@ def run_simulation(param_inputs, mdf_model, stateful):
for node in nodes:
eg = EvaluableGraph(mod_graph, verbose=False)
eg.evaluate()
all_node_results[node.id] = pd.DataFrame({op.value: [float(eg.enodes[node.id].evaluable_outputs[op.id].curr_value)] for op in node.output_ports})
all_node_results[node.id] = pd.DataFrame({op.value: [eg.enodes[node.id].evaluable_outputs[op.id].curr_value] for op in node.output_ports})

return all_node_results
def show_simulation_results(all_node_results, stateful_nodes):
Expand All @@ -71,24 +68,26 @@ def show_simulation_results(all_node_results, stateful_nodes):
st.session_state.selected_columns = {node_id: {col: True for col in chart_data.columns}}
elif node_id not in st.session_state.selected_columns:
st.session_state.selected_columns[node_id] = {col: True for col in chart_data.columns}
columns = chart_data.columns
for column in columns:
st.checkbox(
f"{column}",
value=st.session_state.selected_columns[node_id][column],
key=f"checkbox_{node_id}_{column}",
on_change=update_selected_columns,
args=(node_id, column,)
)


# Filter the data based on selected checkboxes
filtered_data = chart_data[[col for col, selected in st.session_state.selected_columns[node_id].items() if selected]]

# Display the line chart with filtered data
st.line_chart(filtered_data, use_container_width=True, height=400)
columns = chart_data.columns
checks = st.columns(8)
if len(columns) > 0 and len(st.session_state.selected_columns[node_id])>1:
for l, column in enumerate(columns):
with checks[l]:
st.checkbox(
f"{column}",
value=st.session_state.selected_columns[node_id][column],
key=f"checkbox_{node_id}_{column}",
on_change=update_selected_columns,
args=(node_id, column,)
)
else:
st.write(all_node_results[node_id])

for col in chart_data.columns:
st.write(f"{col}: {chart_data[col][0]}")

def update_selected_columns(node_id, column):
st.session_state.selected_columns[node_id][column] = st.session_state[f"checkbox_{node_id}_{column}"]
Expand All @@ -99,55 +98,54 @@ def show_mdf_graph(mdf_model):
image_path = mdf_model.id + ".png"
st.image(image_path, caption="Model Graph Visualization")

def show_json_output(mdf_model):
def show_json_model(mdf_model):
st.subheader("JSON Model")
st.json(mdf_model.to_json())

# st.cache_data()
def view_tabs(mdf_model, param_inputs, stateful): # view
tab1, tab2, tab3 = st.tabs(["Simulation Results", "MDF Graph", "Json Model"])
with tab1:
if stateful:
if 'simulation_results' not in st.session_state:
st.session_state.simulation_results = None

if st.session_state.simulation_results is not None:
show_simulation_results(st.session_state.simulation_results, stateful)
else:
st.write("Run the simulation to see results.") # model
if 'simulation_run' not in st.session_state or not st.session_state.simulation_run:
st.write("Run the simulation to see results.")
elif st.session_state.simulation_results is not None:
show_simulation_results(st.session_state.simulation_results, stateful)
else:
if 'simulation_results' not in st.session_state:
st.session_state.simulation_results = None

if st.session_state.simulation_results is not None:
show_simulation_results(st.session_state.simulation_results, stateful)
else:
st.write("Stateless.")

st.write("No simulation results available.")
with tab2:
show_mdf_graph(mdf_model) # view
with tab3:
show_json_output(mdf_model) # view
show_json_model(mdf_model) # view

def display_and_edit_array(array, key):
if isinstance(array, list):
array = np.array(array)

rows, cols = array.shape if array.ndim > 1 else (1, len(array))

edited_array = []
for i in range(rows):
row = []
for j in range(cols):
value = array[i][j] if array.ndim > 1 else array[i]
edited_value = st.text_input(f"[{i}][{j}]", value=str(value), key=f"{key}_{i}_{j}")
try:
row.append(float(edited_value))
except ValueError:
st.error(f"Invalid input for [{i}][{j}]. Please enter a valid number.")
edited_array.append(row)

return np.array(edited_array)
if rows*cols > 10:
st.write(array)
st.write("Array Shape:", array.shape)
else:
edited_array = []
if rows == 1:
for j in range(cols):
value = array[j] if array.ndim > 1 else array[j]
edited_value = st.text_input(f"[{j}]", value=str(value), key=f"{key}_{j}")
try:
edited_array.append(float(edited_value))
except ValueError:
st.error(f"Invalid input for [{j}]. Please enter a valid number.")
else:
for i in range(rows):
row = []
for j in range(cols):
value = array[i][j] if array.ndim > 1 else array[i]
edited_value = st.text_input(f"[{i}][{j}]", value=str(value), key=f"{key}_{i}_{j}")
try:
row.append(float(edited_value))
except ValueError:
st.error(f"Invalid input for [{i}][{j}]. Please enter a valid number.")
edited_array.append(row)

return np.array(edited_array)

def parameter_form_to_update_model_and_view(mdf_model):
mod_graph = mdf_model.graphs[0]
Expand Down Expand Up @@ -186,8 +184,13 @@ def parameter_form_to_update_model_and_view(mdf_model):

# Create four columns for each node
col1, col2, col3, col4 = st.columns(4)

parameter_list = []
for i, param in enumerate(node.parameters):
if isinstance(param.value, str) or param.value is None:
continue
else:
parameter_list.append(param)
for i, param in enumerate(parameter_list):
if isinstance(param.value, str) or param.value is None:
continue
key = f"{param.id}_{node_index}_{i}"
Expand Down Expand Up @@ -233,32 +236,36 @@ def parameter_form_to_update_model_and_view(mdf_model):
valid_inputs = False

run_button = st.form_submit_button("Run Simulation")

if run_button:
if valid_inputs:
for node in nodes:
for param in node.parameters:
if param.id in param_inputs:
param.value = param_inputs[param.id]
st.session_state.simulation_results = run_simulation(param_inputs, mdf_model, stateful)

view_tabs(mdf_model, param_inputs, stateful_nodes)
st.session_state.simulation_run = True
else:
st.error("Please correct the invalid inputs before running the simulation.")
view_tabs(mdf_model, param_inputs, stateful_nodes)


def upload_file_and_load_to_model():

uploaded_file = st.sidebar.file_uploader("Choose a JSON/YAML/BSON file", type=["json", "yaml", "bson"])
github_url = st.sidebar.text_input("Enter GitHub raw file URL:", placeholder="Enter GitHub raw file URL")
example_models = {
"Newton Cooling Model": "./examples/NewtonCoolingModel.json",
# "ABCD": "./examples/ABCD.json",
"ABCD": "./examples/ABCD.json",
"FN": "./examples/FN.mdf.json",
"States": "./examples/States.json",
"Swicthed RLC Circuit": "./examples/switched_rlc_circuit.json",
"Switched RLC Circuit": "./examples/switched_rlc_circuit.json",
"Simple":"./examples/Simple.json",
# "Arrays":"./examples/Arrays.json",
# "RNN":"./examples/RNNs.json",
"Arrays":"./examples/Arrays.json",
# "RNN":"./examples/RNNs.json", # some issue
"IAF":"./examples/IAFs.json",
"Izhikevich Test":"./examples/IzhikevichTest.mdf.json"
"Izhikevich Test":"./examples/IzhikevichTest.mdf.json",
"Keras to MDF IRIS":"./examples/keras_to_MDF.json",
}
selected_model = st.sidebar.selectbox("Choose an example model", list(example_models.keys()), index=None, placeholder="Dont have an MDF Model? Try some sample examples here!")

Expand Down Expand Up @@ -290,7 +297,11 @@ def load_model_from_content(file_content, file_extension):
json_data = json.loads(file_content)
mdf_model = Model.from_dict(json_data)
elif file_extension in ['yaml', 'yml']:
mdf_model = load_mdf_yaml(io.BytesIO(file_content))
yaml_data = yaml.safe_load(file_content)
mdf_model = Model.from_dict(yaml_data)
elif file_extension == 'bson':
bson_data = bson.decode(file_content)
mdf_model = Model.from_dict(bson_data)
else:
st.error("Unsupported file format. Please use JSON or YAML files.")
return None
Expand All @@ -311,6 +322,7 @@ def main():
mdf_model = upload_file_and_load_to_model() # controller

if mdf_model:
st.session_state.current_model = mdf_model
header1, header2 = st.columns([1, 8], vertical_alignment="top")
with header1:
with st.container():
Expand Down
Loading

0 comments on commit b218236

Please sign in to comment.