Skip to content

Commit

Permalink
Merge branch 'main' of github-personal:microsoft/aurora
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Aug 13, 2024
2 parents 89918c7 + eb69216 commit 92620e9
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
6 changes: 3 additions & 3 deletions aurora/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,6 @@ def to(self, device: str | torch.device) -> "Batch":
"""Move the batch to another device."""
return self._fmap(lambda x: x.to(device))

def float(self) -> "Batch":
"""Convert everything to `float32`s."""
return self._fmap(lambda x: x.float())
def type(self, t: type) -> "Batch":
"""Convert everything to type `t`."""
return self._fmap(lambda x: x.type(t))
7 changes: 4 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ def forward(self, batch: Batch) -> Batch:
Returns:
:class:`Batch`: Prediction for the batch.
"""
batch = batch.float() # `float64`s will take up too much memory.
# Get the first parameter. We'll derive the data type and device from this parameter.
p = next(self.parameters())
batch = batch.type(p.dtype)
batch = batch.normalise()
batch = batch.crop(patch_size=self.patch_size)
# Assume that all parameters of the model are either on the CPU or GPU.
batch = batch.to(next(self.parameters()).device)
batch = batch.to(p.device)

H, W = batch.spatial_shape
patch_res: Int3Tuple = (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
relative_path = path.relative_to(_root)

# Ignore a possible virtual environment.
if str(relative_path.parents[-2]) in {"venv"}:
if len(relative_path.parents) >= 2 and str(relative_path.parents[-2]) in {"venv"}:
continue

# Ignore the automatically generated version file.
Expand Down

0 comments on commit 92620e9

Please sign in to comment.