Skip to content

Commit

Permalink
Merge pull request #2 from qidi1/public_model
Browse files Browse the repository at this point in the history
rename moudle name.
  • Loading branch information
luliwjc authored Jan 12, 2024
2 parents c7e7018 + 58f1587 commit e6c705d
Show file tree
Hide file tree
Showing 44 changed files with 46 additions and 1,344 deletions.
151 changes: 0 additions & 151 deletions hf_inference.py

This file was deleted.

17 changes: 11 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LlamaForCausalLM,
StoppingCriteria,
)
from peft import PeftModel

from utils.extract_sql_meta import isConstCanFind, convert_schema, fetch_column_all_value, is_number

Expand Down Expand Up @@ -121,9 +122,12 @@ def load_model_tokenizer(path, model_type=None, peft_path=None, quantization=Non
trust_remote_code=True,
# use_safetensors=False,
)

print("Loading Original MODEL...")
model = base_model
if peft_path:
print("Loading PEFT MODEL...")
model = PeftModel.from_pretrained(base_model, peft_path, torch_dtype=torch_dtype)
else:
print("Loading Original MODEL...")
model = base_model

model.eval()

Expand Down Expand Up @@ -296,9 +300,9 @@ def second_round_prompt_check_constrain(sql, may_be_other_fields, db_list):
return '\n'.join(prompt_str_list)


