From e9c0530070bbcda3db43b6ee00eed5b85cabab73 Mon Sep 17 00:00:00 2001 From: Corvince <13568919+Corvince@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:18:15 +0100 Subject: [PATCH] Fix #2452 - Check model parameters on startup, remove reactive seed (#2454) --- mesa/visualization/solara_viz.py | 48 +++++++++++++++++++------------- tests/test_solara_viz.py | 46 +++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index d5d83f437ea..fefbfb32a04 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -25,6 +25,7 @@ import asyncio import copy +import inspect from collections.abc import Callable from typing import TYPE_CHECKING, Literal @@ -299,9 +300,12 @@ def ModelCreator(model, model_params, seed=1): - The component provides an interface for adjusting user-defined parameters and reseeding the model. """ - user_params, fixed_params = split_model_params(model_params) + solara.use_effect( + lambda: _check_model_params(model.value.__class__.__init__, fixed_params), + [model.value], + ) - reactive_seed = solara.use_reactive(seed) + user_params, fixed_params = split_model_params(model_params) model_parameters, set_model_parameters = solara.use_state( { @@ -310,29 +314,35 @@ def ModelCreator(model, model_params, seed=1): } ) - def do_reseed(): - """Update the random seed for the model.""" - reactive_seed.value = model.value.random.random() - def on_change(name, value): - set_model_parameters({**model_parameters, name: value}) + new_model_parameters = {**model_parameters, name: value} + model.value = model.value.__class__(**new_model_parameters) + set_model_parameters(new_model_parameters) - def create_model(): - model.value = model.value.__class__(**model_parameters) - model.value._seed = reactive_seed.value + UserInputs(user_params, on_change=on_change) - solara.use_effect(create_model, [model_parameters, reactive_seed.value]) - with solara.Row(justify="space-between"): - solara.InputText( - label="Seed", - value=reactive_seed, - continuous_update=True, - ) +def _check_model_params(init_func, model_params): + """Check if model parameters are valid for the model's initialization function. - solara.Button(label="Reseed", color="primary", on_click=do_reseed) + Args: + init_func: Model initialization function + model_params: Dictionary of model parameters - UserInputs(user_params, on_change=on_change) + Raises: + ValueError: If a parameter is not valid for the model's initialization function + """ + model_parameters = inspect.signature(init_func).parameters + for name in model_parameters: + if ( + model_parameters[name].default == inspect.Parameter.empty + and name not in model_params + and name != "self" + ): + raise ValueError(f"Missing required model parameter: {name}") + for name in model_params: + if name not in model_parameters: + raise ValueError(f"Invalid model parameter: {name}") @solara.component diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 9276696676f..7dfeae722e0 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -3,13 +3,19 @@ import unittest import ipyvuetify as vw +import pytest import solara import mesa import mesa.visualization.components.altair_components import mesa.visualization.components.matplotlib_components from mesa.visualization.components.matplotlib_components import make_mpl_space_component -from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs +from mesa.visualization.solara_viz import ( + Slider, + SolaraViz, + UserInputs, + _check_model_params, +) class TestMakeUserInput(unittest.TestCase): # noqa: D101 @@ -152,3 +158,41 @@ def test_slider(): # noqa: D103 assert not slider_int.is_float_slider slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float) assert slider_dtype_float.is_float_slider + + +def test_model_param_checks(): # noqa: D103 + class ModelWithOptionalParams: + def __init__(self, required_param, optional_param=10): + pass + + class ModelWithOnlyRequired: + def __init__(self, param1, param2): + pass + + # Test that optional params can be omitted + _check_model_params(ModelWithOptionalParams.__init__, {"required_param": 1}) + + # Test that optional params can be provided + _check_model_params( + ModelWithOptionalParams.__init__, {"required_param": 1, "optional_param": 5} + ) + + # Test invalid parameter name raises ValueError + with pytest.raises(ValueError, match="Invalid model parameter: invalid_param"): + _check_model_params( + ModelWithOptionalParams.__init__, {"required_param": 1, "invalid_param": 2} + ) + + # Test missing required parameter raises ValueError + with pytest.raises(ValueError, match="Missing required model parameter: param2"): + _check_model_params(ModelWithOnlyRequired.__init__, {"param1": 1}) + + # Test passing extra parameters raises ValueError + with pytest.raises(ValueError, match="Invalid model parameter: extra"): + _check_model_params( + ModelWithOnlyRequired.__init__, {"param1": 1, "param2": 2, "extra": 3} + ) + + # Test empty params dict raises ValueError if required params + with pytest.raises(ValueError, match="Missing required model parameter"): + _check_model_params(ModelWithOnlyRequired.__init__, {})