Skip to content

Commit

Permalink
Merge pull request #51 from ywang96/add-header-to-download
Browse files Browse the repository at this point in the history
Add headers to download images
  • Loading branch information
patrickvonplaten authored Sep 18, 2024
2 parents 88bee5e + 5ee6372 commit 1b56550
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/mistral_common/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from pydantic import BeforeValidator, PlainSerializer, SerializationInfo
from typing_extensions import Annotated

from mistral_common import __version__


def download_image(url: str) -> Image.Image:
headers = {"User-Agent": f"mistral-common/{__version__}"}
try:
# Make a request to download the image
response = requests.get(url)
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise an error for bad responses (4xx, 5xx)

# Convert the image content to a PIL Image
Expand Down
9 changes: 3 additions & 6 deletions src/mistral_common/tokens/tokenizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def __call__(self, content: Union[ImageChunk, ImageURLChunk]) -> ImageEncoding:
...

@property
def image_token(self) -> int:
...
def image_token(self) -> int: ...


class InstructTokenizer(Generic[InstructRequestType, FIMRequestType, TokenizedType, AssistantMessageType]):
Expand Down Expand Up @@ -182,8 +181,7 @@ def encode_user_message(
is_first: bool,
system_prompt: Optional[str] = None,
force_img_first: bool = False,
) -> Tuple[List[int], List[np.ndarray]]:
...
) -> Tuple[List[int], List[np.ndarray]]: ...

@abstractmethod
def encode_user_content(
Expand All @@ -192,5 +190,4 @@ def encode_user_content(
is_last: bool,
system_prompt: Optional[str] = None,
force_img_first: bool = False,
) -> Tuple[List[int], List[np.ndarray]]:
...
) -> Tuple[List[int], List[np.ndarray]]: ...
13 changes: 13 additions & 0 deletions tests/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def test_image_to_num_tokens(mm_config: MultimodalConfig, special_token_ids: Spe
assert mm_encoder._image_to_num_tokens(img) == (exp1, exp2)


def test_download_gated_image(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None:
mm_encoder = ImageEncoder(mm_config, special_token_ids)

url1 = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
url2 = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"

for url in [url1, url2]:
content = ImageURLChunk(image_url=url)
image = mm_encoder(content).image

assert image is not None, "Make sure gated wikipedia images can be downloaded"


def test_image_encoder(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None:
mm_encoder = ImageEncoder(mm_config, special_token_ids)

Expand Down

0 comments on commit 1b56550

Please sign in to comment.