diff --git a/rl4co/models/zoo/active_search/search.py b/rl4co/models/zoo/active_search/search.py index 7aabad3b..1d13ef70 100644 --- a/rl4co/models/zoo/active_search/search.py +++ b/rl4co/models/zoo/active_search/search.py @@ -98,7 +98,7 @@ def setup(self, stage="fit"): ) self.instance_rewards = torch.zeros(dataset_size) - def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: + def on_train_batch_start(self, batch: Any, batch_idx: int): """Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure the optimizer. """ diff --git a/rl4co/models/zoo/eas/search.py b/rl4co/models/zoo/eas/search.py index c22962fd..7c07165b 100644 --- a/rl4co/models/zoo/eas/search.py +++ b/rl4co/models/zoo/eas/search.py @@ -106,7 +106,7 @@ def setup(self, stage="fit"): self.instance_solutions = [] self.instance_rewards = [] - def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: + def on_train_batch_start(self, batch: Any, batch_idx: int): """Called before training (i.e. search) for a new batch begins. We re-load the original policy state dict and configure all parameters not to require gradients. We do the rest in the training step.