-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsave_embeddings_post_aug.py
158 lines (131 loc) · 4.87 KB
/
save_embeddings_post_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import json
import logging
import warnings
from pathlib import Path
from typing import List, Tuple
from src.models.available_models import FINETUNED_MODELS, MODELS
from src.models.embeddings import Embeddings
from src.models.model import Model
from src.struct_probing.code_augs import available_augs, post_augs
from src.struct_probing.code_augs.aug import CodeAugmentation
from src.utils import Saver, Setup, process_model
from tasks.mlm.dataset import AugDataset
log = logging.getLogger("save_embeddings")
log.setLevel(logging.INFO)
def iter_pretrained_models(args) -> Model:
for name, model in MODELS.items():
if args.model == "all" or name == args.model:
if not args.preview:
model = model.get_model(
name,
debug=args.debug,
task_args=args,
)
else:
model = None
logging.info(f"Processing: {name}, {str(model)}")
yield model
def iter_finetuned_models(args) -> Model:
for model_name, (cls, checkpoint_path) in FINETUNED_MODELS.items():
if args.model == "all" or model_name == args.model:
if not args.preview:
model = cls.get_model(
model_name, checkpoint_path=checkpoint_path, task_args=args
)
else:
model = None
logging.info(f"Processing: {model_name}, {str(model)}")
yield model
def iter_models(args) -> Model:
for model in iter_pretrained_models(args):
yield model
for model in iter_finetuned_models(args):
yield model
def get_code_augmentation(args) -> CodeAugmentation:
return available_augs[args.insert]()
def get_embeddings_data(args, setup: Setup, dataset) -> Tuple[Saver, List[int]]:
return process_model(
dataset, args, setup, n_samples=args.n_samples, debug=args.debug
)
def main(args):
code_aug = get_code_augmentation(args)
if code_aug.required_dataset() is not None:
warnings.warn(
f"changed task data: {args.task} -> {code_aug.required_dataset()}"
)
args.task = code_aug.required_dataset()
post_aug_name = args.post_aug_name
logging.info(f"post_aug_name: {post_aug_name}")
save_path = str(
Setup.get_aug_path(
args.task, f"{args.insert}__{post_aug_name}", data_dir=args.input_dir
)
)
logging.info(f"load data from {save_path}")
dataset = AugDataset(
Saver(save_path, mode="all").load_json(),
type=args.task,
)
embeddings = Embeddings(
code_aug.required_embeddings(), pairsent=False
) # if pairsent is True: sent1
for model in iter_models(args):
if not args.preview:
setup = Setup(
dataset,
code_aug,
model,
embeddings,
data_dir=args.input_dir,
post_aug_name=args.post_aug_name,
)
logging.info(f"Setup: {str(setup)}")
(saver, invalid_ids) = get_embeddings_data(args, setup, dataset)
print(f"saver.path: {saver.path}")
if not args.debug:
logging.info(f"saver: {len(saver.data), saver.path, saver.mode}")
saver.save()
with Path(setup.get_path(), "invalid_ids.json").open("w") as f:
json.dump(invalid_ids, f)
else:
warnings.warn("Data is not saved in DEBUG mode!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="all",
help="model",
choices=list(MODELS.keys()) + list(FINETUNED_MODELS.keys()) + ["all"],
)
parser.add_argument("--task", default="mlm", help="data-dir")
parser.add_argument("--input_dir", default="CodeAnalysisAug", help="data-dir")
# parser.add_argument("--embeddings", default="dummy", type=str)
# parser.add_argument("--sbatch", action="store_true", help="run in parallel on slurm cluster")
parser.add_argument("--n_samples", type=int, default=10000)
parser.add_argument(
"--insert",
type=str,
choices=list(available_augs.keys()),
default="identity",
help="data augmentation for probing tasks (bug detection)",
)
parser.add_argument(
"-p",
"--post_aug_name",
type=str,
choices=list(post_aug().name for post_aug in post_augs),
default="default",
help="data augmentation (ablation)",
)
# Debug
parser.add_argument("--debug", action="store_true")
parser.add_argument("--preview", action="store_true")
args = parser.parse_args()
args.parse_ast = args.insert != "sorts" and not args.insert.startswith(
"algo"
)
args.lang = "java"
if args.insert.startswith("algo"):
args.lang = "python"
main(args)