Skip to content

Commit

Permalink
DEBUG FOR CONSISTENCY in V3.0.0 data/training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
teowu committed Sep 24, 2022
1 parent c9554e0 commit e4143db
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 10 deletions.
6 changes: 4 additions & 2 deletions new_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():
state_dict = torch.load(opt["test_load_path"], map_location=device)["state_dict"]

if "test_load_path_aux" in opt:
new_state_dict = torch.load(opt["test_load_path_aux"], map_location=device)["state_dict"]
aux_state_dict = torch.load(opt["test_load_path_aux"], map_location=device)["state_dict"]

from collections import OrderedDict

Expand All @@ -152,7 +152,9 @@ def main():
ki = k
fusion_state_dict[ki] = v

for k, v in new_state_dict.items():
for k, v in aux_state_dict.items():
if k.startswith("frag"):
continue
if k.startswith("vqa_head"):
ki = k.replace("vqa", "resize")
else:
Expand Down
31 changes: 30 additions & 1 deletion new_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,36 @@ def main():
reinit=True,
)

if "load_path" in opt:
if "load_path_aux" in opt:
state_dict = torch.load(opt["load_path"], map_location=device)["state_dict"]
aux_state_dict = torch.load(opt["load_path_aux"], map_location=device)["state_dict"]

from collections import OrderedDict

fusion_state_dict = OrderedDict()
for k, v in state_dict.items():
if "head" in k:
continue
if k.startswith("vqa_head"):
ki = k.replace("vqa", "fragments")
else:
ki = k
fusion_state_dict[ki] = v

for k, v in aux_state_dict.items():
if "head" in k:
continue
if k.startswith("frag"):
continue
if k.startswith("vqa_head"):
ki = k.replace("vqa", "resize")
else:
ki = k
fusion_state_dict[ki] = v
state_dict = fusion_state_dict
print(model.load_state_dict(state_dict))

elif "load_path" in opt:
state_dict = torch.load(opt["load_path"], map_location=device)

if "state_dict" in state_dict:
Expand Down
44 changes: 37 additions & 7 deletions split_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s'
video[key] = data[key].to(device)
## Reshape into clips
b, c, t, h, w = video[key].shape
video[key] = video[key].reshape(b, c, data["num_clips"], t // data["num_clips"], h, w).permute(0,2,1,3,4,5).reshape(b * data["num_clips"], c, t // data["num_clips"], h, w)
video[key] = video[key].reshape(b, c, data["num_clips"][key], t // data["num_clips"][key], h, w).permute(0,2,1,3,4,5).reshape(b * data["num_clips"][key], c, t // data["num_clips"][key], h, w)
if key + "_up" in data:
video_up[key] = data[key+"_up"].to(device)
## Reshape into clips
Expand Down Expand Up @@ -354,7 +354,6 @@ def main():

bests_ = []

model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device)

if opt.get("split_seed", -1) > 0:
num_splits = 10
Expand All @@ -366,7 +365,8 @@ def main():
if opt.get("split_seed", -1) > 0:
ann_path = opt["data"]["train"]["args"]["anno_file"]

for split in range(0,num_splits):
for split in range(1,num_splits):
model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device)
print(split)
if opt.get("split_seed", -1) > 0:
split_duo = train_test_split(opt["data"]["train"]["args"]["data_prefix"],
Expand Down Expand Up @@ -408,7 +408,35 @@ def main():
reinit=True,
)

if "load_path" in opt:
if "load_path_aux" in opt:
state_dict = torch.load(opt["load_path"], map_location=device)["state_dict"]
aux_state_dict = torch.load(opt["load_path_aux"], map_location=device)["state_dict"]

from collections import OrderedDict

fusion_state_dict = OrderedDict()
for k, v in state_dict.items():
if "head" in k:
continue
if k.startswith("vqa_head"):
ki = k.replace("vqa", "fragments")
else:
ki = k
fusion_state_dict[ki] = v

for k, v in aux_state_dict.items():
if "head" in k:
continue
if k.startswith("frag"):
continue
if k.startswith("vqa_head"):
ki = k.replace("vqa", "resize")
else:
ki = k
fusion_state_dict[ki] = v
state_dict = fusion_state_dict
print(model.load_state_dict(state_dict,strict=False))
elif "load_path" in opt:
state_dict = torch.load(opt["load_path"], map_location=device)

if "state_dict" in state_dict:
Expand Down Expand Up @@ -494,10 +522,10 @@ def main():
model_ema if model_ema is not None else model,
device, bests[key], save_model=opt["save_model"], save_name=opt["name"],
suffix = key+"_s",
)
)
if model_ema is not None:
bests_n[key] = inference_set(
val_loaders[key],
val_loaders[key],
model,
device, bests_n[key], save_model=opt["save_model"], save_name=opt["name"],
suffix = key+'_n',
Expand All @@ -524,7 +552,9 @@ def main():
KROCC: {bests_n[key][2]:.4f}
RMSE: {bests_n[key][3]:.4f}."""
)




for key, value in dict(model.named_children()).items():
if "backbone" in key:
for param in value.parameters():
Expand Down

0 comments on commit e4143db

Please sign in to comment.