forked from boostcampaitech2/mrc-level2-nlp-10
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelList.py
38 lines (28 loc) · 1.55 KB
/
modelList.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import AutoConfig, AutoTokenizer
from model import RobertaQA, BertQA, ElectraQA, ConvModel
def init(model_args):
# load config, tokenizer
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, use_fast=True
)
# insert speical tokens (unk, unused token)
user_defined_symbols = []
for i in range(1, 100):
user_defined_symbols.append(f"[UNK{i}]")
for i in range(500, 700):
user_defined_symbols.append(f"[unused{i}]")
special_tokens_dict = {"additional_special_tokens": user_defined_symbols}
tokenizer.add_special_tokens(special_tokens_dict)
# model_name_or_path마다 다른 모델을 불러오도록 작성되었습니다
if model_args.model_name_or_path == "klue/bert-base":
model = BertQA.from_pretrained(model_args.model_name_or_path, config=config)
elif model_args.model_name_or_path == "klue/roberta-large":
model = RobertaQA.from_pretrained(model_args.model_name_or_path, config=config)
elif model_args.model_name_or_path == 'ConvModel'
model = ConvModel.from_pretrained('klue/roberta-large', config=config)
elif model_args.model_name_or_path == "monologg/koelectra-base-v3-discriminator":
model = ElectraQA.from_pretrained(model_args.model_name_or_path, config=config)
# special token의 추가로 token embedding를 resize합니다
model.resize_token_embeddings(tokenizer.vocab_size + len(user_defined_symbols))
return tokenizer, model