diff --git a/models.py b/models.py index 84bbb03d..b7abf6ac 100644 --- a/models.py +++ b/models.py @@ -696,12 +696,21 @@ def build_model(args, text_aligner, pitch_extractor, bert): def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]): state = torch.load(path, map_location='cpu') params = state['net'] + for key in model: if key in params and key not in ignore_modules: + try: + model[key].load_state_dict(params[key], strict=True) + except: + from collections import OrderedDict + state_dict = params[key] + new_state_dict = OrderedDict() + print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}') + for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()): + new_state_dict[k_m] = v_c + model[key].load_state_dict(new_state_dict, strict=True) print('%s loaded' % key) - model[key].load_state_dict(params[key], strict=False) - _ = [model[key].eval() for key in model] - + if not load_only_params: epoch = state["epoch"] iters = state["iters"]