Skip to content

Commit

Permalink
integrate baseline duty cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 14, 2024
1 parent 8a3defb commit 138e95d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion dominoes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def train(nets, optimizers, dataset, **parameters):
bl_thompson = parameters.get("bl_thompson", False)
bl_significance = parameters.get("bl_significance", 0.05)
bl_batch_size = parameters.get("bl_batch_size", 1024)
bl_duty_cycle = parameters.get("bl_duty_cycle", 1)
bl_parameters = parameters.copy()
bl_parameters["batch_size"] = bl_batch_size # update batch size for baseline reference batch
bl_nets = make_baseline_nets(
Expand Down Expand Up @@ -113,7 +114,8 @@ def train(nets, optimizers, dataset, **parameters):
opt.step()

# update baseline networks if required
bl_nets = check_baseline_updates(nets, bl_nets)
if baseline and epoch % bl_duty_cycle == 0:
bl_nets = check_baseline_updates(nets, bl_nets)

# save training data
with torch.no_grad():
Expand Down

0 comments on commit 138e95d

Please sign in to comment.