Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support local LLM #16

Merged
merged 8 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pip install .

#### Setup for LLMs

To run generating qusetions and generalization functions based on LLMs,apply API-KEY before you run the whole flow.
To run generating questions and generalization functions based on LLMs,apply API-KEY before you run the whole flow.

1. Apply API-KEY

Expand All @@ -52,6 +52,13 @@ echo $DASHSCOPE_API_KEY
conda activate text2gql
```

#### Setup for Local LLMs
To run generating questions and generalization functions based on LLMs, use model id from HuggingFace model hub if you can access HuggingFace or use the related local file path where the LLM model is.

1. Change the model path in each llm related sh file with model id or model local path. If you want to use online LLMs API, please keep model path=""

2. You can also change model path in Config.json file to setup your local LLM.

### Run

#### The whole flow
Expand Down
6 changes: 4 additions & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
"work_mode": "400",
"input_dir_or_file": "./output/output_query.txt",
"output_dir": "./output",
"suffix": "_t"
"suffix": "_t",
"model_path":"Qwen/Qwen2.5-Coder-7B-Instruct"
},
"generate_dataset": {
"input_corpus_dir_or_file": "./output/output_query_t_general.txt",
"output_corpus_path": "./output/text2gql_train.json"
}
}
}

130 changes: 110 additions & 20 deletions generalize_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import os
import copy
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
akiba93 marked this conversation as resolved.
Show resolved Hide resolved


def gen_question_directly(
input_path, output_path
input_path, output_path, tokenizer, model, current_device
): # generate multi questions according to input cypher
# 1. readt files
with open(input_path, "r") as file:
Expand All @@ -30,17 +33,17 @@ def gen_question_directly(
{"role": "user", "content": cypher},
]
# 3. get response
responses = call_with_messages(massages)
responses = call_with_messages(massages, tokenizer, model, current_device)
if responses != "":
questions = process_handler.process(responses)
# 4. save to file
# 4. save to file
save2file(db_id, cypher, questions, output_path)
print("corpus output file:", output_path)


# deprecated
def general_question_directly(
input_path, output_path
input_path, output_path, tokenizer, model, current_device
): # generate multi questions according to input question
# 1. read file
with open(input_path, "r") as file:
Expand All @@ -58,7 +61,7 @@ def general_question_directly(
{"role": "user", "content": question},
]
# 3. get response
responses = call_with_messages(massages)
responses = call_with_messages(massages, tokenizer, model, current_device)
# 4. postprocess and save
if responses != "":
questions = process_handler.process(responses)
Expand All @@ -68,7 +71,7 @@ def general_question_directly(

# recommended
def generalization(
input_path, output_path
input_path, output_path, tokenizer, model, current_device
): # generate multi questions according to input cypher and question
# 1. read file
with open(input_path, "r") as file:
Expand All @@ -87,15 +90,16 @@ def generalization(
{"role": "user", "content": content},
]
# 3. get response
responses = call_with_messages(massages)
responses = call_with_messages(massages, tokenizer, model, current_device)

# 4. postprocess and save
if responses != "":
questions = process_handler.process(responses)
save2file(db_id, cypher, questions, output_path)
print("corpus output file:", output_path)


def gen_question_with_template(input_path, output_path):
def gen_question_with_template(input_path, output_path, tokenizer, model, current_device):
(
db_id,
tmplt_cypher_list,
Expand All @@ -118,23 +122,39 @@ def gen_question_with_template(input_path, output_path):
+ "cyphers:\n"
+ cypher_content
)
massages = [

messages = [
{
"role": "system",
"content": "设想你是一个图数据库的前端,用户给你一个提问,你要给出对应的cypher语句。现在需要你反过来,将我给你的cypher语句翻译为使用者可能输入的提问,要求符合图数据库使用者的口吻,尽量准确地符合cypher含义,不要遗漏cypher中关键字如DISTINCT、OPTIONAL等,可以修改为问句或者陈述句,必须是中文。我每次会给你一个跟需要翻译的cypher相同句式的template帮助你理解cypher的含义。这是一个例子:\ntempalte:\nMATCH (m:keyword{name: 'news report'})<-[:has_keyword]-(a:movie) RETURN a,m ,关键词是news report的电影有哪些?返回相应的节点。cypher:MATCH (m:movie{title: 'The Dark Knight'})<-[:write]-(a:person) RETURN a,m\nMATCH (m:user{login: 'Sherman'})<-[:is_friend]-(a:user) RETURN a,m\n你应当回答:电影The Dark Knight的作者有哪些?返回相关节点。\n在图中找到登录用户Sheman的朋友节点,返回相关的节点信息。\n下面请你对cyphers逐条cypher语句输出翻译的结果,不需要注明是对哪条语句进行的泛化,结果按照换行符隔开,注意句子应当有标点符号",
},
{"role": "user", "content": content},
]

# 3. get response
responses = call_with_messages(massages)
responses = call_with_messages(messages, tokenizer, model, current_device)

# 4. postprocess and save
if responses != "":
questions = process_handler.process(responses)

#deal with unexpected questions length
chunk_size = len(cypher_trunk)
questions_size = len(questions)

if questions_size > chunk_size:
questions = questions[0:chunk_size]
elif questions_size < chunk_size:
filled_questions = ['Question 生成失败'] * (chunk_size - questions_size)
questions = questions + filled_questions
else:
pass

save2file_t(db_id, cypher_trunk, questions, output_path)
print("output file:", output_path)


def call_with_messages(messages):
def call_with_messages_online(messages):
response = Generation.call(
model="qwen-plus-0723",
messages=messages,
Expand All @@ -146,11 +166,10 @@ def call_with_messages(messages):
)
if response.status_code == HTTPStatus.OK:
content = response.output.choices[0].message.content
# print(content)
return content
else:
if response.code == 429: # Requests rate limit exceeded
call_with_messages(messages)
call_with_messages_online(messages)
else:
print(
"Request id: %s, Status code: %s, error code: %s, error message: %s"
Expand All @@ -164,6 +183,34 @@ def call_with_messages(messages):
print("Failed!", messages[1]["content"])
return ""

def call_with_messages_local(messages, tokenizer, model, current_device):
#generate content
inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_dict=True, return_tensors="pt").to(current_device)

#add more args
output = model.generate(
**inputs,
do_sample=True,
temperature=0.8,
top_p=0.8,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens = 2048
)

#deal with output and return
output = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

return output

def call_with_messages(messages,tokenizer="", model="", current_device=""):
if model_path == "":
output = call_with_messages_online(messages)
else:
output = call_with_messages_local(messages, tokenizer, model, current_device)
return output


def load_file_gen_question_with_template(input_path):
with open(input_path, "r") as file:
Expand Down Expand Up @@ -224,15 +271,15 @@ def chunk_list(lst, chunk_size=5):
yield lst[i : i + chunk_size]


def state_machine(input_path, output_path):
def state_machine(input_path, output_path, tokenizer, model, current_device):
if mode == Status.GEN_QUESTION_DIRECTLY.value[0]: # 100
gen_question_directly(input_path, output_path)
gen_question_directly(input_path, output_path, tokenizer, model, current_device)
elif mode == Status.GENERAL_QUESTION_DIRECTLY.value[0]: # 200
general_question_directly(input_path, output_path) # deprecated # 300
general_question_directly(input_path, output_path, tokenizer, model, current_device) # deprecated # 300
elif mode == Status.GENERALIZATION.value[0]:
generalization(input_path, output_path) # recommended # 400
generalization(input_path, output_path, tokenizer, model, current_device) # recommended # 400
elif mode == Status.GEN_QUESTION_WITH_TEMPLATE.value[0]:
gen_question_with_template(input_path, output_path)
gen_question_with_template(input_path, output_path, tokenizer, model, current_device)
else:
print("[ERROR]: work_mode is not proper, current work_mode is:", mode)

Expand All @@ -244,8 +291,47 @@ def main():
dir, file_name = os.path.split(input_path)
file_base, file_extension = os.path.splitext(file_name)
output_path = os.path.join(output_dir, file_base + suffix + ".txt")
state_machine(input_path, output_path)

#load local model
if model_path != "":
#0. check current device
current_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("model running on %s"%current_device)
print("the model path is %s"%model_path)

#1.load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

#2.load model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(current_device)

#3.call
state_machine(input_path, output_path, tokenizer, model, current_device)
else:
tokenizer = ""
model = ""
current_device = ""
state_machine(input_path, output_path, tokenizer, model, current_device)

elif os.path.isdir(input_dir_or_file):
#load local model
if model_path != "":
#0. check current device
current_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("model running on %s"%current_device)
print("the model path is %s"%model_path)

#1.load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

#2.load model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(current_device)

else:
tokenizer = ""
model = ""
current_device = ""

input_dir = input_dir_or_file
for root, dirs, file_names in os.walk(input_dir):
for file_name in file_names:
Expand All @@ -259,7 +345,7 @@ def main():
output_path = os.path.join(root, file_name).replace(
input_dir, output_dir
)
state_machine(input_path, output_path)
state_machine(input_path, output_path, tokenizer, model, current_device)
else:
print("[ERROR]: input file is not exsit", input_dir_or_file)

Expand All @@ -271,6 +357,8 @@ def main():
input_dir_or_file = sys.argv[2]
output_dir = sys.argv[3]
suffix = sys.argv[4]
model_path = sys.argv[5]
print(model_path)
if not os.path.isdir(output_dir):
print("[ERROR]: output_dir do not exsit!")
sys.exit()
Expand All @@ -282,8 +370,10 @@ def main():
input_dir_or_file = configs["input_dir_or_file"]
output_dir = configs["output_dir"]
suffix = configs["suffix"]
model_path = configs["model_path"]
if not os.path.isdir(output_dir):
print("[ERROR]: output_dir do not exsit!")
sys.exit()
process_handler = CorpusPostProcess()
main()

7 changes: 6 additions & 1 deletion scripts/gen_question_directly_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ MODE=100
INPUT_DIR_OR_FILE='./input_examples/query_only.txt'
OUTPUT_DIR="./output"
SUFFIX='_gen'
#For online model service calling
MODEL_PATH=''
#For local model calling
#Using relative model path or HuggingFace model id
#MODEL_PATH='../Qwen/Qwen2.5-Coder-7B-Instruct'

echo "----------------Running generalize_llm.py to GEN_PROMPT_DIRECTLY----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH":
7 changes: 6 additions & 1 deletion scripts/gen_question_with_template_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ MODE=400
INPUT_DIR_OR_FILE='./input_examples/gen_question_llm.txt'
OUTPUT_DIR="./output"
SUFFIX='_t'
#For online model service calling
MODEL_PATH=''
#For local model calling
#Using relative model path or HuggingFace model id
#MODEL_PATH='../Qwen/Qwen2.5-Coder-7B-Instruct'

echo "----------------Running generalize_llm.py to GEN_PROMPT_WITH_TEMPLATE----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH"
7 changes: 6 additions & 1 deletion scripts/general_questions_directly_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ MODE=200
INPUT_DIR_OR_FILE='./input_examples/corpus.txt'
OUTPUT_DIR="./output"
SUFFIX='_general_directly'
#For online model service calling
MODEL_PATH=''
#For local model calling
#Using relative model path or HuggingFace model id
#MODEL_PATH='../Qwen/Qwen2.5-Coder-7B-Instruct'

echo "----------------Running generalize_llm.py to GENERAL_PROMPT_DIRECTLY----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH"
7 changes: 6 additions & 1 deletion scripts/generalize_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ MODE=300
INPUT_DIR_OR_FILE='./input_examples/corpus.txt'
OUTPUT_DIR="./output"
SUFFIX='_general'
#For online model service calling
MODEL_PATH=''
#For local model calling
#Using relative model path or HuggingFace model id
#MODEL_PATH='../Qwen/Qwen2.5-Coder-7B-Instruct'

echo "----------------Running generalize_llm.py to GENERALIZE with query----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH"
12 changes: 9 additions & 3 deletions scripts/run_the_whole_flow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@ fi
CONFIG_PATH='./config.json'
GEN_QUERY=true
DB_ID='movie'
#For online model service calling
MODEL_PATH=''
#For local model calling
#Using relative model path or HuggingFace model id
#MODEL_PATH='../Qwen/Qwen2.5-Coder-7B-Instruct'
echo "----------------Running generator.py to generate cyphers----------------"
python3 generator.py "$CONFIG_PATH" "$GEN_QUERY" "$DB_ID"

MODE=400
INPUT_DIR_OR_FILE="./output/output_query.txt"
SUFFIX='_t'
echo "----------------Running generalize_llm.py to GEN_QUESTION_WITH_TEMPLATE----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH"

MODE=300
INPUT_DIR_OR_FILE="./output/output_query_t.txt"
SUFFIX='_general'
echo "----------------Running generalize_llm.py to GENERALIZE with query----------------"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX"
python3 generalize_llm.py "$MODE" "$INPUT_DIR_OR_FILE" "$OUTPUT_DIR" "$SUFFIX" "$MODEL_PATH"

CONFIG_PATH='./config.json'
INPUT_DIR_OR_FILE="./output/output_query_t_general.txt"
echo "----------------make dataset----------------"
python3 generate_dataset.py "$CONFIG_PATH" "$INPUT_DIR_OR_FILE"
python3 generate_dataset.py "$CONFIG_PATH" "$INPUT_DIR_OR_FILE"

2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def core_dependencies():
"graphviz==0.20.1",
"dashscope",
"tqdm",
"torch",
"transformers"
]


Expand Down