Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Falcon models #406

Merged
merged 25 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
aabbbaf
move q seq info into context
wangruohui Sep 11, 2023
e9e6af0
Merge branch 'pytorch-poc' into falcon
wangruohui Sep 14, 2023
415ca10
falcon aligned
wangruohui Sep 19, 2023
1701fba
trust_remote_code_argument
wangruohui Sep 19, 2023
82ab6d1
Merge branch 'trust_remote_code_argument' into falcon
wangruohui Sep 19, 2023
95cd446
fix for falcon
wangruohui Sep 19, 2023
25144c5
comment out debugs
wangruohui Sep 19, 2023
ee26a25
comment out debugs
wangruohui Sep 19, 2023
4278671
use position id in context
wangruohui Sep 20, 2023
1fd7380
remove codes in falcon model
wangruohui Sep 20, 2023
96a550f
Revert "comment out debugs"
wangruohui Sep 20, 2023
41165b5
Merge branch 'pytorch-poc' into falcon
wangruohui Sep 22, 2023
7bda327
7b correct
wangruohui Sep 22, 2023
981d0a6
1b aligned
wangruohui Sep 25, 2023
d389273
Merge branch 'pytorch-poc' into falcon
wangruohui Sep 25, 2023
892b254
remove debugs
wangruohui Sep 25, 2023
0386288
patch to ignore position ids
wangruohui Sep 26, 2023
9673622
remove debug in alibi, avoid empty inputs
wangruohui Sep 26, 2023
6e92564
fix
wangruohui Sep 26, 2023
7e23002
Merge branch 'pytorch-poc' into falcon
wangruohui Oct 9, 2023
f8df92e
Merge branch 'pytorch-poc' into falcon
wangruohui Oct 10, 2023
de7b81e
rename dir to replace to "models"
wangruohui Oct 10, 2023
4f3b6de
Merge branch 'pytorch-poc' into falcon
wangruohui Oct 16, 2023
8496d5c
use position_id and new fill kernel
wangruohui Oct 16, 2023
85dc284
remove useless get_prompt func
wangruohui Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,19 @@ def stop_words(self):
return [151645] # <|im_end|>


@MODELS.register_module(name='falcon')
class Falcon(BaseModel):

def __init__(self):
super().__init__()

def update_input_ids(self, input_ids: List):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not get used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if len(input_ids) == 0:
# avoid empty input to model
input_ids = [11]
return input_ids


@MODELS.register_module(name='chatglm2-6b')
class ChatGLM2(BaseModel):

Expand Down
17 changes: 16 additions & 1 deletion lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,22 @@ def __init__(
cache_config = CacheConfig(block_size=64,
num_cpu_blocks=0,
num_gpu_blocks=0)
if 'chatglm' in model_path:
if 'falcon' in model_path:
if hf_config.multi_query:
kv_dim = hf_config.hidden_size // hf_config.num_attention_heads
kv_head = 1
else:
kv_dim = hf_config.hidden_size
kv_head = hf_config.num_attention_heads
model_config = ModelConfig(
kv_dim,
hf_config.num_hidden_layers,
kv_head,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
elif 'chatglm' in model_path:
model_config = ModelConfig(
hf_config.hidden_size // hf_config.num_attention_heads *
hf_config.multi_query_group_num,
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch_poc/kernels/alibi_pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _fwd_kernel(
K,
V,
sm_scale,
alibi_scale,
B_Start_Loc,
B_Seqlen,
B_kvlen,
Expand Down Expand Up @@ -134,8 +135,9 @@ def _fwd_kernel(
qk *= sm_scale

mask = start_n + offs_n[None, :]
bias = mask.to(tl.float32) * head_slope
bias = mask.to(tl.float32) * (head_slope * alibi_scale)
qk += bias

# NOTE: inf - inf = nan, and nan will leads to error
qk = tl.where(
(history_len + offs_m[:, None]) >= mask,
Expand Down Expand Up @@ -191,6 +193,7 @@ def alibi_paged_attention_fwd(
max_input_len: int,
head_offset: int = 0,
num_heads: int = -1,
alibi_scale: float = 1.0,
BLOCK: int = 64,
):
"""Paged attention forward with alibi bias.
Expand Down Expand Up @@ -230,6 +233,7 @@ def alibi_paged_attention_fwd(
k,
v,
sm_scale,
alibi_scale,
b_start_loc,
b_seq_len,
b_kv_seq_len,
Expand Down Expand Up @@ -257,4 +261,5 @@ def alibi_paged_attention_fwd(
num_warps=num_warps,
num_stages=1,
)

return
Loading
Loading