-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_integrated_gradients.py
110 lines (85 loc) · 2.6 KB
/
run_integrated_gradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Script to run integrated gradients.
Usage:
$python run_integrated_gradients.py --config ./configs/integrated_gradients_bert_cloze.yaml
"""
import os
import argparse
import pickle as pkl
import pandas as pd
# import itertools
import copy
import numpy as np
import torch
import torch.nn as nn
from src.utils.misc import seed, generate_grid_search_configs
from src.utils.configuration import Config
from src.datasets import *
from src.models import *
from src.trainers import *
from src.modules.preprocessors import *
from src.modules.tokenizers import *
from src.utils.mapper import configmapper
from src.utils.logger import Logger
import os
from src.utils.integrated_gradients import MyIntegratedGradients
from transformers import AutoTokenizer
# from src.utils.misc import seed
dirname = os.path.dirname(__file__)
## Config
parser = argparse.ArgumentParser(
prog="run_integrated_gradients.py",
description="Run integrated gradients on a model.",
)
parser.add_argument(
"--config",
type=str,
action="store",
help="The configuration for integrated gradients",
)
parser.add_argument(
"--model",
type=str,
action="store",
help="The configuration for model",
)
parser.add_argument(
"--data",
type=str,
action="store",
help="The configuration for data",
)
args = parser.parse_args()
ig_config = Config(path=args.config)
model_config = Config(path=args.model)
data_config = Config(path=args.data)
# verbose = args.verbose
# Preprocessor, Dataset, Model
preprocessor = configmapper.get_object(
"preprocessors", data_config.main.preprocessor.name
)(data_config)
model, train_data, val_data = preprocessor.preprocess(model_config, data_config)
tokenizer = AutoTokenizer.from_pretrained(
model_config.params.pretrained_model_name_or_path
)
# model = configmapper.get_object("models", model_config.name).from_pretrained(
# 'bert-large-uncased'
# )
model.load_state_dict(torch.load(ig_config.checkpoint_path))
# Initialize BertIntegratedGradients
big = MyIntegratedGradients(ig_config, model, val_data, tokenizer)
print("### Running IG ###")
(
samples,
word_importances,
token_importances,
) = big.get_all_importances()
print("### Saving the Scores ###")
# print(samples)
# with open(os.path.join(ig_config.store_dir, "samples"), "wb") as out_file:
# pkl.dump(samples, out_file)
with open(os.path.join(ig_config.store_dir, "token_importances"), "wb") as out_file:
pkl.dump(token_importances, out_file)
with open(os.path.join(ig_config.store_dir, "word_importances"), "wb") as out_file:
pkl.dump(word_importances, out_file)
print("### Finished ###")