Skip to content

Commit

Permalink
Support chatglm2 in pytorch_poc (#360)
Browse files Browse the repository at this point in the history
* draft support for chatglm2

* debug llama

* gitignore

* update input_id

* better patching

* patch chatglm2 model

* fix after merge

* remove inits

* q_seq_info & remove some debug & orig_self

* remove old unqeuzze inputid

* update patch and model config

* remove debugs and clean codes

* clean codes

* add credit

* add update id / fix dependency
  • Loading branch information
wangruohui authored Sep 25, 2023
1 parent cdbea77 commit 8123c8e
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 11 deletions.
23 changes: 23 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def stop_words(self):
"""Return the stop-words' token ids."""
return None

def update_input_ids(self, input_ids: List[int]):
"""Further modify input ids of the prompt."""
return input_ids


@MODELS.register_module(name='vicuna')
class Vicuna(BaseModel):
Expand Down Expand Up @@ -481,6 +485,25 @@ def stop_words(self):
return [151645] # <|im_end|>


@MODELS.register_module(name='chatglm2-6b')
class ChatGLM2(BaseModel):

def __init__(self):
super().__init__()
self.count = 0

def get_prompt(self, prompt, sequence_start=True):
# need more check
# https://github.com/THUDM/ChatGLM2-6B/issues/48
# [64790, 64792] to be prepended
self.count += 1
return f'[Round {self.count}]\n\n问:{prompt}\n\n答:'

def update_input_ids(self, input_ids: List):
input_ids = [64790, 64792] + input_ids
return input_ids


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down
10 changes: 7 additions & 3 deletions lmdeploy/pytorch_poc/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def main(
model_path,
model_name: str, # can not get model_name from hf model
session_id: int = 1,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty: float = 1.0,
tp: int = 1,
stream_output=True):
Expand Down Expand Up @@ -73,12 +76,13 @@ def main(
continue
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
input_ids = model.update_input_ids(input_ids)
print(f'{prompt} ', end='', flush=True)
response_size = 0
sampling_param = SamplingParam(
top_k=40,
top_p=0.8,
temperature=0.8,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=False,
random_seed=seed,
Expand Down
27 changes: 19 additions & 8 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,25 @@ def __init__(
cache_config = CacheConfig(block_size=64,
num_cpu_blocks=0,
num_gpu_blocks=0)
model_config = ModelConfig(
hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
if 'chatglm' in model_path:
model_config = ModelConfig(
hf_config.hidden_size // hf_config.num_attention_heads *
hf_config.multi_query_group_num,
hf_config.num_layers,
hf_config.multi_query_group_num,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)
else:
model_config = ModelConfig(
hf_config.hidden_size,
hf_config.num_hidden_layers,
hf_config.num_attention_heads,
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
dtype=torch_dtype,
)

self.scheduler_config = scheduler_config
self.cache_config = cache_config
Expand Down
Loading

0 comments on commit 8123c8e

Please sign in to comment.