Skip to content

Commit

Permalink
[Tests] add testing for search methods
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Sep 18, 2023
1 parent aa0cb26 commit 9cd32e4
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import pytest

from rl4co.envs import PDPEnv, TSPEnv
from rl4co.models import AttentionModel, HeterogeneousAttentionModel, PPOModel, SymNCO
from rl4co.models import (
ActiveSearch,
AttentionModel,
AutoregressivePolicy,
EASEmb,
EASLay,
HeterogeneousAttentionModel,
PPOModel,
SymNCO,
)
from rl4co.utils import RL4COTrainer


Expand Down Expand Up @@ -50,3 +59,15 @@ def test_ham():
trainer = RL4COTrainer(max_epochs=1)
trainer.fit(model)
trainer.test(model)


@pytest.mark.parametrize("SearchMethod", [ActiveSearch, EASEmb, EASLay])
def test_search_methods(SearchMethod):
env = TSPEnv(num_loc=20)
batch_size = 2 if SearchMethod not in [ActiveSearch] else 1
dataset = env.dataset(2)
policy = AutoregressivePolicy(env)
model = SearchMethod(env, policy, dataset, max_iters=2, batch_size=batch_size)
trainer = RL4COTrainer(max_epochs=1)
trainer.fit(model)
trainer.test(model)

0 comments on commit 9cd32e4

Please sign in to comment.