From eb692162f47e1dc32e67e5bc054aaa9d82346602 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Tue, 13 Aug 2024 11:13:19 +0200 Subject: [PATCH] Fix test --- aurora/batch.py | 6 +++--- aurora/model/aurora.py | 7 ++++--- tests/test_headers.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/aurora/batch.py b/aurora/batch.py index e95dc15..06ad408 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -135,6 +135,6 @@ def to(self, device: str | torch.device) -> "Batch": """Move the batch to another device.""" return self._fmap(lambda x: x.to(device)) - def float(self) -> "Batch": - """Convert everything to `float32`s.""" - return self._fmap(lambda x: x.float()) + def type(self, t: type) -> "Batch": + """Convert everything to type `t`.""" + return self._fmap(lambda x: x.type(t)) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 53beaa6..c23574b 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -111,11 +111,12 @@ def forward(self, batch: Batch) -> Batch: Returns: :class:`Batch`: Prediction for the batch. """ - batch = batch.float() # `float64`s will take up too much memory. + # Get the first parameter. We'll derive the data type and device from this parameter. + p = next(self.parameters()) + batch = batch.type(p.dtype) batch = batch.normalise() batch = batch.crop(patch_size=self.patch_size) - # Assume that all parameters of the model are either on the CPU or GPU. - batch = batch.to(next(self.parameters()).device) + batch = batch.to(p.device) H, W = batch.spatial_shape patch_res: Int3Tuple = ( diff --git a/tests/test_headers.py b/tests/test_headers.py index 4267ae3..67499c4 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -15,7 +15,7 @@ relative_path = path.relative_to(_root) # Ignore a possible virtual environment. - if str(relative_path.parents[-2]) in {"venv"}: + if len(relative_path.parents) >= 2 and str(relative_path.parents[-2]) in {"venv"}: continue # Ignore the automatically generated version file.