Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2452 - handle solara viz model params better #2454

Merged
merged 7 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import asyncio
import copy
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -299,9 +300,12 @@ def ModelCreator(model, model_params, seed=1):
- The component provides an interface for adjusting user-defined parameters and reseeding the model.

"""
user_params, fixed_params = split_model_params(model_params)
solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)

reactive_seed = solara.use_reactive(seed)
user_params, fixed_params = split_model_params(model_params)

model_parameters, set_model_parameters = solara.use_state(
{
Expand All @@ -310,29 +314,35 @@ def ModelCreator(model, model_params, seed=1):
}
)

def do_reseed():
"""Update the random seed for the model."""
reactive_seed.value = model.value.random.random()

def on_change(name, value):
set_model_parameters({**model_parameters, name: value})
new_model_parameters = {**model_parameters, name: value}
model.value = model.value.__class__(**new_model_parameters)
set_model_parameters(new_model_parameters)

def create_model():
model.value = model.value.__class__(**model_parameters)
model.value._seed = reactive_seed.value
UserInputs(user_params, on_change=on_change)

solara.use_effect(create_model, [model_parameters, reactive_seed.value])

with solara.Row(justify="space-between"):
solara.InputText(
label="Seed",
value=reactive_seed,
continuous_update=True,
)
def _check_model_params(init_func, model_params):
"""Check if model parameters are valid for the model's initialization function.

solara.Button(label="Reseed", color="primary", on_click=do_reseed)
Args:
init_func: Model initialization function
model_params: Dictionary of model parameters

UserInputs(user_params, on_change=on_change)
Raises:
ValueError: If a parameter is not valid for the model's initialization function
"""
model_parameters = inspect.signature(init_func).parameters
for name in model_parameters:
if (
model_parameters[name].default == inspect.Parameter.empty
and name not in model_params
and name != "self"
):
raise ValueError(f"Missing required model parameter: {name}")
for name in model_params:
if name not in model_parameters:
raise ValueError(f"Invalid model parameter: {name}")


@solara.component
Expand Down
46 changes: 45 additions & 1 deletion tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import unittest

import ipyvuetify as vw
import pytest
import solara

import mesa
import mesa.visualization.components.altair_components
import mesa.visualization.components.matplotlib_components
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
from mesa.visualization.solara_viz import (
Slider,
SolaraViz,
UserInputs,
_check_model_params,
)


class TestMakeUserInput(unittest.TestCase): # noqa: D101
Expand Down Expand Up @@ -152,3 +158,41 @@ def test_slider(): # noqa: D103
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider


def test_model_param_checks(): # noqa: D103
class ModelWithOptionalParams:
def __init__(self, required_param, optional_param=10):
pass

class ModelWithOnlyRequired:
def __init__(self, param1, param2):
pass

# Test that optional params can be omitted
_check_model_params(ModelWithOptionalParams.__init__, {"required_param": 1})

# Test that optional params can be provided
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "optional_param": 5}
)

# Test invalid parameter name raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: invalid_param"):
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "invalid_param": 2}
)

# Test missing required parameter raises ValueError
with pytest.raises(ValueError, match="Missing required model parameter: param2"):
_check_model_params(ModelWithOnlyRequired.__init__, {"param1": 1})

# Test passing extra parameters raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: extra"):
_check_model_params(
ModelWithOnlyRequired.__init__, {"param1": 1, "param2": 2, "extra": 3}
)

# Test empty params dict raises ValueError if required params
with pytest.raises(ValueError, match="Missing required model parameter"):
_check_model_params(ModelWithOnlyRequired.__init__, {})
Loading