Skip to content

Commit

Permalink
fix: phi_3.5_v loading (#896)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 16, 2024
1 parent e14223d commit 908ff75
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 10 deletions.
3 changes: 3 additions & 0 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aphrodite.platforms import current_platform
from aphrodite.quantization import QUANTIZATION_METHODS
from aphrodite.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from aphrodite.triton_utils import HAS_TRITON

Expand Down Expand Up @@ -203,6 +204,8 @@ def __init__(
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

# Choose a default enforce_eager value if the user did not specify
Expand Down
11 changes: 9 additions & 2 deletions aphrodite/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from array import array
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
Tuple, Type)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
Protocol, Tuple, Type)

from loguru import logger
from torch import nn
Expand Down Expand Up @@ -49,6 +49,13 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:

return hf_config

def get_hf_image_processor_config(self) -> Dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config



N = TypeVar("N", bound=Type[nn.Module])

Expand Down
12 changes: 6 additions & 6 deletions aphrodite/modeling/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# limitations under the License.
import re
from functools import lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)

import numpy as np
import torch
Expand Down Expand Up @@ -320,12 +320,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):

# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: PretrainedConfig,
hf_config: Dict[str, Any],
*,
input_height: int,
input_width: int,
) -> int:
num_crops = getattr(hf_config, "num_crops", 16)
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
Expand All @@ -337,7 +337,7 @@ def get_phi3v_image_feature_size(
def get_max_phi3v_image_tokens(ctx: InputContext):

return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
Expand Down Expand Up @@ -391,7 +391,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_image_processor_config()

image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
Expand Down
13 changes: 13 additions & 0 deletions aphrodite/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
try_to_load_from_cache)
from loguru import logger
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
Expand Down Expand Up @@ -243,6 +245,17 @@ def recurse_elems(elem: Any):
return config


def get_hf_image_processor_config(
model: Union[str, Path],
revision: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
# Separate model folder from file path for GGUF models
if Path(model).is_file() and Path(model).suffix == ".gguf":
model = Path(model).parent
return get_image_processor_config(model, revision=revision, **kwargs)


def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
Expand Down
2 changes: 1 addition & 1 deletion examples/vision/vision_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run_phi3v(question):
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_num_seqs=5,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
})

models = ["microsoft/Phi-3-vision-128k-instruct"]
models = ["microsoft/Phi-3.5-vision-instruct"]


def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,
Expand Down

0 comments on commit 908ff75

Please sign in to comment.