Skip to content

Commit

Permalink
KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed May 8, 2023
1 parent f47b380 commit d670cc4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def prepare_data(

split = ds[key].with_format("torch", device=device, dtype=torch.int16)
labels = assert_type(Tensor, split["label"])
val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))
split_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))

with split.formatted_as("torch", device=device):
has_preds = "model_logits" in split.features
lm_preds = split["model_logits"] if has_preds else None

out[ds_name] = (val_h, labels.to(val_h.device), lm_preds)
out[ds_name] = (split_h, labels.to(split_h.device), lm_preds)

return out

Expand Down
40 changes: 36 additions & 4 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,40 @@ def apply_to_layer(
platt_set_h = torch.cat(platt_set_h)
platt_set_gt = torch.cat(platt_set_gt)

# precompute distances between all training points
train_h_dist_mats = dict()
for ds_name, (train_h, _, _) in train_dict.items():
train_h_dist_mats[ds_name] = torch.zeros((len(train_h), len(train_h)))
for i, h1_raw in enumerate(train_h):
for j, h2_raw in enumerate(train_h):
if i == j:
train_h_dist_mats[ds_name][i, j] = 0
elif i < j:
# compute L2 distance between h1 and h2
# where we concatenate away the variants dimension
# Using L2 distance works better than inner product distance.
# See:
# https://arxiv.org/abs/1911.00172
# https://arxiv.org/abs/2112.04426
h1 = rearrange(h1_raw, "v k d -> k (v d)")
h2 = rearrange(h2_raw, "v k d -> k (v d)")
assert (
h1.shape[0] == h2.shape[0] == 2
), "Only supports 2 classes"
# the distance is the minimum of the 4 pairwise distances to
# because we don't know the labels, so we don't want our metric
# to say all the nearby examples have the same label
dist = min(
[
torch.dist(h1[0], h2[0]), # type: ignore
torch.dist(h1[0], h2[1]), # type: ignore
torch.dist(h1[1], h2[0]), # type: ignore
torch.dist(h1[1], h2[1]), # type: ignore
]
)
train_h_dist_mats[ds_name][i, j] = dist
train_h_dist_mats[ds_name][j, i] = dist

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_subs = defaultdict(list)
train_sub_credences = defaultdict(list)
Expand All @@ -119,10 +153,8 @@ def apply_to_layer(
# get a random subset of the training data
train_sub = dict()
for ds_name, (train_h, train_gt, train_lm_preds) in train_dict.items():
# TODO: Find the KNNs for each training point in the training data
# Using L2 distance works better than inner product distance. See:
# https://arxiv.org/abs/1911.00172, https://arxiv.org/abs/2112.04426
sub_idx = torch.randperm(train_h.shape[0])[:num_train]
# Find the KNNs for each training point in the training data
sub_idx = torch.argsort(train_h_dist_mats[ds_name][i])[:num_train]
train_h_sub = train_h[sub_idx]
train_gt_sub = train_gt[sub_idx]
train_lm_preds_sub = (
Expand Down

0 comments on commit d670cc4

Please sign in to comment.