Skip to content

Commit

Permalink
fix glm4v (#2806)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Dec 30, 2024
1 parent 0af291d commit 9b078a1
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 73 deletions.
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys,
ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup,
Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth,
git_clone_github)
git_clone_github, get_matched_model_meta)
from .dataset import (AlpacaPreprocessor, ResponsePreprocessor, MessagesPreprocessor, AutoPreprocessor,
DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor,
LazyLLMDataset, ConstantLengthDataset, standard_keys, load_dataset, DATASET_TYPE,
Expand Down Expand Up @@ -54,7 +54,7 @@
'ModelInfo', 'ModelMeta', 'ModelKeys', 'register_model_arch', 'MultiModelKeys', 'ModelArch',
'MODEL_ARCH_MAPPING', 'get_model_arch', 'get_model_info_meta', 'get_model_name', 'register_model',
'ModelGroup', 'Model', 'get_model_tokenizer_with_flash_attn', 'get_model_tokenizer_multimodal',
'load_by_unsloth', 'git_clone_github'
'load_by_unsloth', 'git_clone_github', 'get_matched_model_meta'
],
'dataset': [
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING',
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __post_init__(self):
self._set_default()
super().__post_init__()

if self.rlhf_type not in ['cpo', 'orpo', 'rm'] and (self.train_type == 'full' or self.rlhf_type == 'ppo'):
if self.rlhf_type in ['dpo', 'kto'] and self.train_type == 'full' or self.rlhf_type == 'ppo':
self.ref_model = self.ref_model or self.model
self.ref_model_type = self.ref_model_type or self.model_type
self.ref_model_revision = self.ref_model_revision or self.model_revision
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map,
get_default_torch_dtype, get_matched_model_meta, get_model_info_meta, get_model_name,
get_model_tokenizer, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn,
get_model_with_value_head, load_by_unsloth, register_model)
load_by_unsloth, register_model)
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download
2 changes: 1 addition & 1 deletion swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class LLMModelType:
mamba = 'mamba'
polylm = 'polylm'
aya = 'aya'

# bert
modern_bert = 'modern_bert'
bert = 'bert'

Expand Down
12 changes: 8 additions & 4 deletions swift/llm/model/model/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ def get_model_tokenizer_glm4v(model_dir: str,
register_model(
ModelMeta(
MLLMModelType.glm4v,
[ModelGroup([
Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'),
])],
[
ModelGroup(
[
Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'),
],
requires=['transformers>=4.42,<4.45'],
),
],
TemplateType.glm4v,
get_model_tokenizer_glm4v,
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
model_arch=ModelArch.glm4v,
requires=['transformers>=4.42'],
))


Expand Down
58 changes: 1 addition & 57 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str,
return model, tokenizer


def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead':
from trl import AutoModelForCausalLMWithValueHead
lm_head_namings = ['lm_head', 'embed_out']
if not any(hasattr(model, attribute) for attribute in lm_head_namings):
setattr(model, 'lm_head', None) # avoid ValueError

model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

def patch_valuehead_model(model):
attr_list = [
'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower',
'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds',
'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding',
'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual'
]
for attr in attr_list:
if hasattr(model.pretrained_model, attr) and not hasattr(model, attr):
setattr(model, attr, getattr(model.pretrained_model, attr))

# PPO compatible
if not hasattr(model, 'score'):
setattr(model, 'score', model.v_head)
if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'):
model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix

base_model_prefix = model.pretrained_model.base_model_prefix
if hasattr(model.pretrained_model, base_model_prefix):
setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix))

patch_valuehead_model(model)

# try to load local vhead weights
vhead_params = None
try:
from safetensors import safe_open
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors')
with safe_open(vhead_file, framework='pt', device='cpu') as f:
vhead_params = {key: f.get_tensor(key) for key in f.keys()}
except Exception:
pass

try:
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin')
vhead_params = torch.load(vhead_file, map_location='cpu')
except Exception:
pass

if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info(f'Loading value head weights from {vhead_file}')
else:
logger.info('The local value head weight file was not detected.'
'Ignore it if this is during the reward modeling phase,')
return model


def get_model_tokenizer_with_flash_attn(model_dir: str,
model_info: ModelInfo,
model_kwargs: Dict[str, Any],
Expand Down Expand Up @@ -430,7 +374,7 @@ def get_model_info_meta(
if model_meta is None and model_type is not None:
model_meta = MODEL_MAPPING[model_type]
if model_meta is None:
model_meta = ModelMeta('', [], 'dummy', get_model_tokenizer_from_local, model_arch=None)
model_meta = ModelMeta(None, [], 'dummy', get_model_tokenizer_from_local, model_arch=None)
logger.info(f'Temporarily create model_meta: {model_meta}')

if torch_dtype is None:
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded['images'] = inputs2['images']
encoded['input_ids'] = input_ids
encoded['labels'] = labels
encoded['position_ids'] = list(range(len(input_ids)))
return encoded

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/template_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __post_init__(self):
val = getattr(self, key)
if isinstance(val, str):
setattr(self, key, [val])

assert isinstance(self.messages, list), f'messages: {self.messages}'
self.remove_response(self.messages)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions swift/ui/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video
from modelscope.hub.utils.utils import get_cache_dir

from swift.llm import TEMPLATE_MAPPING, BaseArguments
from swift.llm.model.register import get_matched_model_meta
from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta

all_langs = ['zh', 'en']
builder: Type['BaseUI'] = None
Expand Down
15 changes: 11 additions & 4 deletions tests/test_align/test_template/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def _infer_model(pt_engine, system=None, messages=None, images=None):
resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
response = resp[0].choices[0].message.content
messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '<image>这是什么'}]
else:
messages = messages.copy()
if images is None:
images = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png']
resp = pt_engine.infer([{'messages': messages, 'images': images}], request_config=request_config)
Expand Down Expand Up @@ -69,9 +71,14 @@ def test_yi_vl():
def test_glm4v():
# There will be differences in '\n'. This is normal.
pt_engine = PtEngine('ZhipuAI/glm-4v-9b')
_infer_model(pt_engine)
messages = [{'role': 'user', 'content': '描述这张图片'}]
response = _infer_model(pt_engine, messages=messages)
pt_engine.default_template.template_backend = 'jinja'
_infer_model(pt_engine)
response2 = _infer_model(pt_engine, messages=messages)
assert response == ('这张图片是一只小猫的特写,它有着非常醒目的蓝色眼睛和混合了灰色、白色和棕色毛发的皮毛。小猫的耳朵竖立着,胡须清晰可见。它的眼神看起来既好奇又警觉,整体上显得非常可爱。')
assert response2 == ('这是一张特写照片,展示了一只毛茸茸的小猫。小猫的眼睛大而圆,呈深蓝色,眼珠呈金黄色,非常明亮。它的鼻子短而小巧,'
'是粉色的。小猫的嘴巴紧闭,胡须细长。它的耳朵竖立着,耳朵内侧是白色的,外侧是棕色的。小猫的毛发看起来柔软而浓密,'
'主要是白色和棕色相间的花纹。背景模糊不清,但似乎是一个室内环境。')


def test_minicpmv():
Expand Down Expand Up @@ -307,11 +314,11 @@ def test_doc_owl2():
# test_deepseek_vl()
# test_deepseek_vl2()
# test_qwen_vl()
# test_glm4v()
test_glm4v()
# test_minicpmv()
# test_got_ocr()
# test_paligemma()
test_paligemma2()
# test_paligemma2()
# test_pixtral()
# test_llama_vision()
# test_llava_hf()
Expand Down

0 comments on commit 9b078a1

Please sign in to comment.