Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Silin159 authored Feb 26, 2024
0 parents commit ed08039
Show file tree
Hide file tree
Showing 41 changed files with 9,168 additions and 0 deletions.
164 changes: 164 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# DiffuCOMET

This is the source code for paper DiffuCOMET: Contextual Commonsense Knowledge Diffusion.

Part of our code is modified from [SeqDiffuSeq](https://github.com/Yuanhy1997/SeqDiffuSeq) repository.

## Getting Started

Create a **python 3.8** Conda environment and install the following packages:
```
conda install mpi4py
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
```

## Preparing Datasets and Toolkits

Our preprocessed datasets can be downloaded from [this link](https://drive.google.com/file/d/1DIbF0WxscgEPKv4mX00wypWeVt39LMWQ/view?usp=sharing), please place ``data/`` under this root directory, and ``data_rp/`` under the ``BART_Rel_Pred/`` directory.

Please also download our commonsense fact linking toolkit (ComFact_Linker) from [this link](https://drive.google.com/file/d/1BDwh1ZQZXWXw3gduSlICB81xOelBhHb0/view?usp=sharing), and place ``ComFact_Linker/`` under this root directory.

## Training

**DiffuCOMET-Fact seeded with BART-{base | large} models**:
```
# Training fact embedding module
# on ComFact benchmark knowledge (ATOMIC 2020):
bash ./train_scripts/train_embedding_{base|large}.sh comfact facts 32
# on WebNLG+ 2020 benchmark knowledge:
bash ./train_scripts/train_embedding_{base|large}.sh webnlg facts 64
# Training fact diffusion module
# on ComFact benchmark (ROCStories portion):
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc facts 32 32 comfact_facts
# on WebNLG+ 2020 benchmark:
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg facts 8 64 webnlg_facts
```

**DiffuCOMET-Entity seeded with BART-{base | large} models**:
```
# Training entity embedding module
# on ComFact benchmark knowledge (ATOMIC 2020):
bash ./train_scripts/train_embedding_{base|large}.sh comfact entities 32
# on WebNLG+ 2020 benchmark knowledge:
bash ./train_scripts/train_embedding_{base|large}.sh webnlg entities 64
# Training head entity diffusion module
# on ComFact benchmark (ROCStories portion):
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc heads 16 16 comfact_entities
# on WebNLG+ 2020 benchmark:
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg heads 8 16 webnlg_entities
# Training tail entity diffusion module
# on ComFact benchmark (ROCStories portion):
bash ./train_scripts/train_diffusion_{base|large}.sh comfact_roc tails 8 24 comfact_entities
# on WebNLG+ 2020 benchmark:
bash ./train_scripts/train_diffusion_{base|large}.sh webnlg tails 8 64 webnlg_entities
# Training relation prediction module
# on ComFact benchmark (ROCStories portion):
bash ./BART_Rel_Pred/train_rel_pred.sh comfact_roc
# on WebNLG+ 2020 benchmark:
bash ./BART_Rel_Pred/train_rel_pred.sh webnlg
```

## Inference

**DiffuCOMET-Fact seeded with BART-{base | large} models**:
```
# Testing on ComFact ROCStories (comfact_roc), PersonaChat (comfact_persona), MuTual (comfact_mutual), MovieSummaries (comfact_movie) or WebNLG+ 2020 (webnlg):
bash ./inference_scripts/inference.sh ${train_dataset} # ${test_dataset} facts {base|large} test ${train_step} ${schedule} ${ctx_len}
# ${train_dataset}: {comfact_roc|webnlg}
# ${test_dataset}: {comfact_roc|comfact_persona|comfact_mutual|comfact_movie|webnlg}
# ${train_step}: training step (ID) of tested model checkpoint, e.g., 130000
# ${schedule}: noise schedule ID of tested model checkpoint, should be ${train_step}-2000, e.g., 128000
# ${ctx_len}: maximum narrative context length, should be 256 for testing on comfact_movie, while 128 for others
# generations will be saved in results/${test_dataset}_facts_{base|large}_${train_step}/generations.json
# Post-processing fact generations, post-processed generations will be saved in ${result_dir}/gen_processed.json:
python ./diffu_eval/post_process_facts.py --context_dir data/${test_dataset}_facts/test.contexts \
--result_dir results/${test_dataset}_facts_{base|large}_${train_step}
```

**DiffuCOMET-Entity seeded with BART-{base | large} models**:
```
# Head entity generation
bash ./inference_scripts/inference.sh ${train_dataset} ${test_dataset} heads {base|large} test ${train_step_head} ${schedule_head} ${ctx_len}
# Post-processing head entity generations:
python ./diffu_eval/post_process_heads.py --context data/${test_dataset}_heads/test.contexts \
--result_dir results/${test_dataset}_heads_{base|large}_${train_step_head} \
--tail_gen_input_dir data/${test_dataset}_tails/test_{base|large}_${train_step_head}
# Tail entity generation
bash ./inference_scripts/inference.sh ${train_dataset} ${test_dataset} tails {base|large} test_{base|large}_${train_step_head} ${train_step_tail} ${schedule_tail} ${ctx_len}
# Post-processing tail entity generations:
python ./diffu_eval/post_process_tails.py --gold_dir data/${test_dataset}_facts/test \
--tail_gen_input_dir data/${test_dataset}_tails/test_{base|large}_${train_step_head} \
--tail_gen_result_dir results/${test_dataset}_tails_{base|large}_${train_step_tail} \
--pipeline_result_dir results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail} \
--rel_pred_input_dir BART_Rel_Pred/data_rp/${test_dataset}/rel_pred_inf_{base|large}/test
# Relation prediction
bash ./BART_Rel_Pred/run_rel_pred.sh ${train_dataset} ${test_dataset} ${train_step_rel_pred} {base|large}
# Post-processing relation predictions:
python ./diffu_eval/post_process_rel_pred.py --test_data ${test_dataset} \
--rel_pred_ids BART_Rel_Pred/data_rp/${test_dataset}/rel_pred_inf_{base|large}/test/labels.json \
--rel_pred_results BART_Rel_Pred/pred/${test_dataset}-{base|large}/predictions.json \
--pipeline_result_dir results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail}
```

## Evaluation

**Evaluating on traditional NLG metrics**
```
python ./diffu_eval/eval_nlg.py --generation ${processed_gen} --eval_result_dir ${eval_result_dir}
# ${processed_gen}: results/${test_dataset}_facts_{base|large}_${train_step}/gen_processed.json
# or results/${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail}/gen_processed.json
# ${eval_result_dir}: directory for saving evaluation scores, e.g., results/${test_dataset}_facts_{base|large}_${train_step}
# evaluation scores will be saved in ${eval_result_dir}/nlg_eval.json
```

**Evaluating on our proposed clustering-based metrics**
```
# Pre-processing generations for ComFact linker to score relevance
python ./diffu_eval/prepare_comfact_linking.py --test_data ${test_dataset} --generation ${processed_gen} \
--comfact_input_dir ComFact_Linker/data_fl/all/fact_link/nlu/${eval_model}
# ${eval_model}: ${test_dataset}_facts_{base|large}_${train_step}
# or ${test_dataset}_pipeline_{base|large}_${train_step_head}_${train_step_tail}
# Switching to ComFact original environment (optional)
# Please refer to ComFact_Linker/README.md
# Run ComFact linker
bash ComFact_Linker/run_fact_link.sh ${eval_model}
# scoring results will be saved in ComFact_Linker/pred/${eval_model}/predictions.json
# Switching back to DiffuCOMET environment (optional)
# Please refer to ComFact_Linker/README.md
# Post-processing ComFact linker scoring results:
python ./diffu_eval/write_comfact_scores.py --comfact_output ComFact_Linker/pred/${eval_model}/predictions.json \
--generation ${processed_gen}
# Clustering-based evaluation
python ./diffu_eval/eval_cluster.py --test_data ${test_dataset} --generation ${processed_gen} \
--eval_result_dir ${eval_result_dir}
# evaluation scores will be saved in ${eval_result_dir}/cluster_eval.csv
# each line of the CSV file records a metric scoring on a range of clustering thresholds (DBSCAN eps)
```

**Evaluating on WebNLG metrics (for testing dataset webnlg)**
```
python ./diffu_eval/eval_webnlg.py --generation ${processed_gen} --eval_result_dir ${eval_result_dir}
# evaluation scores will be saved in ${eval_result_dir}/scores_webnlg.json
```
167 changes: 167 additions & 0 deletions args_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@

import argparse

def create_argparser():
defaults = dict(
data_dir="",
src='src',
tgt='tgt',
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=30000,
warmup=0,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=50,
save_interval=25000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
seed=101,
gradient_clipping=-1.0,
eval_interval=2000,
checkpoint_path="diff_models",
train_txt_path="data/quotes_train.txt",
val_txt_path="data/quotes_valid.txt",
dataset="",
notes="",
)
text_defaults = dict(
modality="text",
emb_scale_factor=1.0,
in_channel=16,
out_channel=16,
noise_level=0.0,
cache_mode="no",
use_bert_tokenizer="no",
padding_mode="block",
preprocessing_num_workers=1,
tok_thresh=150
)

guided_generation_defaults = dict(
classifier_num_epochs=15
)

defaults.update(model_and_diffusion_defaults())
defaults.update(text_defaults)
defaults.update(guided_generation_defaults)
defaults.update(decoding_defaults())
defaults.update(additional_args_for_translation())
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")

add_dict_to_argparser(parser, defaults)
return parser

def additional_args_for_translation():

return dict(
pretrained_tokenizer=None,
sequence_len_src=128,
use_pretrained_tokenizer=False,
generate_by_q=False,
generate_by_mix=False,
generate_by_mix_prob=0.0,
generate_by_mix_part=1.0,
)


def model_and_diffusion_defaults():
"""
Defaults for text-diffusion model training.
"""
return dict(
encoder_layers=6,
decoder_layers=6,
sequence_len=32,
sequence_len_fact=32,
num_channels=16,
num_heads=4,
dropout=0.0,
learn_sigma=False,
sigma_small=False,
class_cond=False,
diffusion_steps=10000,
noise_schedule="linear",
timestep_respacing="",
use_kl=False,
predict_xstart=False,
rescale_timesteps=True,
rescale_learned_sigmas=True,
use_checkpoint=False,
model_arch="transformer",
in_channel=16,
out_channel=16,
vocab_size=66,
config_name="bert-base-uncased",
config_name_embedder="bert-base-uncased",
dae=False,
gamma_nll=1.0,
noise_amplifier=1.0,
fg_do_sample=False,
fg_max_len=20,
fg_top_k=50,
fg_top_p=1.0,
fg_input="greedy_mean", # "sample", "pred_xstart" or "greedy_mean"
logits_mode=1,
training_mode="diffusion-lm",
init_pretrained=False,
init_pretrained_embedder=False,
freeze_embeddings=False,
use_pretrained_embeddings=False,
load_ckpt=None,
loss_update_granu=None,
schedule_update_stride=0,
)


def decoding_defaults():
return dict(
num_samples=50,
top_p=0.9,
out_dir="",
model_name_or_path="",
checkpoint_path="",
use_ddim=False,
clip_denoised=False,
batch_size=64,
mbr_sample=1,
verbose="yes",
clamp="clamp",
preprocessing_num_workers=1,
emb_scale_factor=1.0,
classifier_path="",
time_schedule_path='',
comment='',
)


def add_dict_to_argparser(parser, default_dict):
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)


def args_to_dict(args, keys):
return {k: getattr(args, k) for k in keys}


def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
2 changes: 2 additions & 0 deletions ckpts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

## Directory for checkpointing DiffuCOMET models
2 changes: 2 additions & 0 deletions ckpts/diffusion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

## Directory for checkpointing DiffuCOMET diffusion modules
2 changes: 2 additions & 0 deletions ckpts/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

## Directory for checkpointing DiffuCOMET embedding modules
Loading

0 comments on commit ed08039

Please sign in to comment.