From 9b078a1607e40124ebceb228a7b30dd979df2a78 Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 30 Dec 2024 18:41:36 +0800 Subject: [PATCH] fix glm4v (#2806) --- swift/llm/__init__.py | 4 +- swift/llm/argument/rlhf_args.py | 2 +- swift/llm/model/__init__.py | 2 +- swift/llm/model/constant.py | 2 +- swift/llm/model/model/glm.py | 12 ++-- swift/llm/model/register.py | 58 +------------------ swift/llm/template/template/glm.py | 1 + swift/llm/template/template_inputs.py | 2 +- swift/ui/base.py | 3 +- tests/test_align/test_template/test_vision.py | 15 +++-- 10 files changed, 28 insertions(+), 73 deletions(-) diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index d098e849a..839707641 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -20,7 +20,7 @@ HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys, ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup, Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth, - git_clone_github) + git_clone_github, get_matched_model_meta) from .dataset import (AlpacaPreprocessor, ResponsePreprocessor, MessagesPreprocessor, AutoPreprocessor, DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor, LazyLLMDataset, ConstantLengthDataset, standard_keys, load_dataset, DATASET_TYPE, @@ -54,7 +54,7 @@ 'ModelInfo', 'ModelMeta', 'ModelKeys', 'register_model_arch', 'MultiModelKeys', 'ModelArch', 'MODEL_ARCH_MAPPING', 'get_model_arch', 'get_model_info_meta', 'get_model_name', 'register_model', 'ModelGroup', 'Model', 'get_model_tokenizer_with_flash_attn', 'get_model_tokenizer_multimodal', - 'load_by_unsloth', 'git_clone_github' + 'load_by_unsloth', 'git_clone_github', 'get_matched_model_meta' ], 'dataset': [ 'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING', diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 8d54f9eb7..68dbd9f7d 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -48,7 +48,7 @@ def __post_init__(self): self._set_default() super().__post_init__() - if self.rlhf_type not in ['cpo', 'orpo', 'rm'] and (self.train_type == 'full' or self.rlhf_type == 'ppo'): + if self.rlhf_type in ['dpo', 'kto'] and self.train_type == 'full' or self.rlhf_type == 'ppo': self.ref_model = self.ref_model or self.model self.ref_model_type = self.ref_model_type or self.model_type self.ref_model_revision = self.ref_model_revision or self.model_revision diff --git a/swift/llm/model/__init__.py b/swift/llm/model/__init__.py index 58ba7500c..5e6a49cbb 100644 --- a/swift/llm/model/__init__.py +++ b/swift/llm/model/__init__.py @@ -5,5 +5,5 @@ from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map, get_default_torch_dtype, get_matched_model_meta, get_model_info_meta, get_model_name, get_model_tokenizer, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, - get_model_with_value_head, load_by_unsloth, register_model) + load_by_unsloth, register_model) from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index a87f901c7..c7b6dad4c 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -93,7 +93,7 @@ class LLMModelType: mamba = 'mamba' polylm = 'polylm' aya = 'aya' - + # bert modern_bert = 'modern_bert' bert = 'bert' diff --git a/swift/llm/model/model/glm.py b/swift/llm/model/model/glm.py index bb22dfc79..0fb802fbb 100644 --- a/swift/llm/model/model/glm.py +++ b/swift/llm/model/model/glm.py @@ -178,14 +178,18 @@ def get_model_tokenizer_glm4v(model_dir: str, register_model( ModelMeta( MLLMModelType.glm4v, - [ModelGroup([ - Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'), - ])], + [ + ModelGroup( + [ + Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'), + ], + requires=['transformers>=4.42,<4.45'], + ), + ], TemplateType.glm4v, get_model_tokenizer_glm4v, architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'], model_arch=ModelArch.glm4v, - requires=['transformers>=4.42'], )) diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index a98406eb8..8e3428cc6 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str, return model, tokenizer -def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - setattr(model, 'lm_head', None) # avoid ValueError - - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - - def patch_valuehead_model(model): - attr_list = [ - 'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower', - 'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds', - 'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding', - 'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual' - ] - for attr in attr_list: - if hasattr(model.pretrained_model, attr) and not hasattr(model, attr): - setattr(model, attr, getattr(model.pretrained_model, attr)) - - # PPO compatible - if not hasattr(model, 'score'): - setattr(model, 'score', model.v_head) - if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'): - model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix - - base_model_prefix = model.pretrained_model.base_model_prefix - if hasattr(model.pretrained_model, base_model_prefix): - setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix)) - - patch_valuehead_model(model) - - # try to load local vhead weights - vhead_params = None - try: - from safetensors import safe_open - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors') - with safe_open(vhead_file, framework='pt', device='cpu') as f: - vhead_params = {key: f.get_tensor(key) for key in f.keys()} - except Exception: - pass - - try: - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin') - vhead_params = torch.load(vhead_file, map_location='cpu') - except Exception: - pass - - if vhead_params is not None: - model.load_state_dict(vhead_params, strict=False) - logger.info(f'Loading value head weights from {vhead_file}') - else: - logger.info('The local value head weight file was not detected.' - 'Ignore it if this is during the reward modeling phase,') - return model - - def get_model_tokenizer_with_flash_attn(model_dir: str, model_info: ModelInfo, model_kwargs: Dict[str, Any], @@ -430,7 +374,7 @@ def get_model_info_meta( if model_meta is None and model_type is not None: model_meta = MODEL_MAPPING[model_type] if model_meta is None: - model_meta = ModelMeta('', [], 'dummy', get_model_tokenizer_from_local, model_arch=None) + model_meta = ModelMeta(None, [], 'dummy', get_model_tokenizer_from_local, model_arch=None) logger.info(f'Temporarily create model_meta: {model_meta}') if torch_dtype is None: diff --git a/swift/llm/template/template/glm.py b/swift/llm/template/template/glm.py index 46b73043d..e9ee3d008 100644 --- a/swift/llm/template/template/glm.py +++ b/swift/llm/template/template/glm.py @@ -64,6 +64,7 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded['images'] = inputs2['images'] encoded['input_ids'] = input_ids encoded['labels'] = labels + encoded['position_ids'] = list(range(len(input_ids))) return encoded def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: diff --git a/swift/llm/template/template_inputs.py b/swift/llm/template/template_inputs.py index c4ed439e6..1f6734dde 100644 --- a/swift/llm/template/template_inputs.py +++ b/swift/llm/template/template_inputs.py @@ -45,7 +45,7 @@ def __post_init__(self): val = getattr(self, key) if isinstance(val, str): setattr(self, key, [val]) - + assert isinstance(self.messages, list), f'messages: {self.messages}' self.remove_response(self.messages) @staticmethod diff --git a/swift/ui/base.py b/swift/ui/base.py index 508b61f4b..6e1c84713 100644 --- a/swift/ui/base.py +++ b/swift/ui/base.py @@ -15,8 +15,7 @@ from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video from modelscope.hub.utils.utils import get_cache_dir -from swift.llm import TEMPLATE_MAPPING, BaseArguments -from swift.llm.model.register import get_matched_model_meta +from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta all_langs = ['zh', 'en'] builder: Type['BaseUI'] = None diff --git a/tests/test_align/test_template/test_vision.py b/tests/test_align/test_template/test_vision.py index 6cdffa39e..637bebe9d 100644 --- a/tests/test_align/test_template/test_vision.py +++ b/tests/test_align/test_template/test_vision.py @@ -17,6 +17,8 @@ def _infer_model(pt_engine, system=None, messages=None, images=None): resp = pt_engine.infer([{'messages': messages}], request_config=request_config) response = resp[0].choices[0].message.content messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '这是什么'}] + else: + messages = messages.copy() if images is None: images = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'] resp = pt_engine.infer([{'messages': messages, 'images': images}], request_config=request_config) @@ -69,9 +71,14 @@ def test_yi_vl(): def test_glm4v(): # There will be differences in '\n'. This is normal. pt_engine = PtEngine('ZhipuAI/glm-4v-9b') - _infer_model(pt_engine) + messages = [{'role': 'user', 'content': '描述这张图片'}] + response = _infer_model(pt_engine, messages=messages) pt_engine.default_template.template_backend = 'jinja' - _infer_model(pt_engine) + response2 = _infer_model(pt_engine, messages=messages) + assert response == ('这张图片是一只小猫的特写,它有着非常醒目的蓝色眼睛和混合了灰色、白色和棕色毛发的皮毛。小猫的耳朵竖立着,胡须清晰可见。它的眼神看起来既好奇又警觉,整体上显得非常可爱。') + assert response2 == ('这是一张特写照片,展示了一只毛茸茸的小猫。小猫的眼睛大而圆,呈深蓝色,眼珠呈金黄色,非常明亮。它的鼻子短而小巧,' + '是粉色的。小猫的嘴巴紧闭,胡须细长。它的耳朵竖立着,耳朵内侧是白色的,外侧是棕色的。小猫的毛发看起来柔软而浓密,' + '主要是白色和棕色相间的花纹。背景模糊不清,但似乎是一个室内环境。') def test_minicpmv(): @@ -307,11 +314,11 @@ def test_doc_owl2(): # test_deepseek_vl() # test_deepseek_vl2() # test_qwen_vl() - # test_glm4v() + test_glm4v() # test_minicpmv() # test_got_ocr() # test_paligemma() - test_paligemma2() + # test_paligemma2() # test_pixtral() # test_llama_vision() # test_llava_hf()