Skip to content
/ gex Public

Official code implementation of "GEX: A flexible method for approximating influence via Geometric Ensemble" (NeurIPS 2023)

License

Notifications You must be signed in to change notification settings

sungyubkim/gex

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Geometric Ensemble for sample eXplanation (GEX)

DOI

Official code implementation of "GEX: A flexible method for approximating influence via Geometric Ensemble" (NeurIPS 2023)

How to use this repo?

Pull docker image for dependency

docker pull sungyubkim/jax:ntk-0.4.2

Run docker image

docker run -p 8080:8080/tcp -it --rm --gpus all \
--ipc=host -v $PWD:/root -w /root \
sungyubkim/jax:ntk-0.4.2

To run a single python file,

# to pre-train NN
python3 -m gex.pretrain.main \
    --dataset=mnist \
    --model=vgg \
    --corruption_ratio=0.1
# to estimate influence of pre-trained NN
python3 -m gex.noisy.main \
    --dataset=mnist \
    --model=vgg \
    --corruption_ratio=0.1 \
    --num_ens=8 \
    --ft_lr=0.05 \
    --ft_step=800 \
    --ft_lr_sched=cosine \
    --if_method=la_fge

To run multiple python files at once with ./gex/{task}/total.sh

bash gex/mnist/total.sh

Basically, results files (e.g., log, checkpoints, plots) will be saved in

./gex/{task}/result/{pretrain_hyperparameter_settings}/{posthoc_hyperparameter_settings}

Motivation: Identifying and Resolving Distributional Bias in Influence

Problem

As sample-wise gradient ($g_z$) follows stable distribution (e.g., Gaussian, Cauch, and Lévy), bilinear self-influence ($g_z M g_z$) follows unimodal distribution (e.g., $\chi^2$).

Key Idea

Influence Function can be interpreted as linearized sample-loss deivation (or more simply covariance) given parameters are sampled from Laplace Approximation.

$$ \mathcal{I}(z,z') = \mathbb{E}[ \Delta \ell^\mathrm{lin}(z, \psi) \cdot \Delta \ell^\mathrm{lin}(z', \psi)] = \mathrm{Cov}[\ell^\mathrm{lin}(z,\psi), \ell^\mathrm{lin}(z', \psi)]. $$

Solution

(1) Remove linearizations in sample-loss deviation and (2) Replace Laplace Approximation with Geometric Ensemble to mitigate the singularity of Hessian.

Supporting post-hoc methods

from gex.influence.estimate import compute_influence
# to compute influence kernel (N_tr, N_te) between train-test
influence_kernel = compute_influence(trainer, dataset_tr, dataset_te, dataset_opt , self_influence=False)
# to compute self-influence (N_tr) for train dataset
influence_kernel = compute_influence(trainer, dataset_tr, dataset_te, dataset_opt , self_influence=True)
  • Random Projection (--if_method=randproj)
  • TracIn Random Projection (--if_method=tracinrp)
  • Arnoldi (--if_method=arnoldi)
  • Laplace approximation with K-FAC (--if_method=la_kfac)
  • Geometric Ensemble (--if_method=la_fge)

Acknowledgement

This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No.2019-0-00075, Artificial Intelligence Graduate School Program(KAIST))

About

Official code implementation of "GEX: A flexible method for approximating influence via Geometric Ensemble" (NeurIPS 2023)

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published