main.py
: main function for odqa/cqa chatbotmodule.py
: chatbot module classesfinetune.py
: finetune function for t5/optmodel_utils.py
: helper functions for moduledata_utils.py
: helper function/classes for training datareranker.py
: ranker for response quality, loaded from main_fnentity_tracker.py
: entity tracker for dialog(in progress)
push_all.sh
: push codes on githubrun_chatbot.sh
: script for running chatbot- For single turn QA, use
--task 'odqa'
. For Conversational QA, use--task 'cqa'
in script. - To serve OPT as generator, use
--gen_model_path "facebook/opt-1.3b"
and--model_name "opt"
; to use Flan-T5, change them to--gen_model_path "google/flan-t5-large"
and--model_name t5
- To load fine-tuned weight file, add
--gen_cp_path <path_to_weight>
. Without this line the pretrained weight of the model will be loaded.
- For single turn QA, use
run_train_opt.sh
: script for finetune opt- Feel free to change:
--device
: The GPU# for training.--logger_path
: Path to save training log.--saved_model_path
: Path to save fine-tuned weight.--outer_batch_size
: "Batch size" for Gradient--dataset_path
: Path to converted training dataset, same path as--saved_dataset_path
fromrun_convert_dataset.sh
.Default ./dataset/train_data--epochs
: fine-tuning epochs--lr
: learning rate
- Since training data is very large, the Gradient accumulation is applied in the code. Stay
--batch_size 1
to avoid OUT_OF_MEMORY issue. Change--outer_batch_size
to adjust the "batch size" for training.
- Feel free to change:
run_train_t5.sh
: script for finetune t5- Same as
run_train_opt.sh
- Same as
run_convert_dataset.sh
: script to convert dataset generated from GPT-3 to the format used for training--original_json_path
: Path to GPT-3 generated data. The default dataset isGPT-3_semantic_search.json
.--converted_json_path
: Path to converted Json format of dataset.--saved_dataset_path
: Path to save converted dataset
- ece_rag_dataset_new : ECE120 course note dataset for DPR.
- train_data : ECE120 finetune dataset, converted from GPT-3_semantic_search.json
- GPT-3_semantic_search.json: GPT-3 generated dataset for training.
- The context in training data is too large,truncate text may cause problem. (change the number in data_utils.py, line 214/247 to change the truncated size)
- Max input Length for T5 is only 512, fine-tuning may work, but it may lead to accurate issues.
- Entity Extractor for dialog in progress.