Skip to content

Commit

Permalink
baseline network bugs and notes on progress
Browse files Browse the repository at this point in the history
  • Loading branch information
landoskape committed May 12, 2024
1 parent 1f7942f commit 4243e1d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion dominoes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"\n",
"# Add mechanism for printing the arguments used to build a pointernetwork so the user can see what they did. \n",
"\n",
"# Need to add baselining to the pointer networks...\n",
"# Add documentation of baseline updates and performance etc\n",
"\n",
"# :)"
]
Expand Down
18 changes: 10 additions & 8 deletions dominoes/networks/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def make_baseline_nets(nets, dataset, batch_parameters={}, significance=0.05, ma

def check_baseline_updates(nets, bl_nets):
"""check if baseline networks should be updated"""
for net, bl_net in zip(nets, bl_nets):
bl_net.check_improvement(net)
return bl_net
for inet, net in enumerate(nets):
bl_nets[inet].check_improvement(net)
return bl_nets


class BaselineNetwork(nn.Module):
Expand All @@ -44,7 +44,6 @@ def __init__(self, net, dataset, batch_parameters={}, significance=0.05, max_out
self.dataset = dataset
self.batch_parameters = batch_parameters
self.max_output = max_output
self.update_reference()

# set forward kwargs for use in every forward pass
self.forward_kwargs = dict(
Expand All @@ -53,14 +52,17 @@ def __init__(self, net, dataset, batch_parameters={}, significance=0.05, max_out
max_output=self.max_output,
)

# create a reference batch
self.update_reference()

# set update significance
self.set_significance(significance)

def update_reference(self):
"""set the reference batch for the baseline network"""
self.ref_batch = self.dataset.generate_batch(**self.batch_parameters)
self.ref_choices = forward_batch([self.net], self.ref_batch, **self.forward_kwargs)[1][0]
self.ref_rewards = self.dataset.reward_function(self.ref_choices, self.ref_batch)
ref_choices = forward_batch([self.net], self.ref_batch, **self.forward_kwargs)[1][0]
self.ref_rewards = self.dataset.reward_function(ref_choices, self.ref_batch)

@torch.no_grad()
def update_network(self, net):
Expand All @@ -76,8 +78,8 @@ def set_significance(self, significance):
@torch.no_grad()
def check_improvement(self, net):
"""check if the network should be updated based on the reference batch"""
choices = forward_batch([net], self.batch, **self.forward_kwargs)[1][0]
rewards = self.dataset.reward_function(choices, self.batch)
choices = forward_batch([net], self.ref_batch, **self.forward_kwargs)[1][0]
rewards = self.dataset.reward_function(choices, self.ref_batch)
p = ttest_rel(rewards.view(-1).cpu().numpy(), self.ref_rewards.view(-1).cpu().numpy(), alternative="greater")[1]

if p < self.significance:
Expand Down

0 comments on commit 4243e1d

Please sign in to comment.