Skip to content

Commit

Permalink
Merge branch 'refactor-vl' of github.com:InternLM/lmdeploy into refac…
Browse files Browse the repository at this point in the history
…tor-vl
  • Loading branch information
grimoire committed Dec 9, 2024
2 parents a65007b + 74e7bf8 commit 88f99d4
Show file tree
Hide file tree
Showing 18 changed files with 345 additions and 692 deletions.
151 changes: 10 additions & 141 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import queue
import time
from threading import Thread
from typing import Dict, List, Optional, Union

import torch
from PIL.Image import Image

from lmdeploy.messages import (PytorchEngineConfig, TurbomindEngineConfig,
VisionConfig)
Expand All @@ -26,58 +22,6 @@ def _raise_exception_on_finish(task: asyncio.Task) -> None:
raise e


class Record:
"""Batching manager."""

def __init__(self, thread_safe):
self.thread_safe = thread_safe
self.number = []
self.waiting = []
self.kwargs = []
self.done = []
self.res_que = []
self.total = 0

def enqueue(self, images: List[Image], kwargs: List[Dict],
que: Union[queue.Queue, asyncio.Queue]):
"""add ith request to manager."""
self.number.append(len(images))
self.waiting.extend(images)
self.kwargs.extend(kwargs)
self.res_que.append(que)
self.total += len(images)
self.log('received', len(images))

def dequeue(self, max_batch_size):
"""try to dequeue max batch size images."""
inputs = self.waiting[:max_batch_size]
kwargs = self.kwargs[:max_batch_size]
self.waiting = self.waiting[max_batch_size:]
self.kwargs = self.kwargs[max_batch_size:]
self.total -= len(inputs)
self.log('process', len(inputs))
return inputs, kwargs

def notify(self):
"""set result if request i is finished."""
if len(self.number) == 0 or self.number[0] > len(self.done):
return False
num_images = self.number.pop(0)
outputs = self.done[:num_images]
self.done = self.done[num_images:]
que = self.res_que.pop(0)
self.log('done', num_images)
if self.thread_safe:
que._loop.call_soon_threadsafe(que.put_nowait, outputs)
else:
que.put_nowait(outputs)
return True

def log(self, task: str, num: int):
logger.info(f'ImageEncoder {task} {num} images, '
f'left {self.total} images.')


class ImageEncoder:
"""Image encoder."""

Expand All @@ -97,93 +41,14 @@ def __init__(
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
torch.cuda.empty_cache()
self._que: asyncio.Queue = None
self._loop_task: asyncio.Task = None
if vision_config.thread_safe:
self._create_thread_safe_task()

def _create_thread_safe_task(self):
"""thread safe loop task."""
self._loop = asyncio.new_event_loop()
self._que = asyncio.Queue()

def _work_thread():
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._forward_loop())

thread = Thread(target=_work_thread, daemon=True)
thread.start()
self._loop_thread = thread

def _create_event_loop_task(self):
"""event loop task."""
task = asyncio.get_event_loop().create_task(self._forward_loop())
self._loop_task = task
self._loop = task.get_loop()

@property
def req_que(self):
if self.vision_config.thread_safe:
return self._que
if self._que is None:
self._que = asyncio.Queue()
if self._loop_task is None:
self._create_event_loop_task()
if asyncio.get_event_loop() != self._loop:
raise RuntimeError('Current event loop is different from'
' the one bound to loop task!')
return self._que

async def _forward_loop(self):
"""working loop to process images."""
logger.info('start ImageEncoder._forward_loop')
record = Record(self.vision_config.thread_safe)
while True:
while record.total == 0 or (self._que.qsize() and
record.total < self.max_batch_size):
while self._que.qsize() == 0:
await asyncio.sleep(0.01)
item = await self._que.get()
record.enqueue(item[0], item[1], item[2])
inputs, kwargs = record.dequeue(self.max_batch_size)
future = asyncio.get_event_loop().run_in_executor(
None, self.forward, inputs, kwargs)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
record.done.extend(outputs)
while record.notify():
pass

def forward(self, messages: List[List[Dict]]) -> Dict:
# messages in batch
assert all(isinstance(message, List) for message in messages)

time_start = time.perf_counter()
outputs, n_image = self.model.forward(messages)
if isinstance(outputs[0], torch.Tensor):
outputs = [x.cpu() for x in outputs]
time_end = time.perf_counter()
logger.info(f'ImageEncoder forward {n_image} images, '
f'cost {time_end - time_start:.3f}s')
return outputs

def infer(self, messages: List[Dict]) -> Dict:
"""perform vision encoding to get a dict, in which there are input_ids,
embeddings, embedding_ranges and so on. They will be used by turbomind
engine. The key in the dict must be the same defined in turbmoind
engine's infer API.
Args:
messages (List[Dict]): user's input in GPT4V format
"""
assert isinstance(messages, List)
assert all(isinstance(item, Dict) for item in messages)

return self.forward(messages)

async def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""preprocess multimodal data in the messages."""
return self.model.preprocess(messages)
future = asyncio.get_event_loop().run_in_executor(
None, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs

async def async_infer(self, messages: List[Dict]) -> List[Dict]:
"""get multimodal embedding.
Expand All @@ -192,7 +57,11 @@ async def async_infer(self, messages: List[Dict]) -> List[Dict]:
messages (List[Dict]): a list of message, which is the output
of `preprocess()`
"""
return self.model.forward(messages)
future = asyncio.get_event_loop().run_in_executor(
None, self.model.forward, messages, self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs

async def wrap_for_pytorch(self, messages: List[Dict], chat_template,
tokenizer, sequence_start) -> List[Dict]:
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def build_preprocessor(self, ):
"""
raise NotImplementedError()

