Skip to content

Commit

Permalink
[side effect] vlm quant failed
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Dec 17, 2024
1 parent 1efed79 commit 3a3dae7
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 90 deletions.
25 changes: 14 additions & 11 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
from lmdeploy.vl.model.builder import load_vl_model

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -243,18 +244,20 @@ def calibrate(model: str,
# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
vl_model = None
if model_type == 'vlm':
vl_model = model
if hasattr(model, 'language_model'):
model = model.language_model
if hasattr(model, 'llm'):
model = model.llm
if model_type == 'llm':
model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek vl
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV
model = vl_model.llm
model.config.use_cache = False
model = model.half().eval()
model.half().eval()

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class VisonModel(ABC):

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
hf_config: AutoConfig = None,
backend: str = ''):
"""init."""
self.model_path = model_path
self.with_llm = with_llm
self.max_memory = max_memory
self.backend = backend
if hf_config is None:
Expand All @@ -38,11 +40,11 @@ def build_preprocessor(self, ):
raise NotImplementedError()

def build_model(self, ):
"""build model.
"""build the vision part of a VLM model when backend is turbomind.
ONLY implement it when the backend is turbomind engine
But when `with_llm=True`, load the whole VLM model
"""
if self.backend == 'turbomind':
if self.backend == 'turbomind' or self.with_llm:
raise NotImplementedError()

@abstractmethod
Expand Down
13 changes: 10 additions & 3 deletions lmdeploy/vl/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@

def load_vl_model(model_path: str,
backend: str,
with_llm: bool = False,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None):
"""load visual model.
Args:
model_path(str): the path or repo_id from model hub of the model
backend(str): the name of inference backend
with_llm(bool): load LLM model or not. Set it to False for VLM
inference scenarios and True for VLM quantization
backend_config: the config of the inference engine
"""
if not os.path.exists(model_path):
Expand All @@ -49,11 +52,13 @@ def load_vl_model(model_path: str,
download_dir=download_dir)

max_memory = None
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}
if not with_llm:
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}

_, hf_config = get_model_arch(model_path)
kwargs = dict(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory,
hf_config=hf_config,
backend=backend)
Expand All @@ -63,7 +68,9 @@ def load_vl_model(model_path: str,
logger.info(f'matching vision model: {name}')
model = module(**kwargs)
model.build_preprocessor()
if backend == 'turbomind':
# build the vision part of a VLM model when backend is
# turbomind, or load the whole VLM model when `with_llm==True`
if backend == 'turbomind' or with_llm:
model.build_model()
return model
except Exception:
Expand Down
16 changes: 11 additions & 5 deletions lmdeploy/vl/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def build_preprocessor(self):
self.model_path).image_processor

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
warnings.simplefilter('ignore')
model = AutoModelForCausalLM.from_pretrained(self.model_path)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(model,
Expand Down Expand Up @@ -74,11 +78,13 @@ def build_model(self):

from accelerate import load_checkpoint_and_dispatch
with disable_logging():
load_checkpoint_and_dispatch(model=model,
checkpoint=self.model_path,
device_map=device_map,
dtype=torch.half)
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map=device_map if not self.with_llm else {'': 'cpu'},
dtype=torch.half)

self.model = model.eval()
self.vision_model = model.vision_model.eval()
self.aligner = model.aligner.eval()

Expand Down
12 changes: 8 additions & 4 deletions lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ class InternVLVisionModel(VisonModel):

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
hf_config: AutoConfig = None,
backend: str = ''):
super().__init__(model_path, max_memory, hf_config, backend)
super().__init__(model_path, with_llm, max_memory, hf_config, backend)

def build_preprocessor(self):
self.config = self.hf_config
Expand Down Expand Up @@ -124,21 +125,24 @@ def build_preprocessor(self):
(force_image_size // patch_size)**2 * (downsample_ratio**2))

def build_model(self):
"""Load model."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
# transformers below 4.37.0 may raise error about flash_attn
self.config.llm_config.attn_implementation = 'eager'
model = AutoModel.from_config(self.config, trust_remote_code=True)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

model.half()
from accelerate import load_checkpoint_and_dispatch
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
max_memory=self.max_memory,
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)
Expand Down
15 changes: 9 additions & 6 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def build_preprocessor(self):
return super().build_preprocessor()

def build_model(self):
"""build model & load weights."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
Expand All @@ -98,10 +99,12 @@ def build_model(self):
} # disable vision part quantization
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

with init_empty_vit():
vision_tower = model.get_vision_tower()
Expand All @@ -126,7 +129,7 @@ def build_model(self):
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)

Expand Down
17 changes: 10 additions & 7 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def build_preprocessor(self):
self.n_token_per_image += 1

def build_model(self):
"""build model & load weights."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_llava_install()

self.arch = self.hf_config.architectures[0]
Expand Down Expand Up @@ -271,11 +272,13 @@ def build_model(self):
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)

# remove the LLM part from llava model.
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
# remove the LLM part from llava model.
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

# init empty vision_tower, the embedding layer in CLIPVisionModel
# can't init right under init_empty_weights
Expand All @@ -292,7 +295,7 @@ def build_model(self):
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['CLIPEncoderLayer'],
dtype=torch.half)

Expand Down
24 changes: 14 additions & 10 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,30 @@ def build_preprocessor(self):
self.n_token_per_image += 1

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration._from_config(self.hf_config)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

# fix for llava-hf/llava-interleave-qwen-7b-hf
setattr(model.config, 'tie_word_embeddings', False)
with disable_logging():
load_checkpoint_and_dispatch(model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
no_split_module_classes=[
'CLIPEncoderLayer',
'SiglipEncoderLayer'
],
dtype=torch.half)
load_checkpoint_and_dispatch(
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=[
'CLIPEncoderLayer', 'SiglipEncoderLayer'
],
dtype=torch.half)
model.eval()
self.model = model

Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def build_preprocessor(self):
from transformers import LlavaNextForConditionalGeneration
self.model = LlavaNextForConditionalGeneration._from_config(
self.hf_config)
del self.model.language_model
self.vl_model = self.model
if not self.with_llm:
del self.model.language_model

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import load_checkpoint_and_dispatch
from accelerate.utils import get_balanced_memory, infer_auto_device_map

Expand Down Expand Up @@ -58,7 +62,7 @@ def build_model(self):
load_checkpoint_and_dispatch(
model=self.model,
checkpoint=self.model_path,
device_map=device_map,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
self.model.eval()
Expand Down
14 changes: 9 additions & 5 deletions lmdeploy/vl/model/mini_gemeni.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def build_preprocessor(self):
pass

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_mini_gemini_install()
# empty init
from accelerate import init_empty_weights
Expand All @@ -201,10 +203,12 @@ def build_model(self):
vision_tower.load_model()
vision_tower_aux = model.get_vision_tower_aux()
vision_tower_aux.load_model()
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
Expand All @@ -230,7 +234,7 @@ def build_model(self):
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map=device_map,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['CLIPEncoderLayer', 'ConvNeXtStage'],
dtype=torch.half)

Expand Down
Loading

0 comments on commit 3a3dae7

Please sign in to comment.