-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add example to use vision model (#16)
* first commit * remove jupyter notebook checkpoints and init files * remove check attribute method * update readme * update comment in image_input.py * update init method and add examples with several images in notebook * update naming and type for list of image files variables * add docstring for the methods of ImageInput class * update some typo * update module paths in example notebook * rename methods only called in the class ImageInput * add copy method on input variable for the ImageInput class * update typo for naming methods called only in ImageInput class * update typo for naming methods called only in ImageInput class
- Loading branch information
1 parent
b604264
commit b84c3f8
Showing
7 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Image Input | ||
|
||
## Add image in put in a prompt | ||
To avoid repeatedly encoding an image and extracting the related extension, you can create a class inheriting from ```image_input.ImageInput```. | ||
|
||
You need to instantiate the following attributes : list of paths of image files (list) | ||
|
||
## How to use | ||
After deploying your REST API, you just need to first call it with the method encode() and extension(), and then with the following route ```/v1/chat/completions``` and a body as follows: | ||
|
||
``` | ||
{"messages": [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "Summarize all the information included in this image."}, | ||
{ "type": "image_url", | ||
"image_url": { | ||
"url": "data:image/{image_extension[0]};base64,{image_base64[0]}" | ||
} | ||
}, | ||
{ "type": "image_url", | ||
"image_url": { | ||
"url": "data:image/{image_extension[1]};base64,{image_base64[1]}" | ||
} | ||
}, | ||
{ "type": "image_url", | ||
"image_url": { | ||
"url": "data:image/{image_extension[2]};base64,{image_base64[2]}" | ||
} | ||
} | ||
] | ||
} | ||
], | ||
"model": "llava-v1.6-mistral-7b-hf", | ||
"max_tokens": 1024, | ||
"temperature":0.0 | ||
}''' | ||
``` | ||
|
||
## Warning | ||
|
||
Regarding vllm versions until 0.5.3.post1, only one image can be passed per call. | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import base64 | ||
import os | ||
from typing import Union | ||
|
||
|
||
class ImageInput: | ||
""" | ||
Represents a class with specific methods to help to prepare a list of image inputs for vision models api call. | ||
Methods: | ||
__init__(image_paths: list): | ||
Initializes a ImageInput instance with the provided attribute. Raises NotImplementedError if any attribute is None. | ||
encode() -> dict: | ||
Encodes and returns a dictionary representation of the images encoded in base 64. | ||
extension() -> dict: | ||
Returns a dictionary of the extension of each image input (needed in the api call) : | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/png;base64,{base64_images[0]}"} | ||
""" | ||
|
||
def __init__(self, image_paths: list): | ||
self.image_paths: list = image_paths.copy() | ||
self.base64_images = self._encode() | ||
self.extension_images = self._extension() | ||
|
||
def _encode(self) -> dict: | ||
"""Opens the image files and encode it as a base64 string | ||
Returns: | ||
dict : The base 64 of all images | ||
""" | ||
|
||
base64_images = {} | ||
for i in range(len(self.image_paths)): | ||
with open(self.image_paths[i], "rb") as image_file: | ||
base64_images[i] = base64.b64encode(image_file.read()).decode("utf-8") | ||
return base64_images | ||
|
||
def _extension(self) -> dict: | ||
"""Extracts the string of the extension of the image files | ||
Returns: | ||
dict : The extension of all images | ||
""" | ||
|
||
extension_images = {} | ||
for i in range(len(self.image_paths)): | ||
extension_images[i] = os.path.splitext(self.image_paths[i])[1][1:] | ||
return extension_images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "03269354-1f22-4ef9-9eb9-b1825bc95724", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import Image, display, Audio, Markdown\n", | ||
"from happy_vllm.image_input import ImageInput\n", | ||
"\n", | ||
"image_paths = [\"/path/image1.png\", \"/path/image2.png\", \"/path/image3.png\"]\n", | ||
"\n", | ||
"# Preview image for context\n", | ||
"display(Image(image_paths[0]))\n", | ||
"\n", | ||
"# Open the image file and encode it as a base64 string\n", | ||
"\n", | ||
"image_input = ImageInput(image_paths)\n", | ||
"image_base64 = image_input.base64_images\n", | ||
"image_extension = image_input.extension_images" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4b58fd68-7dbf-49eb-9b9e-03a93dd70a3e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import requests\n", | ||
"from dotenv import load_dotenv\n", | ||
"\n", | ||
"load_dotenv()\n", | ||
"\n", | ||
"# Configuration\n", | ||
"URL = os.getenv('HOST')+':'+os.getenv('PORT')+'/'+os.getenv('API_ENDPOINT_PREFIX')\n", | ||
"ENDPOINT = \"/v1/chat/completions\"\n", | ||
"\n", | ||
"# data for the request\n", | ||
"data = f'''{{\n", | ||
" \"messages\": [\n", | ||
" {{\n", | ||
" \"role\": \"user\",\n", | ||
" \"content\": [\n", | ||
" {{\"type\": \"text\", \"text\": \"Summarize all the information included in this image.\"}},\n", | ||
" {{\n", | ||
" \"type\": \"image_url\",\n", | ||
" \"image_url\": {{\n", | ||
" \"url\": \"data:image/{image_extension[0]};base64,{image_base64[0]}\"\n", | ||
" }}\n", | ||
" }},\n", | ||
" {{\n", | ||
" \"type\": \"image_url\",\n", | ||
" \"image_url\": {{\n", | ||
" \"url\": \"data:image/{image_extension[1]};base64,{image_base64[1]}\"\n", | ||
" }}\n", | ||
" }},\n", | ||
" {{\n", | ||
" \"type\": \"image_url\",\n", | ||
" \"image_url\": {{\n", | ||
" \"url\": \"data:image/{image_extension[2]};base64,{image_base64[2]}\"\n", | ||
" }}\n", | ||
" }}\n", | ||
" ]\n", | ||
" }}\n", | ||
" ],\n", | ||
" \"model\": \"llava-v1.6-mistral-7b-hf\",\n", | ||
" \"max_tokens\": 1024,\n", | ||
" \"temperature\":0.0\n", | ||
" }}'''\n", | ||
"\n", | ||
"# Try to send the request\n", | ||
"try:\n", | ||
" response = requests.post(URL+ENDPOINT, data=data)\n", | ||
" response.raise_for_status() # An HTTPError will be raised if an unsuccessful status code is returned\n", | ||
"except requests.RequestException as e:\n", | ||
" raise SystemExit(f\"Failed to make the request. Error: {e}\")\n", | ||
"\n", | ||
"# Print the response\n", | ||
"print(response.json())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import pytest | ||
|
||
from happy_vllm.utils import ImageInput | ||
|
||
|
||
def test_imageinput(): | ||
# Test ImageInput class creation | ||
image_input = ImageInput(image_paths=["./test.jpg"]) | ||
assert image_input.image_paths == ["./test.jpg"] | ||
|
||
# Test ImageInput class encode method | ||
file_base64 = open('./test_image_base64.txt') | ||
test_image_base64 = file_base64.read() | ||
file_base64.close() | ||
expected_base64 = {0: test_image_base64} | ||
assert image_input.base64_images == expected_base64 | ||
|
||
# Test ImageInput class extension method | ||
expected_extension = {0: 'jpg'} | ||
assert image_input.extension_images == expected_extension |