Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
a_remix_user committed Mar 14, 2023
1 parent 69ac469 commit 947e85d
Show file tree
Hide file tree
Showing 20 changed files with 359 additions and 128 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
data/*
archive.zip
squeezenet.pth
*.zip
*.zip
bad old values/*
GOOD shap/*
pics/*
shap_records/*
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ Organization:
- `datasets.py` configures the [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) and [FairFace](https://github.com/joojs/fairface) datasets and dataloaders.
- `model.py` configures a [SqueezeNet](https://arxiv.org/abs/1602.07360) model for the binary gender detection class. It also defines a `epoch` function, which is a generic train/test loop.
- `train_squeeze.py` finetunes a ImageNet-pretrained SqueezeNet on a gender detection task from the CelebA dataset, and evaluates the resulting model on 500 images from each Black/White Male/Female split of FairFace.
- `mab_shapley.py` finds shapley values for each filter on the FairFace dataset using the multi-armed bandit algorithm described by Ghorbani et al., and outputs approximations of filter shapley values to `shapley_values.pkl`.
- `evaluate.py` evaluates the accuracy on FairFace (decomposed by race and gender) when removing filters with negative shapley values.
- `mab.py` finds Shapley values for each filter on the FairFace dataset using the multi-armed bandit algorithm described by Ghorbani et al., and outputs approximations of filter shapley values to `shapley_values.pkl`.
- `eval.py` evaluates the accuracy on FairFace (decomposed by race and gender) when removing filters with negative shapley values.
- `squeezenet.pth` stores the weights of a SqueezeNet after two epochs of fine tuning on the CelebA gender detection task.
- `shapley_values.pkl` stores the shapley values obtained after 195 iterations of the MAB algorithm.
- `shapley_values.pkl` stores the shapley values obtained after 424 iterations of the MAB algorithm.

Necessary Datasets:
- `./data/celeba` should contain the CelebA dataset, stored as `*.jpg` files within `test`, `train`, and `val` subfolders. The attributes files should be stored in `list_landmarks_align_celeba.csv`
- `./data/fairface` should contain the FairFace dataset, with `train` and `val` subfolders and a `fairface_label_val.csv` label file.

*Reproducibility note: running `evaluate.py` will reproduce the above figure using the precomputed shapley values in `shaple_values.pkl` on a small subset of the CelebA and FairFace dataset we've included in this repository. Replicating the MAB algorithm and our full results require dataset downloads.*
*Reproducibility note: running `eval.py` will reproduce the above figure using the precomputed shapley values in `shaple_values.pkl` on a small subset of the CelebA and FairFace dataset we've included in this repository. Replicating the MAB algorithm and our full results require dataset downloads.*


## Citation
Expand Down
Binary file modified __pycache__/ablation.cpython-38.pyc
Binary file not shown.
Binary file modified __pycache__/datasets.cpython-38.pyc
Binary file not shown.
Binary file added __pycache__/eval.cpython-38.pyc
Binary file not shown.
Binary file modified __pycache__/model.cpython-38.pyc
Binary file not shown.
32 changes: 18 additions & 14 deletions ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import torch as t
import torch.nn as nn
import torch.nn.functional as f
import numpy as np
import math
import time
from functools import partial
from tqdm import tqdm
import pickle
Expand All @@ -33,22 +30,21 @@

# %%

# store the order of convolutional modules so we can look up the means
conv_dict = dict()
for idx, c in enumerate(convs):
conv_dict[c] = idx

# get means of every filter, calculated from val_loader CelebA data

def mean_hook(module, input, output, conv_means):
means = t.mean(output.relu(), dim=0, keepdim=True)
# print(means.shape)
# print(conv_means[conv_dict[module]].shape)
# print("----")
conv_means[conv_dict[module]] = means
# add value to mean_dict[module] if it exists, otherwise create it
# mean_dict[module] = mean_dict.get(module, 0) + means
return output

# calculating the means for every filter takes 16 seconds, so we store it in a file
def forward_pass_store_means(loader):
conv_means = [t.zeros((1,)) for c in convs]

Expand All @@ -67,6 +63,7 @@ def forward_pass_store_means(loader):
with open("means.pkl", "wb") as f:
pickle.dump(conv_means, f)

# load the means from saved file
def load_conv_means():
with open("means.pkl", "rb") as f:
return pickle.load(f)
Expand All @@ -78,27 +75,34 @@ def mean_ablation_hook(module, input, output, ablations, conv_means):
output[:, ablations] = conv_means[conv_dict[module]][:, ablations]
return output

# hook for zero ablating filters during forward pass according to `ablate_mask`
def zero_ablation_hook(module, input, output, ablations):
output[:, ablations] = 0
return output

# forward pass with ablation for *a single batch* from loader
# 1 means ablate the filter. 0 means don't ablate
def forward_pass(ablate_mask, conv_means, loader):
def forward_pass(ablate_mask, conv_means, loader, full_data=True, zero_ablate=False):
start_idx = 0
handlers = []
for conv in convs:
ablations = ablate_mask[start_idx : start_idx + conv.out_channels] == 1
handlers.append(
conv.register_forward_hook(
partial(mean_ablation_hook, ablations=ablations, conv_means=conv_means)
partial(mean_ablation_hook, ablations=ablations, conv_means=conv_means) if not zero_ablate else partial(zero_ablation_hook, ablations=ablations)
)
)
start_idx += conv.out_channels

with t.no_grad():
data, target = next(iter(loader))
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
acc = (pred == target).sum().item() / len(pred)
if full_data:
acc = epoch(model, loader, nn.CrossEntropyLoss(), None, device, False)
else:
with t.no_grad():
data, target = next(iter(loader))
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
acc = (pred == target).sum().item() / len(pred)

for h in handlers:
h.remove()
Expand Down
Binary file removed cb.pkl
Binary file not shown.
42 changes: 24 additions & 18 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_celeb_data_loaders(batch_size=32, num_workers=2):
return train_loader, val_loader


def get_gender_datasets(num_images=500):
def get_gender_datasets(num_images=200, val=False):
# Load the fairface dataset, filtered to be in 20-59 age range and only white/black
df = pd.read_csv("data/fairface/fairface_label_val.csv", index_col="file")
ages = (
Expand All @@ -83,18 +83,22 @@ def get_gender_datasets(num_images=500):
* (df["age"] != "60-69")
* (df["age"] != "more than 70")
)
white_men = df.loc[(df["race"] == "White") * (df["gender"] == "Male") * ages][
:num_images
]
white_women = df.loc[(df["race"] == "White") * (df["gender"] == "Female") * ages][
:num_images
]
black_men = df.loc[(df["race"] == "Black") * (df["gender"] == "Male") * ages][
:num_images
]
black_women = df.loc[(df["race"] == "Black") * (df["gender"] == "Female") * ages][
:num_images
]

white_men = df.loc[(df["race"] == "White") * (df["gender"] == "Male") * ages]
white_women = df.loc[(df["race"] == "White") * (df["gender"] == "Female") * ages]
black_men = df.loc[(df["race"] == "Black") * (df["gender"] == "Male") * ages]
black_women = df.loc[(df["race"] == "Black") * (df["gender"] == "Female") * ages]

if val:
white_men = white_men[-num_images:]
white_women = white_women[-num_images:]
black_men = black_men[-num_images:]
black_women = black_women[-num_images:]
else:
white_men = white_men[:num_images]
white_women = white_women[:num_images]
black_men = black_men[:num_images]
black_women = black_women[:num_images]

white_men_ds = FairfaceDataset("data/fairface", transform, white_men)
white_women_ds = FairfaceDataset("data/fairface", transform, white_women)
Expand All @@ -104,10 +108,11 @@ def get_gender_datasets(num_images=500):
return white_men_ds, white_women_ds, black_men_ds, black_women_ds


def get_gender_dataloaders(batch_size=128, num_workers=2):
def get_gender_dataloaders(batch_size=128, num_images_each=200, num_workers=2, val=False):
white_men_ds, white_women_ds, black_men_ds, black_women_ds = get_gender_datasets(
num_images=500
)
num_images=num_images_each,
val=val
)

white_men_loader = DataLoader(
white_men_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
Expand All @@ -125,9 +130,10 @@ def get_gender_dataloaders(batch_size=128, num_workers=2):
return white_men_loader, white_women_loader, black_men_loader, black_women_loader


def get_combined_gender_loader(batch_size=128, num_workers=2):
def get_combined_gender_loader(batch_size=128, num_images_each=200, num_workers=2, val=False):
white_men_ds, white_women_ds, black_men_ds, black_women_ds = get_gender_datasets(
num_images=500
num_images=num_images_each,
val=val
)

combined_dataset = ConcatDataset(
Expand Down
130 changes: 60 additions & 70 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,29 @@
# %%
# Evaluation file for ablating neurons selected by highest Shapley values

from ablation import filt_len, celeba_val_loader, load_conv_means, forward_pass, get_gender_dataloaders

import numpy as np
from tqdm import tqdm
import pickle

import matplotlib.pyplot as plt

# %%
# load in shapley values
with open("shapley_values.pkl", "rb") as pickle_file:
shapley_values = pickle.load(pickle_file)
# # load in shapley values
# with open("shapley_values.pkl", "rb") as pickle_file:
# shapley_values = pickle.load(pickle_file)

shapley_values.sort()
print(sum(shapley_values))
print(sum(shapley_values[-100:]))

# %%
# shapley_values.sort()
# print(sum(shapley_values))
# print(sum(shapley_values[-100:]))

# batch is entire dataset
white_men_dataloader, white_women_dataloader, black_men_dataloader, black_women_dataloader = get_gender_dataloaders()

with open("shapley_values_first_iter.pkl", "rb") as pickle_file:
shapley_values = pickle.load(pickle_file)

neurons_sorted = np.argsort(shapley_values) # ascending order

# %%

def check_score(ablate_mask):
conv_means = load_conv_means()
# check the score on FairFace for a given ablation of the filters
def check_score(ablate_mask, conv_means):
wm_acc = forward_pass(ablate_mask, conv_means, white_men_dataloader)
ww_acc = forward_pass(ablate_mask, conv_means, white_women_dataloader)
bm_acc = forward_pass(ablate_mask, conv_means, black_men_dataloader)
Expand All @@ -39,66 +35,60 @@ def check_score(ablate_mask):
# print("Black women", bw_acc)
# print("overall", overall_acc)
# print(len(celeba_val_loader))
celeba_acc = forward_pass(
ablate_mask, conv_means, celeba_val_loader
) # TODO uses new batch every iteration, weird
celeba_acc = 0
# celeba_acc = forward_pass(
# ablate_mask, conv_means, celeba_val_loader
# )
# print("celeba", celeba_acc)

return wm_acc, ww_acc, bm_acc, bw_acc, overall_acc, celeba_acc

# plot accuracy vs number of top filters ablated
def save_plots(iters=False):
conv_means = load_conv_means()

ablate_mask = np.zeros(filt_len)
check_score(ablate_mask)

wm_accs, ww_accs, bm_accs, bw_accs, overall_accs, celeba_accs = [], [], [], [], [], []
for i in tqdm(range(30)):
ablate_mask[neurons_sorted[-i - 1]] = 1
wm_acc, ww_acc, bm_ac, bw_acc, overall_acc, celeba_acc = check_score(ablate_mask)
wm_accs.append(wm_acc)
ww_accs.append(ww_acc)
bm_accs.append(bm_ac)
bw_accs.append(bw_acc)
overall_accs.append(overall_acc)
celeba_accs.append(celeba_acc)

# %%

import matplotlib.pyplot as plt

plt.hist(shapley_values, bins=100)
plt.show()

plt.plot(celeba_accs, label="celeba")
plt.plot(wm_accs, label="white men")
plt.plot(ww_accs, label="white women")
plt.plot(bm_accs, label="black men")
plt.plot(bw_accs, label="black women")
plt.plot(overall_accs, label="overall")
plt.legend()
plt.xlabel("Number of filters ablated")
plt.ylabel("Test Accuracy (%)")

with open(f"shap_records/shapley_values/{iters}.pkl" if iters else "shapley_values.pkl", "rb") as pickle_file:
shapley_values = pickle.load(pickle_file)

neurons_sorted = np.argsort(shapley_values) # ascending order

# print(sum(shapley_values[neurons_sorted[-100:]]))
# print(shapley_values[neurons_sorted[-1]])

ablate_mask = np.zeros(filt_len)

check_score(ablate_mask, conv_means)
wm_accs, ww_accs, bm_accs, bw_accs, overall_accs, celeba_accs = [], [], [], [], [], []
for i in tqdm(range(30)):
# print("shapley value", shapley_values[neurons_sorted[-i - 1]])
ablate_mask[neurons_sorted[-i - 1]] = 1
wm_acc, ww_acc, bm_ac, bw_acc, overall_acc, celeba_acc = check_score(ablate_mask, conv_means)
wm_accs.append(wm_acc)
ww_accs.append(ww_acc)
bm_accs.append(bm_ac)
bw_accs.append(bw_acc)
overall_accs.append(overall_acc)
celeba_accs.append(celeba_acc)

plt.hist(shapley_values, bins=100)
plt.show()
plt.savefig(f"pics/dist/{iters}.png")

# plt.plot(celeba_accs, label="celeba")
plt.plot(wm_accs, label="white men")
plt.plot(ww_accs, label="white women")
plt.plot(bm_accs, label="black men")
plt.plot(bw_accs, label="black women")
plt.plot(overall_accs, label="overall")
plt.legend()
# plt.ylim(.5, 1.05)
plt.xlabel("Number of filters ablated")
plt.ylabel("Test Accuracy (%)")
plt.show()
plt.savefig(f"pics/accs/{iters}.png")

# %%

# load in iterations.pkl
with open("iterations.pkl", "rb") as pickle_file:
iterations_pk = pickle.load(pickle_file)
print(iterations_pk)

if __name__ == "__main__":
save_plots()

# %%

# # samples = total number of samples
# def moving_average(samples, prev_mean, x_n):
# prev_mean[relevant_neurons]
# return ((samples-1) * prev_mean + x_n) / samples

# def moving_variance(samples, prev_mean, prev_var, new_mean, x_n):
# return ((samples-1) * (prev_var + np.square(new_mean - prev_mean)) + np.square(x_n - new_mean)) / samples

# def confidence_bounds(samples, variances, delta):
# return np.sqrt(2 * variances * np.log(2 / delta) / samples) + 7/3 * np.log(2 / delta) / (samples-1)

# shapley_values += value_updates
# variances[relevant_neurons] = ((samples[relevant_neurons] - 1) * (variances[relevant_neurons] + np.square(value_updates[relevant_neurons])) + np.square(differential[relevant_neurons] - shapley_values[relevant_neurons])) / samples[relevant_neurons]
Loading

0 comments on commit 947e85d

Please sign in to comment.