diff --git a/tests/conftest.py b/tests/conftest.py index ba657af5..8e589616 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ import pvnet from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule import pvnet.models.multimodal.encoders.encoders3d import pvnet.models.multimodal.linear_networks.networks @@ -158,21 +157,21 @@ def sample_pv_batch(): # old batches. For now we use the old batches to test the site encoder models return torch.load("tests/test_data/presaved_batches/train/000000.pt") - -@pytest.fixture() -def sample_wind_batch(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - ) - batch = next(iter(dm.train_dataloader())) - return batch +# TODO update this test once we add the loading logic for the Site dataset +# @pytest.fixture() +# def sample_wind_batch(): +# dm = WindDataModule( +# configuration=None, +# batch_size=2, +# num_workers=0, +# prefetch_factor=None, +# train_period=[None, None], +# val_period=[None, None], +# test_period=[None, None], +# batch_dir="tests/test_data/sample_wind_batches", +# ) +# batch = next(iter(dm.train_dataloader())) +# return batch @pytest.fixture() diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 41969b22..08d3cac8 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -41,14 +41,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar batch_size=8, ) - -def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): - _test_model_forward( - sample_wind_batch, - SingleAttentionNetwork, - site_encoder_sensor_model_kwargs, - batch_size=2, - ) +# TODO once we have updated the sample batches for sites include this test +# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): +# _test_model_forward( +# sample_wind_batch, +# SingleAttentionNetwork, +# site_encoder_sensor_model_kwargs, +# batch_size=2, +# ) # Test model backward on all models