Skip to content

Commit

Permalink
Fix chatglm tokenizer failed when transformers>=4.45.0 (#2520)
Browse files Browse the repository at this point in the history
* Fix chatglm tokenizer failed when transformers>=4.45.0

* fix chatglm2-6b
  • Loading branch information
AllentDan authored Sep 26, 2024
1 parent 0323103 commit bb1dfa6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 27 additions & 0 deletions lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,15 @@ class ChatGLM4Tokenizer(HuggingFaceTokenizer):

def __init__(self, model_path):
super(ChatGLM4Tokenizer, self).__init__(model_path)
original_pad = self.model._pad

def __pad(*args, **kwargs):
if 'padding_side' in kwargs:
kwargs.pop('padding_side')
return original_pad(*args, **kwargs)

# fix for transformers>4.45.0
self.model._pad = __pad

def encode(self,
s: str,
Expand All @@ -534,6 +543,22 @@ def encode(self,
**kwargs)


class ChatGLMTokenizer(HuggingFaceTokenizer):
"""tokenizer of GLM2."""

def __init__(self, model_path):
super(ChatGLMTokenizer, self).__init__(model_path)
original_pad = self.model._pad

def __pad(*args, **kwargs):
if 'padding_side' in kwargs:
kwargs.pop('padding_side')
return original_pad(*args, **kwargs)

# fix for transformers>4.45.0
self.model._pad = __pad


class Tokenizer:
"""Tokenize prompts or de-tokenize tokens into texts.
Expand Down Expand Up @@ -563,6 +588,8 @@ def __init__(self, model_file: str):
config_tokenizer_class = tokenizer_config.get('tokenizer_class')
if config_tokenizer_class == 'ChatGLM4Tokenizer':
self.model = ChatGLM4Tokenizer(model_folder)
elif config_tokenizer_class == 'ChatGLMTokenizer':
self.model = ChatGLMTokenizer(model_folder)
else:
self.model = HuggingFaceTokenizer(model_folder)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_lmdeploy/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer
from lmdeploy.tokenizer import DetokenizeState, HuggingFaceTokenizer, Tokenizer


@pytest.mark.parametrize('model_path', [
Expand All @@ -20,7 +20,7 @@
@pytest.mark.parametrize('skip_special_tokens', [True, False])
def test_tokenizer(model_path, input, interval, add_special_tokens,
skip_special_tokens):
tokenizer = HuggingFaceTokenizer(model_path)
tokenizer = Tokenizer(model_path).model
encoded = tokenizer.encode(input,
False,
add_special_tokens=add_special_tokens)
Expand Down

0 comments on commit bb1dfa6

Please sign in to comment.