diff --git a/test/dattri/algorithm/test_influence_function.py b/test/dattri/algorithm/test_influence_function.py index 2b86b7b..7cc3945 100644 --- a/test/dattri/algorithm/test_influence_function.py +++ b/test/dattri/algorithm/test_influence_function.py @@ -312,3 +312,58 @@ def f(params, data_target_pair): assert (cached_train_reps.shape[0] == 4 * math.ceil(0.5 * len(train_loader))) # Check for transformed test rep assert (torch.allclose(ground_truth_test_rep, transformed_test_rep, atol=1e-4)) + + def test_ekfac_transform_test_rep(self): + """Test for EK-FAC test representation transformation.""" + def average_pairwise_correlation(tensor1, tensor2): + stacked = torch.stack([tensor1, tensor2], dim=0) + reshaped = stacked.view(2, -1) + + corr_matrix = torch.corrcoef(reshaped) + pairwise_corr = corr_matrix[0, 1] + + return pairwise_corr.item() + + train_dataset = TensorDataset( + torch.randn(20, 1, 28, 28), + torch.randint(0, 10, (20,)), + ) + train_loader = DataLoader(train_dataset, batch_size=4) + + model = train_mnist_lr(train_loader) + + def f(params, data_target_pair): + image, label = data_target_pair + loss = nn.CrossEntropyLoss() + yhat = torch.func.functional_call(model, params, image) + return loss(yhat, label.long()) + + task = AttributionTask( + loss_func=f, + model=model, + checkpoints=model.state_dict(), + ) + + # EK-FAC + attributor = IFAttributorEKFAC( + task=task, + device=torch.device("cpu"), + damping=0.1, + ) + attributor.cache(train_loader) + + attributor_gt = IFAttributorExplicit( + task=task, + layer_name=attributor.layer_name, + device=torch.device("cpu"), + regularization=1e-3, + ) + attributor_gt.cache(train_loader) + + test_rep = torch.randn((30, 7850), device=torch.device("cpu")) + transformed_test_rep = attributor.transform_test_rep(0, test_rep) + gt_test_rep = attributor_gt.transform_test_rep(0, test_rep[:, :7840]) + + # Check pair-wise correlation + corr = average_pairwise_correlation(gt_test_rep, transformed_test_rep) + assert corr > 0.98 # noqa: PLR2004