Skip to content

Commit

Permalink
Fix #2452 - Check model parameters on startup, remove reactive seed (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince authored Nov 5, 2024
1 parent a49e40d commit e9c0530
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
48 changes: 29 additions & 19 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import asyncio
import copy
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -299,9 +300,12 @@ def ModelCreator(model, model_params, seed=1):
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
"""
user_params, fixed_params = split_model_params(model_params)
solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)

reactive_seed = solara.use_reactive(seed)
user_params, fixed_params = split_model_params(model_params)

model_parameters, set_model_parameters = solara.use_state(
{
Expand All @@ -310,29 +314,35 @@ def ModelCreator(model, model_params, seed=1):
}
)

def do_reseed():
"""Update the random seed for the model."""
reactive_seed.value = model.value.random.random()

def on_change(name, value):
set_model_parameters({**model_parameters, name: value})
new_model_parameters = {**model_parameters, name: value}
model.value = model.value.__class__(**new_model_parameters)
set_model_parameters(new_model_parameters)

def create_model():
model.value = model.value.__class__(**model_parameters)
model.value._seed = reactive_seed.value
UserInputs(user_params, on_change=on_change)

solara.use_effect(create_model, [model_parameters, reactive_seed.value])

with solara.Row(justify="space-between"):
solara.InputText(
label="Seed",
value=reactive_seed,
continuous_update=True,
)
def _check_model_params(init_func, model_params):
"""Check if model parameters are valid for the model's initialization function.
solara.Button(label="Reseed", color="primary", on_click=do_reseed)
Args:
init_func: Model initialization function
model_params: Dictionary of model parameters
UserInputs(user_params, on_change=on_change)
Raises:
ValueError: If a parameter is not valid for the model's initialization function
"""
model_parameters = inspect.signature(init_func).parameters
for name in model_parameters:
if (
model_parameters[name].default == inspect.Parameter.empty
and name not in model_params
and name != "self"
):
raise ValueError(f"Missing required model parameter: {name}")
for name in model_params:
if name not in model_parameters:
raise ValueError(f"Invalid model parameter: {name}")


@solara.component
Expand Down
46 changes: 45 additions & 1 deletion tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import unittest

import ipyvuetify as vw
import pytest
import solara

import mesa
import mesa.visualization.components.altair_components
import mesa.visualization.components.matplotlib_components
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
from mesa.visualization.solara_viz import (
Slider,
SolaraViz,
UserInputs,
_check_model_params,
)


class TestMakeUserInput(unittest.TestCase): # noqa: D101
Expand Down Expand Up @@ -152,3 +158,41 @@ def test_slider(): # noqa: D103
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider


def test_model_param_checks(): # noqa: D103
class ModelWithOptionalParams:
def __init__(self, required_param, optional_param=10):
pass

class ModelWithOnlyRequired:
def __init__(self, param1, param2):
pass

# Test that optional params can be omitted
_check_model_params(ModelWithOptionalParams.__init__, {"required_param": 1})

# Test that optional params can be provided
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "optional_param": 5}
)

# Test invalid parameter name raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: invalid_param"):
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "invalid_param": 2}
)

# Test missing required parameter raises ValueError
with pytest.raises(ValueError, match="Missing required model parameter: param2"):
_check_model_params(ModelWithOnlyRequired.__init__, {"param1": 1})

# Test passing extra parameters raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: extra"):
_check_model_params(
ModelWithOnlyRequired.__init__, {"param1": 1, "param2": 2, "extra": 3}
)

# Test empty params dict raises ValueError if required params
with pytest.raises(ValueError, match="Missing required model parameter"):
_check_model_params(ModelWithOnlyRequired.__init__, {})

0 comments on commit e9c0530

Please sign in to comment.