Skip to content

Commit

Permalink
fix molmo
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Dec 6, 2024
1 parent 8f7a56f commit 29c3558
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions lmdeploy/vl/model/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MolmoVisionModel(VisonModel):
def build_preprocessor(self):
self.processor = AutoProcessor.from_pretrained(self.model_path,
trust_remote_code=True,
torch_dtype='auto',
torch_dtype=torch.half,
device_map='auto')

def build_model(self):
Expand All @@ -46,7 +46,8 @@ def build_model(self):
max_memory=self.max_memory,
no_split_module_classes=[
'ResidualAttentionBlock', 'Embedding'
])
],
dtype=torch.half)

# We need eval mode to freeze the weights in model, thus,
# avoid randomness in inference.
Expand All @@ -58,8 +59,7 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
if not isinstance(message['content'], List):
continue
images = [
x['image'].convert('RGB') for x in message['content']
if x['type'] == 'image'
x['image'] for x in message['content'] if x['type'] == 'image'
]
content = [
x['text'] for x in message['content'] if x['type'] == 'text'
Expand Down Expand Up @@ -111,6 +111,8 @@ def forward(self, messages: List[Dict]) -> List[Dict]:
assert batch_size == 1
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
embeddings = self.model.model.transformer.wte(input_ids)
images = images.to(self.model.dtype)
image_masks = image_masks.to(self.model.dtype)
image_features, _ = self.model.model.vision_backbone(
images, image_masks)
num_image, num_patch = image_features.shape[1:3]
Expand All @@ -126,8 +128,11 @@ def forward(self, messages: List[Dict]) -> List[Dict]:
batch_idx = torch.tile(batch_idx[:, None],
[1, image_features.shape[1]])
image_features = image_features.to(embeddings.device)
embeddings[batch_idx[valid],
image_input_idx[valid]] += image_features[valid]
# Since we remove bos_id from input_ids during `preprocess`,
# the index `image_input_idx[valid]` should be shift to left
# by subtracting 1
index = image_input_idx[valid] - 1
embeddings[batch_idx[valid], index] += image_features[valid]
assert embeddings.shape[:2] == (batch_size, seq_len)
messages[i].update(
dict(forward=dict(input_ids=input_ids.flatten(),
Expand All @@ -139,8 +144,6 @@ def proc_messages(cls, messages):
IMAGE_TOKEN = '<IMAGE_TOKEN>'
for message in messages:
role, content = message['role'], message['content']
if role == 'images':
continue
if isinstance(content, List):
n_images = len([1 for x in content if x['type'] == 'image'])
content = [x['text'] for x in content if x['type'] == 'text']
Expand Down Expand Up @@ -170,22 +173,25 @@ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start):
results.append(([bos], None))

for i, message in enumerate(messages):
prompt = ''
role, content = message['role'], message['content']
if role == 'images':
continue
if isinstance(content, List):
forward_result = message.pop('forward')
input_ids = forward_result['input_ids']
embeddings = forward_result['embeddings']
results.append((input_ids.tolist(), embeddings))
else:
prompt = ''
if role == 'user':
prompt = f' User: {content}'
elif role == 'assistant':
prompt = f' Assistant:{content}'
else:
assert 0, f'molmo does not support role {role}, message is {message}' # noqa
if i == len(messages) - 1:
# the last message
assert role == 'user', f'the role of last message is expected to be user, but got {role}' # noqa
prompt += ' Assistant:'
if prompt:
input_ids = self.processor.tokenizer.encode(
prompt, add_special_tokens=False)
results.append((input_ids, None))
Expand Down

0 comments on commit 29c3558

Please sign in to comment.