diff --git a/new_test.py b/new_test.py index 4e56e54..db2ea9e 100644 --- a/new_test.py +++ b/new_test.py @@ -140,7 +140,7 @@ def main(): state_dict = torch.load(opt["test_load_path"], map_location=device)["state_dict"] if "test_load_path_aux" in opt: - new_state_dict = torch.load(opt["test_load_path_aux"], map_location=device)["state_dict"] + aux_state_dict = torch.load(opt["test_load_path_aux"], map_location=device)["state_dict"] from collections import OrderedDict @@ -152,7 +152,9 @@ def main(): ki = k fusion_state_dict[ki] = v - for k, v in new_state_dict.items(): + for k, v in aux_state_dict.items(): + if k.startswith("frag"): + continue if k.startswith("vqa_head"): ki = k.replace("vqa", "resize") else: diff --git a/new_train.py b/new_train.py index 60e3db9..21da7ed 100755 --- a/new_train.py +++ b/new_train.py @@ -395,7 +395,36 @@ def main(): reinit=True, ) - if "load_path" in opt: + if "load_path_aux" in opt: + state_dict = torch.load(opt["load_path"], map_location=device)["state_dict"] + aux_state_dict = torch.load(opt["load_path_aux"], map_location=device)["state_dict"] + + from collections import OrderedDict + + fusion_state_dict = OrderedDict() + for k, v in state_dict.items(): + if "head" in k: + continue + if k.startswith("vqa_head"): + ki = k.replace("vqa", "fragments") + else: + ki = k + fusion_state_dict[ki] = v + + for k, v in aux_state_dict.items(): + if "head" in k: + continue + if k.startswith("frag"): + continue + if k.startswith("vqa_head"): + ki = k.replace("vqa", "resize") + else: + ki = k + fusion_state_dict[ki] = v + state_dict = fusion_state_dict + print(model.load_state_dict(state_dict)) + + elif "load_path" in opt: state_dict = torch.load(opt["load_path"], map_location=device) if "state_dict" in state_dict: diff --git a/split_train.py b/split_train.py index 1298dcb..18c99ca 100644 --- a/split_train.py +++ b/split_train.py @@ -249,7 +249,7 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s' video[key] = data[key].to(device) ## Reshape into clips b, c, t, h, w = video[key].shape - video[key] = video[key].reshape(b, c, data["num_clips"], t // data["num_clips"], h, w).permute(0,2,1,3,4,5).reshape(b * data["num_clips"], c, t // data["num_clips"], h, w) + video[key] = video[key].reshape(b, c, data["num_clips"][key], t // data["num_clips"][key], h, w).permute(0,2,1,3,4,5).reshape(b * data["num_clips"][key], c, t // data["num_clips"][key], h, w) if key + "_up" in data: video_up[key] = data[key+"_up"].to(device) ## Reshape into clips @@ -354,7 +354,6 @@ def main(): bests_ = [] - model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device) if opt.get("split_seed", -1) > 0: num_splits = 10 @@ -366,7 +365,8 @@ def main(): if opt.get("split_seed", -1) > 0: ann_path = opt["data"]["train"]["args"]["anno_file"] - for split in range(0,num_splits): + for split in range(1,num_splits): + model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device) print(split) if opt.get("split_seed", -1) > 0: split_duo = train_test_split(opt["data"]["train"]["args"]["data_prefix"], @@ -408,7 +408,35 @@ def main(): reinit=True, ) - if "load_path" in opt: + if "load_path_aux" in opt: + state_dict = torch.load(opt["load_path"], map_location=device)["state_dict"] + aux_state_dict = torch.load(opt["load_path_aux"], map_location=device)["state_dict"] + + from collections import OrderedDict + + fusion_state_dict = OrderedDict() + for k, v in state_dict.items(): + if "head" in k: + continue + if k.startswith("vqa_head"): + ki = k.replace("vqa", "fragments") + else: + ki = k + fusion_state_dict[ki] = v + + for k, v in aux_state_dict.items(): + if "head" in k: + continue + if k.startswith("frag"): + continue + if k.startswith("vqa_head"): + ki = k.replace("vqa", "resize") + else: + ki = k + fusion_state_dict[ki] = v + state_dict = fusion_state_dict + print(model.load_state_dict(state_dict,strict=False)) + elif "load_path" in opt: state_dict = torch.load(opt["load_path"], map_location=device) if "state_dict" in state_dict: @@ -494,10 +522,10 @@ def main(): model_ema if model_ema is not None else model, device, bests[key], save_model=opt["save_model"], save_name=opt["name"], suffix = key+"_s", - ) + ) if model_ema is not None: bests_n[key] = inference_set( - val_loaders[key], + val_loaders[key], model, device, bests_n[key], save_model=opt["save_model"], save_name=opt["name"], suffix = key+'_n', @@ -524,7 +552,9 @@ def main(): KROCC: {bests_n[key][2]:.4f} RMSE: {bests_n[key][3]:.4f}.""" ) - + + + for key, value in dict(model.named_children()).items(): if "backbone" in key: for param in value.parameters():