Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Sx Lau committed Oct 13, 2024
1 parent 36c063a commit 614f800
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions test/dattri/algorithm/test_influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,58 @@ def f(params, data_target_pair):
assert (cached_train_reps.shape[0] == 4 * math.ceil(0.5 * len(train_loader)))
# Check for transformed test rep
assert (torch.allclose(ground_truth_test_rep, transformed_test_rep, atol=1e-4))

def test_ekfac_transform_test_rep(self):
"""Test for EK-FAC test representation transformation."""
def average_pairwise_correlation(tensor1, tensor2):
stacked = torch.stack([tensor1, tensor2], dim=0)
reshaped = stacked.view(2, -1)

corr_matrix = torch.corrcoef(reshaped)
pairwise_corr = corr_matrix[0, 1]

return pairwise_corr.item()

train_dataset = TensorDataset(
torch.randn(20, 1, 28, 28),
torch.randint(0, 10, (20,)),
)
train_loader = DataLoader(train_dataset, batch_size=4)

model = train_mnist_lr(train_loader)

def f(params, data_target_pair):
image, label = data_target_pair
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model, params, image)
return loss(yhat, label.long())

task = AttributionTask(
loss_func=f,
model=model,
checkpoints=model.state_dict(),
)

# EK-FAC
attributor = IFAttributorEKFAC(
task=task,
device=torch.device("cpu"),
damping=0.1,
)
attributor.cache(train_loader)

attributor_gt = IFAttributorExplicit(
task=task,
layer_name=attributor.layer_name,
device=torch.device("cpu"),
regularization=1e-3,
)
attributor_gt.cache(train_loader)

test_rep = torch.randn((30, 7850), device=torch.device("cpu"))
transformed_test_rep = attributor.transform_test_rep(0, test_rep)
gt_test_rep = attributor_gt.transform_test_rep(0, test_rep[:, :7840])

# Check pair-wise correlation
corr = average_pairwise_correlation(gt_test_rep, transformed_test_rep)
assert corr > 0.98 # noqa: PLR2004

0 comments on commit 614f800

Please sign in to comment.