-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7827940
Showing
8 changed files
with
531 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# BioLaySumm | ||
Finetune model in BioLaySumm dataset | ||
### Init Analysis | ||
PLOS and eLife val dataset mean | ||
#### Relevance | ||
| | Rouge-1 | Rouge-2 | Rouge-L | BERTScore | | ||
|:------ |:------- | ------- |:------- |:--------- | | ||
| BART | 0.4786 | 0.1525 | 0.4452 | 0.8486 | | ||
| LED | **0.4858** | **0.1552** | **0.4502** | **0.8571** | | ||
| T5 | 0.4358 | 0.1214 | 0.4095 | 0.8398 | | ||
#### Readability | ||
| | FKGL | DCRS | | ||
|:------ | ------- |:------- | | ||
| BART | 12.3617 | 9.9345 | | ||
| LED | 11.8577 | 9.8441 | | ||
| T5 | **10.1728** | **9.1107** | | ||
#### Factuality | ||
| | BARTScore | | ||
|:------ |:--------- | | ||
| BART | -2.7569 | | ||
| LED | **-2.0367** | | ||
| T5 | -3.7528 | | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from utils import * | ||
import pandas as pd | ||
|
||
import torch | ||
from datasets import Dataset | ||
from transformers import BartTokenizer, BartForConditionalGeneration | ||
|
||
class prepare_dataset(object): | ||
def __init__(self, file_path, nums): | ||
self.data = pd.read_csv(file_path) | ||
self.num = int(nums) | ||
|
||
def load(self): | ||
data_list = [] | ||
for i in range(self.num): | ||
data_list.append(str(self.data['src_sen'][i]) + '\t' + str(self.data['dst_sen'][i])) | ||
|
||
return {"text": data_list} | ||
|
||
# def tokenizer | ||
def tokenize_data(tokenizer, dataset, max_len): | ||
def convert_to_features(example_batch): | ||
src_texts = [] | ||
dst_texts = [] | ||
for example in example_batch['text']: | ||
term = example.split('\t', 1) | ||
src_texts.append(term[0]) | ||
dst_texts.append(term[1]) | ||
|
||
src_encodings = tokenizer.batch_encode_plus( | ||
src_texts, | ||
truncation=True, | ||
padding='max_length', | ||
max_length=max_len, | ||
) | ||
dst_encodings = tokenizer.batch_encode_plus( | ||
dst_texts, | ||
truncation=True, | ||
padding='max_length', | ||
max_length=max_len, | ||
) | ||
encodings = { | ||
'input_ids': src_encodings['input_ids'], | ||
'attention_mask': src_encodings['attention_mask'], | ||
'dst_ids': dst_encodings['input_ids'], | ||
'target_attention_mask': dst_encodings['attention_mask'] | ||
} | ||
|
||
return encodings | ||
|
||
dataset = dataset.map(convert_to_features, batched=True) | ||
# Set the tensor type and the columns which the dataset should return | ||
columns = ['input_ids', 'dst_ids', 'attention_mask', 'target_attention_mask'] | ||
dataset.with_format(type='torch', columns=columns) | ||
# Rename columns to the names that the forward method of the selected | ||
# model expects | ||
dataset = dataset.rename_column('dst_ids', 'labels') | ||
dataset = dataset.rename_column('target_attention_mask', 'decoder_attention_mask') | ||
|
||
# ---------------------- !!! ---------------------------------------------- | ||
dataset = dataset.remove_columns(['text']) | ||
|
||
return dataset | ||
|
||
if __name__ == '__main__': | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default='PLOS') | ||
parser.add_argument("--datatype", type=str, default="train") | ||
args = parser.parse_args() | ||
train_article, _, _ = load_task1_data(args) | ||
args.datatype = "val" | ||
val_article, _, _ = load_task1_data(args) | ||
|
||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | ||
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device) | ||
|
||
d_train = prepare_dataset('PLOS_train.csv', len(train_article)) | ||
train_data_dic = d_train.load() | ||
train_dataset = Dataset.from_dict(train_data_dic, split='train') | ||
# ------------------------------------------------------------------------- | ||
d_valid = prepare_dataset('PLOS_val.csv', len(val_article)) | ||
valid_data_dic = d_valid.load() | ||
valid_dataset = Dataset.from_dict(valid_data_dic, split='test') | ||
|
||
|
||
train_data = tokenize_data(tokenizer, train_dataset, max_len = 1024) | ||
valid_data = tokenize_data(tokenizer, valid_dataset, max_len = 512) | ||
|
||
|
||
# exit() | ||
from transformers import TrainingArguments, Trainer | ||
|
||
training_args = TrainingArguments( | ||
output_dir='./results', # output directory 结果输出地址 | ||
num_train_epochs=1, # total # of training epochs 训练总批次 | ||
per_device_train_batch_size=1, # batch size per device during training 训练批大小 | ||
per_device_eval_batch_size=1, # batch size for evaluation 评估批大小 | ||
logging_dir='./logs/rn_log', # directory for storing logs 日志存储位置 | ||
learning_rate=1e-4, # 学习率 | ||
save_steps=False,# 不保存检查点 | ||
logging_steps=2, | ||
gradient_accumulation_steps=8, | ||
) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
train_dataset=train_data, | ||
eval_dataset=valid_data, | ||
) | ||
trainer.train() | ||
|
||
##模型保存 | ||
model.save_pretrained("./bart-3/") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from utils import * | ||
|
||
from transformers import( | ||
Seq2SeqTrainer, | ||
Seq2SeqTrainingArguments, | ||
AutoTokenizer, | ||
AutoModelForSeq2SeqLM, | ||
) | ||
from tqdm import tqdm | ||
from datasets import Dataset | ||
|
||
|
||
# load tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384") | ||
|
||
encoder_max_length = 6144 | ||
decoder_max_length = 512 | ||
batch_size = 1 | ||
|
||
def process_data_to_model_inputs(batch): | ||
inputs = tokenizer( | ||
batch["article"], | ||
padding="max_length", | ||
truncation=True, | ||
max_length=encoder_max_length, | ||
) | ||
outputs = tokenizer( | ||
batch["abstract"], | ||
padding="max_length", | ||
truncation=True, | ||
max_length=decoder_max_length, | ||
) | ||
batch["input_ids"] = inputs.input_ids | ||
batch["attention_mask"] = inputs.attention_mask | ||
|
||
batch["global_attention_mask"] = len(batch["input_ids"]) * [ | ||
[0 for _ in range(len(batch["input_ids"][0]))] | ||
] | ||
|
||
batch["global_attention_mask"][0][0] = 1 | ||
batch["labels"] = outputs.input_ids | ||
|
||
batch["labels"] = [ | ||
[-100 if token == tokenizer.pad_token_id else token for token in labels] | ||
for labels in batch["labels"] | ||
] | ||
|
||
return batch | ||
|
||
# construct training dataset | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default='PLOS') | ||
parser.add_argument("--datatype", type=str, default="train") | ||
args = parser.parse_args() | ||
# training data | ||
article_train, lay_sum_train, _ = load_task1_data(args) | ||
train_dataset = {'article': article_train, 'abstract': lay_sum_train} | ||
train_dataset = Dataset.from_dict(train_dataset) | ||
# validation data | ||
args.datatype = 'val' | ||
article_val, lay_sum_val, _ = load_task1_data(args) | ||
val_dataset = {'article': article_val, 'abstract': lay_sum_val} | ||
val_dataset = Dataset.from_dict(val_dataset) | ||
|
||
# --------------------test 300 nums of data------------------- | ||
train_dataset = train_dataset.select(range(1000)) | ||
val_dataset = val_dataset.select(range(10)) | ||
# ------------------------------------------------------------ | ||
|
||
|
||
# map train data | ||
train_dataset = train_dataset.map( | ||
process_data_to_model_inputs, | ||
batched = True, | ||
batch_size = batch_size, | ||
remove_columns=["article", "abstract"] | ||
) | ||
# map val data | ||
val_dataset = val_dataset.map( | ||
process_data_to_model_inputs, | ||
batched = True, | ||
batch_size = batch_size, | ||
remove_columns=["article", "abstract"] | ||
) | ||
|
||
# the datasets should be converted into the PyTorch format | ||
train_dataset.set_format( | ||
type="torch", | ||
columns=["input_ids", "attention_mask", "global_attention_mask", "labels"], | ||
) | ||
val_dataset.set_format( | ||
type="torch", | ||
columns=["input_ids", "attention_mask", "global_attention_mask", "labels"], | ||
) | ||
|
||
from rouge import Rouge | ||
rouge = Rouge() | ||
# the generation output, called pred.predictions as well as the gold label, called pred.label_ids. | ||
def compute_metrics(pred): | ||
labels_ids = pred.label_ids | ||
pred_ids = pred.predictions | ||
|
||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | ||
labels_ids[labels_ids == -100] = tokenizer.pad_token_id | ||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) | ||
rouge_output = rouge.get_scores(pred_str, label_str)[0]['rouge-2'] | ||
|
||
return { | ||
"rouge2_precision": round(rouge_output['p'], 4), | ||
"rouge2_recall": round(rouge_output['r'], 4), | ||
"rouge2_fmeasure": round(rouge_output['f'], 4), | ||
} | ||
|
||
led = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False) | ||
|
||
# set generate hyperparameters | ||
led.config.num_beams = 2 | ||
led.config.max_length = 512 | ||
led.config.min_length = 100 | ||
led.config.length_penalty = 2.0 | ||
led.config.early_stopping = True | ||
led.config.no_repeat_ngram_size = 3 | ||
|
||
# Training | ||
model_name = 'long-2' | ||
training_args = Seq2SeqTrainingArguments( | ||
predict_with_generate=True, | ||
evaluation_strategy="steps", | ||
per_device_train_batch_size=batch_size, | ||
per_device_eval_batch_size=batch_size, | ||
fp16=True, | ||
output_dir=f"./{model_name}", | ||
logging_steps=5, | ||
eval_steps=10, | ||
save_steps=10, | ||
save_total_limit=2, | ||
gradient_accumulation_steps=4, | ||
num_train_epochs=1, | ||
) | ||
|
||
trainer = Seq2SeqTrainer( | ||
model=led, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
compute_metrics=compute_metrics, | ||
train_dataset=train_dataset, | ||
eval_dataset=val_dataset, | ||
) | ||
|
||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import argparse | ||
import torch | ||
from transformers import BartForConditionalGeneration, BartTokenizer | ||
from utils import * | ||
from write_data_csv import write_data_txt | ||
|
||
# generation candidate sentences (through beam-search) | ||
def sen_generation(device, tokenizer, model, text: str, max_length: int, beam_nums): | ||
inputs = tokenizer.encode(text, padding=True, max_length=max_length, truncation=True, | ||
return_tensors='pt') | ||
inputs = inputs.to(device) | ||
model = model.to(device) | ||
|
||
res = model.generate( | ||
inputs, length_penalty = 2, num_beams = 4, no_repeat_ngram_size = 3, | ||
max_length = max_length, num_return_sequences = beam_nums | ||
) | ||
|
||
decode_tokens = [] | ||
for beam_res in res: | ||
decode_tokens.append(tokenizer.decode(beam_res.squeeze(), skip_special_tokens = True).lower()) | ||
|
||
return decode_tokens | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default='PLOS') | ||
parser.add_argument("--datatype", type=str, default="val") | ||
parser.add_argument("--max_len", type=int, default=512) | ||
parser.add_argument("--beam_nums", type=int, default=1) | ||
args = parser.parse_args() | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
new_model = BartForConditionalGeneration.from_pretrained("./bart-2/") | ||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | ||
|
||
article, _, _ = load_task1_data(args) | ||
sys_out = [] | ||
for sen in tqdm(article): | ||
# generate candidate sentences list | ||
result = sen_generation(device, tokenizer, new_model, sen, | ||
args.max_len, args.beam_nums) | ||
sys_out.append(result[0]) | ||
|
||
write_data_txt(sys_out, "bart_plos_1") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
from datasets import Dataset | ||
from transformers import LEDTokenizer, LEDForConditionalGeneration | ||
from utils import * | ||
|
||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
||
# load tokenizer | ||
tokenizer = LEDTokenizer.from_pretrained("./checkpoint-10/") | ||
model = LEDForConditionalGeneration.from_pretrained("./checkpoint-10/").to(device) | ||
|
||
def generate_sum(batch): | ||
inputs_dict = tokenizer(batch["article"], padding="max_length", max_length=8192, return_tensors="pt", truncation=True) | ||
input_ids = inputs_dict.input_ids.to(device) | ||
attention_mask = inputs_dict.attention_mask.to(device) | ||
global_attention_mask = torch.zeros_like(attention_mask) | ||
# put global attention on <s> token | ||
global_attention_mask[:, 0] = 1 | ||
|
||
predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) | ||
batch["predicted_abstract"] = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True) | ||
return batch | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset", type=str, default="PLOS") | ||
parser.add_argument("--datatype", type=str, default="val") | ||
args = parser.parse_args() | ||
|
||
article_val, lay_sum_val, _ = load_task1_data(args) | ||
val_dataset = {'article': article_val, 'abstract': lay_sum_val} | ||
val_dataset = Dataset.from_dict(val_dataset) | ||
val_dataset = val_dataset.select(range(50)) | ||
|
||
result = val_dataset.map(generate_sum, batched=True, batch_size=1) | ||
# print(result["predicted_abstract"][0]) | ||
|
||
from write_data_csv import write_data_txt | ||
write_data_txt(result["predicted_abstract"], "long_PLOS_1") |
Oops, something went wrong.