Skip to content

Commit

Permalink
Merge pull request #113 from openclimatefix/end2end_test
Browse files Browse the repository at this point in the history
Add end-to-end training test
  • Loading branch information
dfulu authored Dec 21, 2023
2 parents 82127e7 + eed5241 commit ffbfa04
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
50 changes: 50 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import glob
import tempfile

import pytest
import pandas as pd
Expand All @@ -16,6 +18,7 @@
import pvnet.models.multimodal.encoders.encoders3d
import pvnet.models.multimodal.linear_networks.networks
import pvnet.models.multimodal.site_encoders.encoders
from pvnet.models.multimodal.multimodal import Model


xr.set_options(keep_attrs=True)
Expand Down Expand Up @@ -91,6 +94,41 @@ def sat_data():
return ds


@pytest.fixture()
def sample_train_val_datamodule():
# duplicate the sample batcnes for more training/val data
n_duplicates = 10

with tempfile.TemporaryDirectory() as tmpdirname:
os.makedirs(f"{tmpdirname}/train")
os.makedirs(f"{tmpdirname}/val")

file_n = 0

for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
batch = torch.load(file)

for i in range(n_duplicates):
# Save fopr both train and val
torch.save(batch, f"{tmpdirname}/train/{file_n:06}.pt")
torch.save(batch, f"{tmpdirname}/val/{file_n:06}.pt")

file_n += 1

dm = DataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
block_nwp_and_sat=False,
batch_dir=f"{tmpdirname}",
)
yield dm


@pytest.fixture()
def sample_datamodule():
dm = DataModule(
Expand Down Expand Up @@ -212,3 +250,15 @@ def multimodal_model_kwargs(model_minutes_kwargs):

kwargs.update(model_minutes_kwargs)
return kwargs


@pytest.fixture()
def multimodal_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.fixture()
def multimodal_quantile_model(multimodal_model_kwargs):
model = Model(output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs)
return model
13 changes: 0 additions & 13 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
from pvnet.models.multimodal.multimodal import Model
from torch.optim import SGD
import pytest


@pytest.fixture()
def multimodal_model(multimodal_model_kwargs):
model = Model(**multimodal_model_kwargs)
return model


@pytest.fixture()
def multimodal_quantile_model(multimodal_model_kwargs):
model = Model(output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs)
return model


def test_model_forward(multimodal_model, sample_batch):
y = multimodal_model(sample_batch)

Expand Down
6 changes: 6 additions & 0 deletions tests/test_end2end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import lightning


def test_model_trainer_fit(multimodal_model, sample_train_val_datamodule):
trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True)
trainer.fit(model=multimodal_model, datamodule=sample_train_val_datamodule)

0 comments on commit ffbfa04

Please sign in to comment.