-
Notifications
You must be signed in to change notification settings - Fork 6
/
model_debug.py
33 lines (23 loc) · 869 Bytes
/
model_debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import dotenv
import hydra
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities.distributed import rank_zero_info
from lasaft.models.conditioned.scripts import trainer as trainer
from lasaft.utils.functions import mkdir_if_not_exists, print_config
dotenv.load_dotenv(override=True)
def main(cfg: DictConfig):
# Load config
rank_zero_info(OmegaConf.to_yaml(cfg))
# Pretty print config using Rich library
if cfg.get("print_config"):
print_config(cfg, resolve=True)
# if cfg['model']['spec_type'] != 'magnitude':
# cfg['model']['input_channels'] = 4
# model = framework(**args)
model = hydra.utils.instantiate(cfg['model'])
a = 5
@hydra.main(config_path="conf", config_name="model_debug")
def hydra_entry(cfg: DictConfig) -> None:
main(cfg)
if __name__ == '__main__':
hydra_entry()