def start_inference(base_model_path, valid_file_path, db_dir):
def start_inference(base_model_path, peft_path, valid_file_path, db_dir):
content_list, database_list = load_test_data(valid_file_path)
model, tokenizer = load_model_tokenizer(base_model_path, model_type='deepseek',
model, tokenizer = load_model_tokenizer(base_model_path, peft_path= peft_path, model_type='deepseek',
eos_token='<|end▁of▁sentence|>', pad_token='<|end▁of▁sentence|>')
cnt, err = 0, 0
predict_result = []
Expand Down Expand Up @@ -330,7 +334,7 @@ def start_inference(base_model_path, valid_file_path, db_dir):


def main(opt):
predict_result = start_inference(opt.model_path, opt.eval_file, opt.base_dir)
predict_result = start_inference(opt.model_path,opt.peft_path, opt.eval_file, opt.base_dir)
with open(opt.output, 'w') as f:
f.write("\n".join(predict_result))
f.flush()
Expand All @@ -342,5 +346,6 @@ def main(opt):
parser_arg.add_argument('--eval_file', type=str, default="./data/preprocessed_data/resdsql_dev.json")
parser_arg.add_argument('--base_dir', type=str, default="./data/preprocessed_data/spider/database")
parser_arg.add_argument('--output', type=str, default="./predict_result/sqlgpt.sql")
parser_arg.add_argument('--peft_path', type=str, default=None)
opt = parser_arg.parse_args()
main(opt)
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
from ply import lex

from sqlgpt_parser.parser.mysql_parser.reserved import (
from sql_metadata.parser.mysql_parser.reserved import (
reversed,
nonreserved,
not_keyword_token,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import types, threading

from sqlgpt_parser.parser.tree.window import (
from sql_metadata.parser.tree.window import (
FrameBound,
FrameClause,
FrameExpr,
Expand All @@ -23,9 +23,9 @@
WindowSpec,
)

from sqlgpt_parser.parser.tree.with_stmt import CommonTableExpr, With, WithHasQuery
from sqlgpt_parser.parser.tree.index_type import IndexType
from sqlgpt_parser.parser.tree.expression import (
from sql_metadata.parser.tree.with_stmt import CommonTableExpr, With, WithHasQuery
from sql_metadata.parser.tree.index_type import IndexType
from sql_metadata.parser.tree.expression import (
AggregateFunc,
ArithmeticBinaryExpression,
ArithmeticUnaryExpression,
Expand Down Expand Up @@ -61,9 +61,9 @@
JsonTableColumn,
TimeInterval,
)
from sqlgpt_parser.parser.tree.grouping import SimpleGroupBy
from sqlgpt_parser.parser.tree.join_criteria import JoinOn, JoinUsing, NaturalJoin
from sqlgpt_parser.parser.tree.literal import (
from sql_metadata.parser.tree.grouping import SimpleGroupBy
from sql_metadata.parser.tree.join_criteria import JoinOn, JoinUsing, NaturalJoin
from sql_metadata.parser.tree.literal import (
BooleanLiteral,
DateLiteral,
DoubleLiteral,
Expand All @@ -75,21 +75,21 @@
ErrorLiteral,
TimestampLiteral,
)
from sqlgpt_parser.parser.tree.node import Node
from sqlgpt_parser.parser.tree.qualified_name import QualifiedName
from sqlgpt_parser.parser.tree.query_specification import QuerySpecification
from sqlgpt_parser.parser.tree.relation import AliasedRelation, Join
from sqlgpt_parser.parser.tree.select import Select
from sqlgpt_parser.parser.tree.select_item import Partition, SingleColumn
from sqlgpt_parser.parser.tree.set_operation import Except, Intersect, Union
from sqlgpt_parser.parser.tree.sort_item import ByItem, PartitionByClause, SortItem
from sqlgpt_parser.parser.tree.statement import Delete, Insert, Query, Update
from sqlgpt_parser.parser.tree.table import Table
from sqlgpt_parser.parser.tree.values import Values
from sqlgpt_parser.parser.tree.field_type import UNSPECIFIEDLENGTH, FieldType, SQLType
from sql_metadata.parser.tree.node import Node
from sql_metadata.parser.tree.qualified_name import QualifiedName
from sql_metadata.parser.tree.query_specification import QuerySpecification
from sql_metadata.parser.tree.relation import AliasedRelation, Join
from sql_metadata.parser.tree.select import Select
from sql_metadata.parser.tree.select_item import Partition, SingleColumn
from sql_metadata.parser.tree.set_operation import Except, Intersect, Union
from sql_metadata.parser.tree.sort_item import ByItem, PartitionByClause, SortItem
from sql_metadata.parser.tree.statement import Delete, Insert, Query, Update
from sql_metadata.parser.tree.table import Table
from sql_metadata.parser.tree.values import Values
from sql_metadata.parser.tree.field_type import UNSPECIFIEDLENGTH, FieldType, SQLType

from ply import yacc
from sqlgpt_parser.parser.mysql_parser.lexer import tokens, lexer
from sql_metadata.parser.mysql_parser.lexer import tokens, lexer

tokens = tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
"""
import copy

from sqlgpt_parser.parser.tree.grouping import GroupingSets, SimpleGroupBy
from sqlgpt_parser.parser.tree.literal import StringLiteral, Literal
from sqlgpt_parser.parser.tree.qualified_name import QualifiedName
from sqlgpt_parser.parser.tree.visitor import DefaultTraversalVisitor
from sqlgpt_parser.parser.tree.expression import (
from sql_metadata.parser.tree.grouping import GroupingSets, SimpleGroupBy
from sql_metadata.parser.tree.literal import StringLiteral, Literal
from sql_metadata.parser.tree.qualified_name import QualifiedName
from sql_metadata.parser.tree.visitor import DefaultTraversalVisitor
from sql_metadata.parser.tree.expression import (
InListExpression,
QualifiedNameReference,
SubqueryExpression,
)
from sqlgpt_parser.parser.tree.query_specification import QuerySpecification
from sqlgpt_parser.parser.tree.join_criteria import JoinOn, JoinUsing
from sqlgpt_parser.parser.tree.relation import AliasedRelation
from sqlgpt_parser.parser.tree.table import Table
from sqlgpt_parser.utils.untils import convert_nested_strings_to_lowercase, get_string_values
from sqlgpt_parser.parser.tree.statement import Delete, Insert, Query, Update
from sqlgpt_parser.parser.tree.set_operation import Except, Intersect, Union
from sql_metadata.parser.tree.query_specification import QuerySpecification
from sql_metadata.parser.tree.join_criteria import JoinOn, JoinUsing
from sql_metadata.parser.tree.relation import AliasedRelation
from sql_metadata.parser.tree.table import Table
from sql_metadata.utils.untils import convert_nested_strings_to_lowercase, get_string_values
from sql_metadata.parser.tree.statement import Delete, Insert, Query, Update
from sql_metadata.parser.tree.set_operation import Except, Intersect, Union


class ParserUtils(object):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit e6c705d

Please sign in to comment.