@abstractmethod
def build_model(self, ):
"""build model.
Expand Down Expand Up @@ -90,21 +89,23 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
""" # noqa
raise NotImplementedError()

@abstractmethod
def forward(self, messages: List[Dict]) -> List[Dict]:
def forward(self,
messages: List[Dict],
max_batch_size: int = 1) -> List[Dict]:
"""extract image feature. ONLY implement it when the backend is
turbomind engine.
Args:
messages(List[Dict]): the outputs of `preprocess`
max_batch_size(int): the max batch size when forwarding vision
model
Return:
the message list with forwarding results included, which is
determined by the derived classes
"""
if self.backend == 'turbomind':
raise NotImplementedError()

@abstractmethod
def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
"""pack the preprocessing results in a format compatible with what is
required by pytorch engine. ONLY implement it when the backend is
Expand All @@ -119,7 +120,6 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
if self.backend == 'pytorch':
raise NotImplementedError()

@abstractmethod
def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
"""pack the forwarding results in a format compatible with what is
required by turbomind engine. ONLY implement it when the backend is
Expand Down
63 changes: 0 additions & 63 deletions lmdeploy/vl/model/cogvlm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')

Expand All @@ -32,49 +27,6 @@ def build_preprocessor(self):
patch_size = self.hf_config.vision_config['patch_size']
self.n_token_per_image = 2 + (image_size // patch_size // 2)**2

def build_model(self):
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils import get_balanced_memory, infer_auto_device_map
with init_empty_weights(), warnings.catch_warnings():
self.model = AutoModelForCausalLM.from_config(
self.hf_config, trust_remote_code=True)
if not self.with_llm:
del self.model.lm_head
for key in ['layers', 'norm', 'embed_tokens']:
setattr(self.model.model, key, None)
else:
self.vl_model = self.model

no_split_module_classes = ['TransformerLayer']
max_memory = get_balanced_memory(
self.model,
max_memory=self.max_memory,
dtype=torch.half,
no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
self.model,
no_split_module_classes=no_split_module_classes,
max_memory=max_memory,
dtype=torch.half)
same_device_keys = [('model.vision.linear_proj', 'model.vision.boi',
'model.vision.eoi')]
for keys in same_device_keys:
keys = [k for k in keys if k in device_map]
if len(keys) <= 1:
continue
for k in keys[1:]:
device_map[k] = device_map[keys[0]]

with disable_logging():
load_checkpoint_and_dispatch(
model=self.model,
checkpoint=self.model_path,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
self.model = self.model.model.vision
self.model.eval()

def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""refer to the spec of `super().preprocess`"""
images = self.collect_images(messages)
Expand All @@ -90,18 +42,6 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
messages.append(dict(role='preprocess', content=outputs))
return messages

@torch.no_grad()
def forward(self, messages: List[Dict]) -> List[Dict]:
"""extract image feature. ONLY implement it when the backend is
turbomind engine.
Args:
messages(List[Dict]): the outputs of `preprocess`
Return:
the message list with forwarding results included
"""
assert 0, 'cogvlm is not supported by turbomind'

@classmethod
def proc_messages(cls, messages, chat_template, sequence_start):
"""apply chat template to get the prompt."""
Expand Down Expand Up @@ -145,6 +85,3 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start):
sequence_start)
return super().to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer,
sequence_start)

def to_turbomind(self, messages, chat_template, sequence_start):
assert 0, 'cogvlm is not supported by turbomind'
29 changes: 19 additions & 10 deletions lmdeploy/vl/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,35 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
return messages

@torch.no_grad()
def forward(self, messages: List[Dict]) -> List[Dict]:
def forward(self,
messages: List[Dict],
max_batch_size: int = 1) -> List[Dict]:
"""extract image feature. ONLY implement it when the backend is
turbomind engine.
Args:
messages(List[Dict]): the outputs of `preprocess`
max_batch_size(int): the max batch size when forwarding vision
model
Return:
the message list with forwarding results included
"""
inputs = [x['content'] for x in messages if x['role'] == 'preprocess']
inputs = inputs[0]
pixel_values = [x['pixel_values'] for x in inputs]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=next(
self.vision_model.parameters()).device,
dtype=torch.float16)
# [b x n_images, T2, D]
images_embeds = self.aligner(self.vision_model(pixel_values))
outputs = torch.split(images_embeds, 1, dim=0)
outputs = [x.squeeze() for x in outputs]
outputs = []
for idx in range(0, len(inputs), max_batch_size):
pixel_values = [
x['pixel_values'] for x in inputs[idx:idx + max_batch_size]
]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=next(
self.vision_model.parameters()).device,
dtype=torch.float16)
# [b x n_images, T2, D]
logger.info(f'vision forward shape: {pixel_values.shape}')
feats = self.aligner(self.vision_model(pixel_values))
feats = torch.split(feats, 1, dim=0)
outputs.extend([x.squeeze() for x in feats])
messages.append(dict(role='forward', content=outputs))
return messages

Expand Down
Loading

0 comments on commit 88f99d4

Please sign in to comment.