Skip to content

Commit

Permalink
fix num_train bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 7, 2023
1 parent 8672104 commit f47b380
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def apply_to_layer(
# get a random subset of the training data
train_sub = dict()
for ds_name, (train_h, train_gt, train_lm_preds) in train_dict.items():
sub_idx = torch.randperm(train_h.shape[0])[:num_platt]
# TODO: Find the KNNs for each training point in the training data
# Using L2 distance works better than inner product distance. See:
# https://arxiv.org/abs/1911.00172, https://arxiv.org/abs/2112.04426
sub_idx = torch.randperm(train_h.shape[0])[:num_train]
train_h_sub = train_h[sub_idx]
train_gt_sub = train_gt[sub_idx]
train_lm_preds_sub = (
Expand Down

0 comments on commit f47b380

Please sign in to comment.