diff --git a/elk/run.py b/elk/run.py index 85f244c7..26aa6d42 100644 --- a/elk/run.py +++ b/elk/run.py @@ -131,13 +131,13 @@ def prepare_data( split = ds[key].with_format("torch", device=device, dtype=torch.int16) labels = assert_type(Tensor, split["label"]) - val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) + split_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"])) with split.formatted_as("torch", device=device): has_preds = "model_logits" in split.features lm_preds = split["model_logits"] if has_preds else None - out[ds_name] = (val_h, labels.to(val_h.device), lm_preds) + out[ds_name] = (split_h, labels.to(split_h.device), lm_preds) return out diff --git a/elk/training/train.py b/elk/training/train.py index 61cfb816..e89b7baf 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -108,6 +108,40 @@ def apply_to_layer( platt_set_h = torch.cat(platt_set_h) platt_set_gt = torch.cat(platt_set_gt) + # precompute distances between all training points + train_h_dist_mats = dict() + for ds_name, (train_h, _, _) in train_dict.items(): + train_h_dist_mats[ds_name] = torch.zeros((len(train_h), len(train_h))) + for i, h1_raw in enumerate(train_h): + for j, h2_raw in enumerate(train_h): + if i == j: + train_h_dist_mats[ds_name][i, j] = 0 + elif i < j: + # compute L2 distance between h1 and h2 + # where we concatenate away the variants dimension + # Using L2 distance works better than inner product distance. + # See: + # https://arxiv.org/abs/1911.00172 + # https://arxiv.org/abs/2112.04426 + h1 = rearrange(h1_raw, "v k d -> k (v d)") + h2 = rearrange(h2_raw, "v k d -> k (v d)") + assert ( + h1.shape[0] == h2.shape[0] == 2 + ), "Only supports 2 classes" + # the distance is the minimum of the 4 pairwise distances to + # because we don't know the labels, so we don't want our metric + # to say all the nearby examples have the same label + dist = min( + [ + torch.dist(h1[0], h2[0]), # type: ignore + torch.dist(h1[0], h2[1]), # type: ignore + torch.dist(h1[1], h2[0]), # type: ignore + torch.dist(h1[1], h2[1]), # type: ignore + ] + ) + train_h_dist_mats[ds_name][i, j] = dist + train_h_dist_mats[ds_name][j, i] = dist + reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) train_subs = defaultdict(list) train_sub_credences = defaultdict(list) @@ -119,10 +153,8 @@ def apply_to_layer( # get a random subset of the training data train_sub = dict() for ds_name, (train_h, train_gt, train_lm_preds) in train_dict.items(): - # TODO: Find the KNNs for each training point in the training data - # Using L2 distance works better than inner product distance. See: - # https://arxiv.org/abs/1911.00172, https://arxiv.org/abs/2112.04426 - sub_idx = torch.randperm(train_h.shape[0])[:num_train] + # Find the KNNs for each training point in the training data + sub_idx = torch.argsort(train_h_dist_mats[ds_name][i])[:num_train] train_h_sub = train_h[sub_idx] train_gt_sub = train_gt[sub_idx] train_lm_preds_sub = (