forked from Silin159/DiffuCOMET
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ed08039
Showing
41 changed files
with
9,168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# DiffuCOMET | ||
|
||
This is the source code for paper DiffuCOMET: Contextual Commonsense Knowledge Diffusion. | ||
|
||
Part of our code is modified from [SeqDiffuSeq](https://github.com/Yuanhy1997/SeqDiffuSeq) repository. | ||
|
||
## Getting Started | ||
|
||
Create a **python 3.8** Conda environment and install the following packages: | ||
``` | ||
conda install mpi4py | ||
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Preparing Datasets and Toolkits | ||
|
||
Our preprocessed datasets can be downloaded from [this link](https://drive.google.com/file/d/1DIbF0WxscgEPKv4mX00wypWeVt39LMWQ/view?usp=sharing), please place ``data/`` under this root directory, and ``data_rp/`` under the ``BART_Rel_Pred/`` directory. | ||
|
||
Please also download our commonsense fact linking toolkit (ComFact_Linker) from [this link](https://drive.google.com/file/d/1BDwh1ZQZXWXw3gduSlICB81xOelBhHb0/view?usp=sharing), and place ``ComFact_Linker/`` under this root directory. | ||
|
||
## Training | ||
|
||
**DiffuCOMET-Fact seeded with BART-{base | large} models**: | ||
``` | ||
# Training fact embedding module | ||
# on ComFact benchmark knowledge (ATOMIC 2020): | ||
bash ./train_scripts/train_embedding_{base|large}.sh comfact facts 32 | ||
# on WebNLG+ 2020 benchmark knowledge: | ||
bash ./train_scripts/train_embedding_{base|large}.sh webnlg facts 64 | ||
# Training fact diffusion module | ||
# on ComFact benchmark (ROCStories portion): | ||
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc facts 32 32 comfact_facts | ||
# on WebNLG+ 2020 benchmark: | ||
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg facts 8 64 webnlg_facts | ||
``` | ||
|
||
**DiffuCOMET-Entity seeded with BART-{base | large} models**: | ||
``` | ||
# Training entity embedding module | ||
# on ComFact benchmark knowledge (ATOMIC 2020): | ||
bash ./train_scripts/train_embedding_{base|large}.sh comfact entities 32 | ||
# on WebNLG+ 2020 benchmark knowledge: | ||
bash ./train_scripts/train_embedding_{base|large}.sh webnlg entities 64 | ||
# Training head entity diffusion module | ||
# on ComFact benchmark (ROCStories portion): | ||
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc heads 16 16 comfact_entities | ||
# on WebNLG+ 2020 benchmark: | ||
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg heads 8 16 webnlg_entities | ||
# Training tail entity diffusion module | ||
# on ComFact benchmark (ROCStories portion): | ||
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc tails 8 24 comfact_entities | ||
# on WebNLG+ 2020 benchmark: | ||
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg tails 8 64 webnlg_entities | ||
# Training relation prediction module | ||
# on ComFact benchmark (ROCStories portion): | ||
bash ./BART_Rel_Pred/train_rel_pred.sh comfact_roc | ||
# on WebNLG+ 2020 benchmark: | ||
bash ./BART_Rel_Pred/train_rel_pred.sh webnlg | ||
``` | ||
|
||
## Inference | ||
|
||
**DiffuCOMET-Fact seeded with BART-{base | large} models**: | ||
``` | ||
# Testing on ComFact ROCStories (comfact_roc), PersonaChat (comfact_persona), MuTual (comfact_mutual), MovieSummaries (comfact_movie) or WebNLG+ 2020 (webnlg): | ||
bash ./inference_scripts/inference.sh ${train_dataset} # ${test_dataset} facts {base|large} test ${train_step} ${schedule} ${ctx_len} | ||
# ${train_dataset}: {comfact_roc|webnlg} | ||
# ${test_dataset}: {comfact_roc|comfact_persona|comfact_mutual|comfact_movie|webnlg} | ||
# ${train_step}: training step (ID) of tested model checkpoint, e.g., 130000 | ||
# ${schedule}: noise schedule ID of tested model checkpoint, should be ${train_step}-2000, e.g., 128000 | ||
# ${ctx_len}: maximum narrative context length, should be 256 for testing on comfact_movie, while 128 for others | ||
# generations will be saved in results/${test_dataset}_facts_{base|large}_${train_step}/generations.json | ||
# Post-processing fact generations, post-processed generations will be saved in ${result_dir}/gen_processed.json: | ||
python ./diffu_eval/post_process_facts.py --context_dir data/${test_dataset}_facts/test.contexts \ | ||
--result_dir results/${test_dataset}_facts_{base|large}_${train_step} | ||
``` | ||
|
||
**DiffuCOMET-Entity seeded with BART-{base | large} models**: | ||
``` | ||
# Head entity generation | ||
bash ./inference_scripts/inference.sh ${train_dataset} ${test_dataset} heads {base|large} test ${train_step_head} ${schedule_head} ${ctx_len} | ||
# Post-processing head entity generations: | ||
python ./diffu_eval/post_process_heads.py --context data/${test_dataset}_heads/test.contexts \ | ||
--result_dir results/${test_dataset}_heads_{base|large}_${train_step_head} \ | ||
--tail_gen_input_dir data/${test_dataset}_tails/test_{base|large}_${train_step_head} | ||
# Tail entity generation | ||
bash ./inference_scripts/inference.sh ${train_dataset} ${test_dataset} tails {base|large} test_{base|large}_${train_step_head} ${train_step_tail} ${schedule_tail} ${ctx_len} | ||
# Post-processing tail entity generations: | ||
python ./diffu_eval/post_process_tails.py --gold_dir data/${test_dataset}_facts/test \ | ||
--tail_gen_input_dir data/${test_dataset}_tails/test_{base|large}_${train_step_head} \ | ||
--tail_gen_result_dir results/${test_dataset}_tails_{base|large}_${train_step_tail} \ | ||
--pipeline_result_dir results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail} \ | ||
--rel_pred_input_dir BART_Rel_Pred/data_rp/${test_dataset}/rel_pred_inf_{base|large}/test | ||
# Relation prediction | ||
bash ./BART_Rel_Pred/run_rel_pred.sh ${train_dataset} ${test_dataset} ${train_step_rel_pred} {base|large} | ||
# Post-processing relation predictions: | ||
python ./diffu_eval/post_process_rel_pred.py --test_data ${test_dataset} \ | ||
--rel_pred_ids BART_Rel_Pred/data_rp/${test_dataset}/rel_pred_inf_{base|large}/test/labels.json \ | ||
--rel_pred_results BART_Rel_Pred/pred/${test_dataset}-{base|large}/predictions.json \ | ||
--pipeline_result_dir results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail} | ||
``` | ||
|
||
## Evaluation | ||
|
||
**Evaluating on traditional NLG metrics** | ||
``` | ||
python ./diffu_eval/eval_nlg.py --generation ${processed_gen} --eval_result_dir ${eval_result_dir} | ||
# ${processed_gen}: results/${test_dataset}_facts_{base|large}_${train_step}/gen_processed.json | ||
# or results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail}/gen_processed.json | ||
# ${eval_result_dir}: directory for saving evaluation scores, e.g., results/${test_dataset}_facts_{base|large}_${train_step} | ||
# evaluation scores will be saved in ${eval_result_dir}/nlg_eval.json | ||
``` | ||
|
||
**Evaluating on our proposed clustering-based metrics** | ||
``` | ||
# Pre-processing generations for ComFact linker to score relevance | ||
python ./diffu_eval/prepare_comfact_linking.py --test_data ${test_dataset} --generation ${processed_gen} \ | ||
--comfact_input_dir ComFact_Linker/data_fl/all/fact_link/nlu/${eval_model} | ||
# ${eval_model}: ${test_dataset}_facts_{base|large}_${train_step} | ||
# or ${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail} | ||
# Switching to ComFact original environment (optional) | ||
# Please refer to ComFact_Linker/README.md | ||
# Run ComFact linker | ||
bash ComFact_Linker/run_fact_link.sh ${eval_model} | ||
# scoring results will be saved in ComFact_Linker/pred/${eval_model}/predictions.json | ||
# Switching back to DiffuCOMET environment (optional) | ||
# Please refer to ComFact_Linker/README.md | ||
# Post-processing ComFact linker scoring results: | ||
python ./diffu_eval/write_comfact_scores.py --comfact_output ComFact_Linker/pred/${eval_model}/predictions.json \ | ||
--generation ${processed_gen} | ||
# Clustering-based evaluation | ||
python ./diffu_eval/eval_cluster.py --test_data ${test_dataset} --generation ${processed_gen} \ | ||
--eval_result_dir ${eval_result_dir} | ||
# evaluation scores will be saved in ${eval_result_dir}/cluster_eval.csv | ||
# each line of the CSV file records a metric scoring on a range of clustering thresholds (DBSCAN eps) | ||
``` | ||
|
||
**Evaluating on WebNLG metrics (for testing dataset webnlg)** | ||
``` | ||
python ./diffu_eval/eval_webnlg.py --generation ${processed_gen} --eval_result_dir ${eval_result_dir} | ||
# evaluation scores will be saved in ${eval_result_dir}/scores_webnlg.json | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
|
||
import argparse | ||
|
||
def create_argparser(): | ||
defaults = dict( | ||
data_dir="", | ||
src='src', | ||
tgt='tgt', | ||
schedule_sampler="uniform", | ||
lr=1e-4, | ||
weight_decay=0.0, | ||
lr_anneal_steps=30000, | ||
warmup=0, | ||
batch_size=1, | ||
microbatch=-1, # -1 disables microbatches | ||
ema_rate="0.9999", # comma-separated list of EMA values | ||
log_interval=50, | ||
save_interval=25000, | ||
resume_checkpoint="", | ||
use_fp16=False, | ||
fp16_scale_growth=1e-3, | ||
seed=101, | ||
gradient_clipping=-1.0, | ||
eval_interval=2000, | ||
checkpoint_path="diff_models", | ||
train_txt_path="data/quotes_train.txt", | ||
val_txt_path="data/quotes_valid.txt", | ||
dataset="", | ||
notes="", | ||
) | ||
text_defaults = dict( | ||
modality="text", | ||
emb_scale_factor=1.0, | ||
in_channel=16, | ||
out_channel=16, | ||
noise_level=0.0, | ||
cache_mode="no", | ||
use_bert_tokenizer="no", | ||
padding_mode="block", | ||
preprocessing_num_workers=1, | ||
tok_thresh=150 | ||
) | ||
|
||
guided_generation_defaults = dict( | ||
classifier_num_epochs=15 | ||
) | ||
|
||
defaults.update(model_and_diffusion_defaults()) | ||
defaults.update(text_defaults) | ||
defaults.update(guided_generation_defaults) | ||
defaults.update(decoding_defaults()) | ||
defaults.update(additional_args_for_translation()) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--debug", action="store_true") | ||
|
||
add_dict_to_argparser(parser, defaults) | ||
return parser | ||
|
||
def additional_args_for_translation(): | ||
|
||
return dict( | ||
pretrained_tokenizer=None, | ||
sequence_len_src=128, | ||
use_pretrained_tokenizer=False, | ||
generate_by_q=False, | ||
generate_by_mix=False, | ||
generate_by_mix_prob=0.0, | ||
generate_by_mix_part=1.0, | ||
) | ||
|
||
|
||
def model_and_diffusion_defaults(): | ||
""" | ||
Defaults for text-diffusion model training. | ||
""" | ||
return dict( | ||
encoder_layers=6, | ||
decoder_layers=6, | ||
sequence_len=32, | ||
sequence_len_fact=32, | ||
num_channels=16, | ||
num_heads=4, | ||
dropout=0.0, | ||
learn_sigma=False, | ||
sigma_small=False, | ||
class_cond=False, | ||
diffusion_steps=10000, | ||
noise_schedule="linear", | ||
timestep_respacing="", | ||
use_kl=False, | ||
predict_xstart=False, | ||
rescale_timesteps=True, | ||
rescale_learned_sigmas=True, | ||
use_checkpoint=False, | ||
model_arch="transformer", | ||
in_channel=16, | ||
out_channel=16, | ||
vocab_size=66, | ||
config_name="bert-base-uncased", | ||
config_name_embedder="bert-base-uncased", | ||
dae=False, | ||
gamma_nll=1.0, | ||
noise_amplifier=1.0, | ||
fg_do_sample=False, | ||
fg_max_len=20, | ||
fg_top_k=50, | ||
fg_top_p=1.0, | ||
fg_input="greedy_mean", # "sample", "pred_xstart" or "greedy_mean" | ||
logits_mode=1, | ||
training_mode="diffusion-lm", | ||
init_pretrained=False, | ||
init_pretrained_embedder=False, | ||
freeze_embeddings=False, | ||
use_pretrained_embeddings=False, | ||
load_ckpt=None, | ||
loss_update_granu=None, | ||
schedule_update_stride=0, | ||
) | ||
|
||
|
||
def decoding_defaults(): | ||
return dict( | ||
num_samples=50, | ||
top_p=0.9, | ||
out_dir="", | ||
model_name_or_path="", | ||
checkpoint_path="", | ||
use_ddim=False, | ||
clip_denoised=False, | ||
batch_size=64, | ||
mbr_sample=1, | ||
verbose="yes", | ||
clamp="clamp", | ||
preprocessing_num_workers=1, | ||
emb_scale_factor=1.0, | ||
classifier_path="", | ||
time_schedule_path='', | ||
comment='', | ||
) | ||
|
||
|
||
def add_dict_to_argparser(parser, default_dict): | ||
for k, v in default_dict.items(): | ||
v_type = type(v) | ||
if v is None: | ||
v_type = str | ||
elif isinstance(v, bool): | ||
v_type = str2bool | ||
parser.add_argument(f"--{k}", default=v, type=v_type) | ||
|
||
|
||
def args_to_dict(args, keys): | ||
return {k: getattr(args, k) for k in keys} | ||
|
||
|
||
def str2bool(v): | ||
""" | ||
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse | ||
""" | ||
if isinstance(v, bool): | ||
return v | ||
if v.lower() in ("yes", "true", "t", "y", "1"): | ||
return True | ||
elif v.lower() in ("no", "false", "f", "n", "0"): | ||
return False | ||
else: | ||
raise argparse.ArgumentTypeError("boolean value expected") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
## Directory for checkpointing DiffuCOMET models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
## Directory for checkpointing DiffuCOMET diffusion modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
## Directory for checkpointing DiffuCOMET embedding modules |
Oops, something went wrong.