Skip to content

Commit

Permalink
Merge branch 'main' into sdk-update
Browse files Browse the repository at this point in the history
  • Loading branch information
keighrim committed Jun 26, 2024
2 parents 4e372ca + 89ee167 commit 2507442
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,26 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y
p_scores = []
r_scores = []
f_scores = []
loss = nn.CrossEntropyLoss(reduction="none")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if num_splits == 1, validation is empty. single fold training.
if configs['num_splits'] == 1:
train_guids = set(guids)
validation_guids = set([])
for block in configs['block_guids_train']:
train_guids.discard(block)
# prepare_datasets seems to work fine with empty validation set
train, valid, labelset_size = prepare_datasets(indir, train_guids, validation_guids, configs)
train_loader = DataLoader(train, batch_size=len(guids), shuffle=True)
export_model_file = f"{outdir}/{train_id}.pt"
model = train_model(
get_net(train.feat_dim, labelset_size, configs['num_layers'], configs['dropouts']),
loss, device, train_loader, configs)
torch.save(model.state_dict(), export_model_file)
p_config = Path(f'{outdir}/{train_id}.yml')
export_kfold_config(config_file, configs, p_config)
return
# otherwise, do k-fold training, where k = 'num_splits'
for i in range(0, configs['num_splits']):
validation_guids = set(guids[i*len_val:(i+1)*len_val])
train_guids = set(guids) - validation_guids
Expand All @@ -178,13 +198,11 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y
continue
train_loader = DataLoader(train, batch_size=40, shuffle=True)
valid_loader = DataLoader(valid, batch_size=len(valid), shuffle=True)
loss = nn.CrossEntropyLoss(reduction="none")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f'Split {i}: training on {len(train_guids)} videos, validating on {validation_guids}')
export_csv_file = f"{outdir}/{train_id}.kfold_{i:03d}.csv"
export_model_file = f"{outdir}/{train_id}.kfold_{i:03d}.pt"
model = train_model(
get_net(train.feat_dim, labelset_size, configs['num_layers'], configs['dropouts']),
get_net(train.feat_dim, labelset_size, configs['num_layers'], configs['dropouts']),
loss, device, train_loader, configs)
torch.save(model.state_dict(), export_model_file)
p, r, f = evaluate(model, valid_loader, pretraining_binned_label(config), export_fname=export_csv_file)
Expand Down

0 comments on commit 2507442

Please sign in to comment.