-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy patheval_operator.py
84 lines (70 loc) · 3.15 KB
/
eval_operator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import yaml
import torch
from torch.utils.data import DataLoader
from models import FNO3d, FNO2d
from train_utils import NSLoader, get_forcing, DarcyFlow
from train_utils.eval_3d import eval_ns
from train_utils.eval_2d import eval_darcy
from argparse import ArgumentParser
def test_3d(config):
device = 0 if torch.cuda.is_available() else 'cpu'
data_config = config['data']
loader = NSLoader(datapath1=data_config['datapath'],
nx=data_config['nx'], nt=data_config['nt'],
sub=data_config['sub'], sub_t=data_config['sub_t'],
N=data_config['total_num'],
t_interval=data_config['time_interval'])
eval_loader = loader.make_loader(n_sample=data_config['n_sample'],
batch_size=config['test']['batchsize'],
start=data_config['offset'],
train=data_config['shuffle'])
model = FNO3d(modes1=config['model']['modes1'],
modes2=config['model']['modes2'],
modes3=config['model']['modes3'],
fc_dim=config['model']['fc_dim'],
layers=config['model']['layers']).to(device)
if 'ckpt' in config['test']:
ckpt_path = config['test']['ckpt']
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['model'])
print('Weights loaded from %s' % ckpt_path)
print(f'Resolution : {loader.S}x{loader.S}x{loader.T}')
forcing = get_forcing(loader.S).to(device)
eval_ns(model,
loader,
eval_loader,
forcing,
config,
device=device)
def test_2d(config):
device = 0 if torch.cuda.is_available() else 'cpu'
data_config = config['data']
dataset = DarcyFlow(data_config['datapath'],
nx=data_config['nx'], sub=data_config['sub'],
offset=data_config['offset'], num=data_config['n_sample'])
dataloader = DataLoader(dataset, batch_size=config['test']['batchsize'], shuffle=False)
print(device)
model = FNO2d(modes1=config['model']['modes1'],
modes2=config['model']['modes2'],
fc_dim=config['model']['fc_dim'],
layers=config['model']['layers'],
act=config['model']['act']).to(device)
# Load from checkpoint
if 'ckpt' in config['test']:
ckpt_path = config['test']['ckpt']
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['model'])
print('Weights loaded from %s' % ckpt_path)
eval_darcy(model, dataloader, config, device)
if __name__ == '__main__':
parser = ArgumentParser(description='Basic paser')
parser.add_argument('--config_path', type=str, help='Path to the configuration file')
parser.add_argument('--log', action='store_true', help='Turn on the wandb')
options = parser.parse_args()
config_file = options.config_path
with open(config_file, 'r') as stream:
config = yaml.load(stream, yaml.FullLoader)
if 'name' in config['data'] and config['data']['name'] == 'Darcy':
test_2d(config)
else:
test_3d(config)