diff --git a/src/mistral_common/multimodal.py b/src/mistral_common/multimodal.py index a167a77..88224c3 100644 --- a/src/mistral_common/multimodal.py +++ b/src/mistral_common/multimodal.py @@ -7,11 +7,14 @@ from pydantic import BeforeValidator, PlainSerializer, SerializationInfo from typing_extensions import Annotated +from mistral_common import __version__ + def download_image(url: str) -> Image.Image: + headers = {"User-Agent": f"mistral-common/{__version__}"} try: # Make a request to download the image - response = requests.get(url) + response = requests.get(url, headers=headers) response.raise_for_status() # Raise an error for bad responses (4xx, 5xx) # Convert the image content to a PIL Image diff --git a/src/mistral_common/tokens/tokenizers/base.py b/src/mistral_common/tokens/tokenizers/base.py index 3853a4b..b51f917 100644 --- a/src/mistral_common/tokens/tokenizers/base.py +++ b/src/mistral_common/tokens/tokenizers/base.py @@ -150,8 +150,7 @@ def __call__(self, content: Union[ImageChunk, ImageURLChunk]) -> ImageEncoding: ... @property - def image_token(self) -> int: - ... + def image_token(self) -> int: ... class InstructTokenizer(Generic[InstructRequestType, FIMRequestType, TokenizedType, AssistantMessageType]): @@ -182,8 +181,7 @@ def encode_user_message( is_first: bool, system_prompt: Optional[str] = None, force_img_first: bool = False, - ) -> Tuple[List[int], List[np.ndarray]]: - ... + ) -> Tuple[List[int], List[np.ndarray]]: ... @abstractmethod def encode_user_content( @@ -192,5 +190,4 @@ def encode_user_content( is_last: bool, system_prompt: Optional[str] = None, force_img_first: bool = False, - ) -> Tuple[List[int], List[np.ndarray]]: - ... + ) -> Tuple[List[int], List[np.ndarray]]: ... diff --git a/tests/test_multimodal.py b/tests/test_multimodal.py index c385e45..ffb7161 100644 --- a/tests/test_multimodal.py +++ b/tests/test_multimodal.py @@ -36,6 +36,19 @@ def test_image_to_num_tokens(mm_config: MultimodalConfig, special_token_ids: Spe assert mm_encoder._image_to_num_tokens(img) == (exp1, exp2) +def test_download_gated_image(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None: + mm_encoder = ImageEncoder(mm_config, special_token_ids) + + url1 = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" + url2 = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" + + for url in [url1, url2]: + content = ImageURLChunk(image_url=url) + image = mm_encoder(content).image + + assert image is not None, "Make sure gated wikipedia images can be downloaded" + + def test_image_encoder(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None: mm_encoder = ImageEncoder(mm_config, special_token_ids)