Skip to content

Commit

Permalink
update all baseline arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 12, 2024
1 parent b6e1434 commit 1f7942f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions dominoes/experiments/arglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def add_network_training_metaparameters(parser):
action="store_true",
help="if used, will not use a baseline correction during training (default=False)",
)
parser.add_argument("--bl_temperature", type=float, default=1.0, help="temperature for baseline networks during training")
parser.add_argument("--bl_thompson", default=False, action="store_true", help="if used, will use Thompson sampling for baseline networks")
parser.add_argument("--bl_significance", type=float, default=0.05, help="significance level for updating baseline networks")
return parser


Expand Down
3 changes: 3 additions & 0 deletions dominoes/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def make_train_parameters(self, dataset, train=True):
parameters["temperature"] = self.args.train_temperature if train else 1.0
parameters["thompson"] = not self.args.no_thompson if train else False
parameters["baseline"] = not self.args.no_baseline if train else False
parameters["bl_temperature"] = self.args.bl_temperature
parameters["bl_thompson"] = not self.args.bl_thompson
parameters["bl_significance"] = self.args.bl_significance
parameters["gamma"] = self.args.gamma
parameters["save_loss"] = self.args.save_loss
parameters["save_reward"] = self.args.save_reward
Expand Down
2 changes: 1 addition & 1 deletion dominoes/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def train(nets, optimizers, dataset, **parameters):
learning_mode = parameters.get("learning_mode")
temperature = parameters.get("temperature", 1.0)
thompson = parameters.get("thompson", True)
baseline = parameters.get("baseline", False)
baseline = parameters.get("baseline", True)

# process the learning_mode and save conditions
get_loss = learning_mode == "supervised" or parameters.get("save_loss", False)
Expand Down

0 comments on commit 1f7942f

Please sign in to comment.