-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
73 lines (55 loc) · 2.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from jax import config
from omegaconf import DictConfig, OmegaConf
from jax_sph.defaults import defaults
def check_subset(superset, subset, full_key=""):
"""Check that the keys of 'subset' are a subset of 'superset'."""
for k, v in subset.items():
key = full_key + k
if isinstance(v, dict):
check_subset(superset[k], v, key + ".")
else:
msg = f"cli_args must be a subset of the defaults. Wrong cli key: '{key}'"
assert k in superset, msg
def load_embedded_configs(cli_args: DictConfig) -> DictConfig:
"""Loads all 'extends' embedded configs and merge them with the cli overwrites."""
cfgs = [OmegaConf.load(cli_args.config)]
while "extends" in cfgs[0]:
extends_path = cfgs[0]["extends"]
del cfgs[0]["extends"]
# go to parents configs until the defaults are reached
if extends_path != "JAX_SPH_DEFAULTS":
cfgs = [OmegaConf.load(extends_path)] + cfgs
else:
cfgs = [defaults] + cfgs
# Assert that the cli_args and all inherited config files are a subset of
# the defaults if inheritance from defaults is used.
# Exclude case.special from this check as it is case-specific.
for cfg in cfgs[1:] + [cli_args]:
cfg = cfg.copy()
if "case" in cfg and "special" in cfg.case:
del cfg.case.special
check_subset(defaults, cfg)
break
# merge all embedded configs and give highest priority to cli_args
cfg = OmegaConf.merge(*cfgs, cli_args)
return cfg
if __name__ == "__main__":
cli_args = OmegaConf.from_cli()
assert "config" in cli_args, "A configuration file must be specified."
cfg = load_embedded_configs(cli_args)
# Specify cuda device. These setting must be done before importing jax-md.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
if cfg.gpu == "-1":
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.xla_mem_fraction)
# for reproducibility
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
# os.environ["TF_DETERMINISTIC_OPS"] = "1"
if cfg.no_jit:
config.update("jax_disable_jit", True)
if cfg.dtype == "float64":
config.update("jax_enable_x64", True)
from jax_sph.simulate import simulate
simulate(cfg)