forked from batubayk/enc_dec_sum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
push_to_repo.py
49 lines (42 loc) · 1.69 KB
/
push_to_repo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from transformers import AutoTokenizer, AutoConfig, EncoderDecoderModel, AutoModelForSeq2SeqLM, MBartTokenizerFast
model_repo_name = "combined_tr_berturk32k_cased_summary"
model_name_or_path = "outputs/checkpoint-82325"
config = AutoConfig.from_pretrained(
model_name_or_path
)
if "mbart" in model_name_or_path:
tokenizer = MBartTokenizerFast.from_pretrained(
model_name_or_path,
src_lang="tr_TR",
tgt_lang="tr_TR")
else:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
use_fast=True,
strip_accents=False,
lowercase=False
)
if "bert" in model_name_or_path:
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
if "bert" in model_name_or_path:
model = EncoderDecoderModel.from_pretrained(model_name_or_path)
# set special tokens
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# sensible parameters for beam search
model.config.vocab_size = model.config.decoder.vocab_size
else:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path,
config=config,
)
if "mbart" in model_name_or_path:
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
model.push_to_hub(model_repo_name)
tokenizer.push_to_hub(model_repo_name)