diff --git a/pysr/logger_specs.py b/pysr/logger_specs.py index 354a21de..f4c0a3df 100644 --- a/pysr/logger_specs.py +++ b/pysr/logger_specs.py @@ -19,6 +19,11 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None: """Write hyperparameters to the logger.""" pass # pragma: no cover + @abstractmethod + def close(self, logger: AnyValue) -> None: + """Close the logger instance.""" + pass # pragma: no cover + @dataclass class TensorBoardLoggerSpec(AbstractLoggerSpec): @@ -74,3 +79,7 @@ def write_hparams(self, logger: AnyValue, hparams: dict[str, Any]) -> None: ], ), ) + + def close(self, logger: AnyValue) -> None: + base_logger = jl.SymbolicRegression.get_logger(logger) + jl.close(base_logger) diff --git a/pysr/sr.py b/pysr/sr.py index aa9492e2..3cae1735 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -703,6 +703,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): stored as an array of uint8, produced by Julia's Serialization.serialize function. julia_options_stream_ : ndarray The serialized julia options, stored as an array of uint8, + logger_ : AnyValue | None + The logger instance used for this fit, if any. expression_spec_ : AbstractExpressionSpec The expression specification used for this fit. This is equal to `self.expression_spec` if provided, or `ExpressionSpec()` otherwise. @@ -765,6 +767,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): output_directory_: str julia_state_stream_: NDArray[np.uint8] | None julia_options_stream_: NDArray[np.uint8] | None + logger_: AnyValue | None equation_file_contents_: list[pd.DataFrame] | None show_pickle_warnings_: bool @@ -1917,7 +1920,12 @@ def _run( jl.seval(self.complexity_mapping) if self.complexity_mapping else None ) - logger = self.logger_spec.create_logger() if self.logger_spec else None + if hasattr(self, "logger_") and self.logger_ is not None and self.warm_start: + logger = self.logger_ + else: + logger = self.logger_spec.create_logger() if self.logger_spec else None + + self.logger_ = logger # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl @@ -2058,6 +2066,8 @@ def _run( ) if self.logger_spec is not None: self.logger_spec.write_hparams(logger, self.get_params()) + if not self.warm_start: + self.logger_spec.close(logger) self.julia_state_stream_ = jl_serialize(out) diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 3264383a..52772042 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -647,41 +647,52 @@ def test_tensorboard_logger(self): self.skipTest("TensorBoard not installed. Skipping test.") y = self.X[:, 0] - with tempfile.TemporaryDirectory() as tmpdir: - logger_spec = TensorBoardLoggerSpec( - log_dir=tmpdir, log_interval=2, overwrite=True - ) - model = PySRRegressor( - **self.default_test_kwargs, - logger_spec=logger_spec, - early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1", - ) - model.fit(self.X, y) + for warm_start in [False, True]: + with tempfile.TemporaryDirectory() as tmpdir: + logger_spec = TensorBoardLoggerSpec( + log_dir=tmpdir, log_interval=2, overwrite=True + ) + model = PySRRegressor( + **self.default_test_kwargs, + logger_spec=logger_spec, + early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1", + warm_start=warm_start, + ) + model.fit(self.X, y) + logger = model.logger_ + # Should restart from same logger if warm_start is True + model.fit(self.X, y) + logger2 = model.logger_ + + if warm_start: + self.assertEqual(logger, logger2) + else: + self.assertNotEqual(logger, logger2) - # Verify log directory exists and contains TensorBoard files - log_dir = Path(tmpdir) - assert log_dir.exists() - files = list(log_dir.glob("events.out.tfevents.*")) - assert len(files) == 1 + # Verify log directory exists and contains TensorBoard files + log_dir = Path(tmpdir) + assert log_dir.exists() + files = list(log_dir.glob("events.out.tfevents.*")) + assert len(files) == 1 if warm_start else 2 - # Load and verify TensorBoard events - event_acc = EventAccumulator(str(log_dir)) - event_acc.Reload() + # Load and verify TensorBoard events + event_acc = EventAccumulator(str(log_dir)) + event_acc.Reload() - # Check that we have the expected scalar summaries - scalars = event_acc.Tags()["scalars"] - self.assertIn("search/data/summaries/pareto_volume", scalars) - self.assertIn("search/data/summaries/min_loss", scalars) + # Check that we have the expected scalar summaries + scalars = event_acc.Tags()["scalars"] + self.assertIn("search/data/summaries/pareto_volume", scalars) + self.assertIn("search/data/summaries/min_loss", scalars) - # Check that we have multiple events for each summary - pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume") - min_loss_events = event_acc.Scalars("search/data/summaries/min_loss") + # Check that we have multiple events for each summary + pareto_events = event_acc.Scalars("search/data/summaries/pareto_volume") + min_loss_events = event_acc.Scalars("search/data/summaries/min_loss") - self.assertGreater(len(pareto_events), 0) - self.assertGreater(len(min_loss_events), 0) + self.assertGreater(len(pareto_events), 0) + self.assertGreater(len(min_loss_events), 0) - # Verify model still works as expected - self.assertLessEqual(model.get_best()["loss"], 1e-4) + # Verify model still works as expected + self.assertLessEqual(model.get_best()["loss"], 1e-4) def manually_create_model(equations, feature_names=None):