Skip to content

Commit

Permalink
Merge pull request #217 from mistralai/pixtral
Browse files Browse the repository at this point in the history
Pixtral
  • Loading branch information
patrickvonplaten authored Sep 13, 2024
2 parents 3fd585d + 510f7ae commit 4304e4f
Show file tree
Hide file tree
Showing 10 changed files with 576 additions and 160 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistral_inference"
version = "1.3.1"
version = "1.4.0"
description = ""
authors = ["bam4d <[email protected]>"]
readme = "README.md"
Expand All @@ -27,8 +27,9 @@ python = "^3.9.10"
xformers = ">=0.0.24"
simple-parsing = ">=0.1.5"
fire = ">=0.6.0"
mistral_common = "^1.3.0"
mistral_common = ">=1.4.0"
safetensors = ">=0.4.0"
pillow = ">=10.3.0"

[tool.poetry.group.dev.dependencies]
types-protobuf = "4.24.0.20240129"
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.1"
__version__ = "1.4.0"
19 changes: 17 additions & 2 deletions src/mistral_inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
from mistral_inference.moe import MoeArgs


@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float = 1e4 # for rope-2D
image_token_id: int = 10


@dataclass
class TransformerArgs(Serializable):
dim: int
Expand All @@ -28,7 +41,9 @@ class TransformerArgs(Serializable):
lora: Optional[LoraArgs] = None
model_type: str = "transformer"

def __post_init__(self):
vision_encoder: Optional[VisionEncoderArgs] = None

def __post_init__(self) -> None:
assert self.model_type == "transformer", self.model_type


Expand All @@ -45,5 +60,5 @@ class MambaArgs(Serializable):
tie_embeddings: bool
model_type: str = "mamba"

def __post_init__(self):
def __post_init__(self) -> None:
assert self.model_type == "mamba", self.model_type
13 changes: 13 additions & 0 deletions src/mistral_inference/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Tuple

import numpy as np
import torch

from mistral_inference.cache import BufferCache
Expand Down Expand Up @@ -43,12 +44,21 @@ def generate_mamba(
def generate(
encoded_prompts: List[List[int]],
model: Transformer,
images: List[List[np.ndarray]] = [],
*,
max_tokens: int,
temperature: float,
chunk_size: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tuple[List[List[int]], List[List[float]]]:
images_torch: List[List[torch.Tensor]] = []
if images:
assert chunk_size is None
images_torch = [
[torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample]
for images_for_sample in images
]

model = model.eval()
B, V = len(encoded_prompts), model.args.vocab_size

Expand All @@ -75,12 +85,15 @@ def generate(
if chunk_size is None:
chunk_size = max_prompt_len

flattened_images: List[torch.Tensor] = sum(images_torch, [])

# Encode prompt by chunks
for s in range(0, max_prompt_len, chunk_size):
prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts]
assert all(len(p) > 0 for p in prompt_chunks)
prelogits = model.forward(
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
images=flattened_images,
seqlens=[len(p) for p in prompt_chunks],
cache=cache,
)
Expand Down
68 changes: 60 additions & 8 deletions src/mistral_inference/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,31 @@
import os
import warnings
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import List, Optional, Tuple, Type, Union

import fire # type: ignore
import torch
import torch.distributed as dist
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
ContentChunk,
ImageChunk,
ImageURLChunk,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import Tokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer, SpecialTokenPolicy
from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
from mistral_common.tokens.tokenizers.tekken import is_tekken

from mistral_common.tokens.tokenizers.tekken import (
SpecialTokenPolicy,
Tekkenizer,
is_tekken,
)
from PIL import Image

from mistral_inference.args import TransformerArgs
from mistral_inference.generate import generate, generate_mamba
from mistral_inference.mamba import Mamba
from mistral_inference.transformer import Transformer
Expand Down Expand Up @@ -62,6 +74,31 @@ def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> Li
return padded_lists


def _get_multimodal_input() -> Tuple[UserMessage, bool]:
chunks: List[ContentChunk] = []

response = input("Text prompt: ")
if response:
chunks.append(TextChunk(text=response))

print("[You can input zero, one or more images now.]")
while True:
did_something = False
response = input("Image path or url [Leave empty and press enter to finish image input]: ")
if response:
if Path(response).is_file():
chunks.append(ImageChunk(image=Image.open(response)))
else:
assert response.startswith("http"), f"{response} does not seem to be a valid url."
chunks.append(ImageURLChunk(image_url=response))
did_something = True

if not did_something:
break

return UserMessage(content=chunks), not chunks


def interactive(
model_path: str,
max_tokens: int = 35,
Expand All @@ -85,6 +122,10 @@ def interactive(

model_cls = get_model_cls(model_path)
model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
is_multimodal = isinstance(model.args, TransformerArgs) and model.args.vision_encoder is not None

if is_multimodal:
assert instruct, "Multimodal models should only be used in instruct mode"

# load LoRA
if lora_path is not None:
Expand All @@ -95,17 +136,27 @@ def interactive(

while True:
if should_print:
user_input = input("Prompt: ")
if not is_multimodal:
user_input = input("Prompt: ")

if instruct:
messages += [UserMessage(content=user_input)]
if is_multimodal:
mm_input, finished = _get_multimodal_input()
if finished:
break
messages += [mm_input]
else:
messages += [UserMessage(content=user_input)]
chat_completion_request = ChatCompletionRequest(messages=messages)

tokens = mistral_tokenizer.encode_chat_completion(chat_completion_request).tokens
tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
tokens = tokenized.tokens
images = tokenized.images
else:
prompt += user_input

tokens = tokenizer.encode(prompt, bos=True, eos=False)
images = []

length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
else:
Expand All @@ -121,6 +172,7 @@ def interactive(
generated_tokens, _ = generate_fn( # type: ignore[operator]
[tokens],
model,
[images],
max_tokens=max_tokens,
temperature=temperature,
eos_id=tokenizer.eos_id,
Expand Down
32 changes: 30 additions & 2 deletions src/mistral_inference/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,34 @@ def apply_rotary_emb(
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[:, None, :]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)


def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
(height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))

h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)

freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
Loading

0 comments on commit 4304e4f

Please sign in to comment.