forked from mllam/neural-lam
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_grid_features.py
63 lines (51 loc) · 1.87 KB
/
create_grid_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Standard library
import os
from argparse import ArgumentParser
# Third-party
import numpy as np
import torch
# First-party
from neural_lam import config
def main():
"""
Pre-compute all static features related to the grid nodes
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--data_config",
type=str,
default="neural_lam/data_config.yaml",
help="Path to data config file (default: neural_lam/data_config.yaml)",
)
args = parser.parse_args()
config_loader = config.Config.from_file(args.data_config)
static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
# -- Static grid node features --
grid_xy = torch.tensor(
np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
) # (2, N_y, N_x)
grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
pos_max = torch.max(torch.abs(grid_xy))
grid_xy = grid_xy / pos_max # Divide by maximum coordinate
geopotential = torch.tensor(
np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
) # (N_y, N_x)
geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
gp_min = torch.min(geopotential)
gp_max = torch.max(geopotential)
# Rescale geopotential to [0,1]
geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)
grid_border_mask = torch.tensor(
np.load(os.path.join(static_dir_path, "border_mask.npy")),
dtype=torch.int64,
) # (N_y, N_x)
grid_border_mask = (
grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
) # (N_grid, 1)
# Concatenate grid features
grid_features = torch.cat(
(grid_xy, geopotential, grid_border_mask), dim=1
) # (N_grid, 4)
torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))
if __name__ == "__main__":
main()