Skip to content

Commit

Permalink
0.25.0 +Florence
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Jun 19, 2024
1 parent 2d9d3b8 commit d087b6c
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 112 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview`
- - [X] [idefics2-8b-chatty-AWQ](https://huggingface.co/HuggingFaceM4/idefics2-8b-chatty-AWQ) (wont gpu split)
- [X] [Microsoft](https://huggingface.co/microsoft/)
- - [X] [Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
- - [X] [Florence-2-large-ft](https://huggingface.co/microsoft/Florence-2-large-ft) (wont gpu split)
- - [X] [Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft) (wont gpu split)
- [X] [failspy](https://huggingface.co/failspy)
- - [X] [Phi-3-vision-128k-instruct-abliterated-alpha](https://huggingface.co/failspy/Phi-3-vision-128k-instruct-abliterated-alpha)
- [X] [qihoo360](https://huggingface.co/qihoo360)
Expand Down Expand Up @@ -103,6 +105,11 @@ See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_le

## Recent updates

Version 0.25.0

- New model support: microsoft/Florence family of models. Not a chat model, but simple questions are ok and all commands are functional. ex "<MORE_DETAILED_CAPTION>", "<OCR>", "<OD>", etc.
- Improved error handling & logging

Version 0.24.1

- Compatibility: Support generation without images for most models. (llava based models still require an image)
Expand Down
4 changes: 2 additions & 2 deletions backend/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
from huggingface_hub import snapshot_download

from loguru import logger
from vision_qna import *

# BAAI/Emu2-Chat
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_p
self.model = load_checkpoint_and_dispatch(self.model, checkpoint=checkpoint, device_map=device_map).eval()

# self.model.device/dtype are overloaded with some other object
print(f"Loaded {model_id} on device: {self.device} with dtype: {self.params['torch_dtype']}")
logger.info(f"Loaded {model_id} on device: {self.device} with dtype: {self.params['torch_dtype']}")

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
images, prompt, system = await emu_images_prompt_system_from_messages(request.messages)
Expand Down
63 changes: 63 additions & 0 deletions backend/florence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from transformers import AutoProcessor, AutoModelForCausalLM

from vision_qna import *

# microsoft/Florence-2-large-ft
# microsoft/Florence-2-base-ft

def select_task(prompt):
tasks = ["<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>", "<OCR>", # simple tasks
"<OD>", "<DENSE_REGION_CAPTION>", "<REGION_PROPOSAL>", "<CAPTION_TO_PHRASE_GROUNDING>",
"<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>", "<OPEN_VOCABULARY_DETECTION>",
"<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>", "<OCR_WITH_REGION>"
]
for task in tasks:
if task in prompt:
return task

return None

class VisionQnA(VisionQnABase):
model_name: str = "florence"
format: str = "florence"
visual_layers: List[str] = ['vision_tower', 'image_proj_norm', 'image_pos_embed', 'visual_temporal_embed']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

if not format:
self.format = guess_model_format(model_id)

self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def chat_with_images(self, request: ImageChatRequest) -> str:
images, prompt = await prompt_from_messages(request.messages, self.format)

inputs = self.processor(text=prompt, images=images[0], return_tensors="pt").to(device=self.model.device, dtype=self.model.dtype)

default_params = {
'do_sample': False,
'num_beams': 3,
}

params = self.get_generation_params(request, default_params=default_params)

generation_kwargs = dict(
**inputs,
**params,
)

generated_ids = self.model.generate(**generation_kwargs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = self.processor.post_process_generation(generated_text, task=select_task(prompt), image_size=(images[0].width, images[0].height))

for k, v in parsed_answer.items():
return str(v)

4 changes: 4 additions & 0 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
["llava-hf/llava-v1.6-vicuna-13b-hf", "--use-flash-attn"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn", "--load-in-4bit"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn"],
["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0"],
["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0"],
["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn", "--load-in-4bit"],
["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn"],
["openbmb/MiniCPM-Llama3-V-2_5", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
Expand Down
103 changes: 101 additions & 2 deletions openedai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,72 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from fastapi.responses import PlainTextResponse, JSONResponse
from loguru import logger

class OpenAIError(Exception):
pass

class APIError(OpenAIError):
message: str
code: str = None
param: str = None
type: str = None

def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''):
super().__init__(message)
self.message = message
self.code = code
self.param = param
self.type = self.__class__.__name__,
self.internal_message = internal_message

def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
self.message,
self.code,
self.param,
)

class InternalServerError(APIError):
pass

class ServiceUnavailableError(APIError):
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
super().__init__(message, code, internal_message)

class APIStatusError(APIError):
status_code: int = 400

def __init__(self, message: str, param: str = None, internal_message: str = ''):
super().__init__(message, self.status_code, param, internal_message)

class BadRequestError(APIStatusError):
status_code: int = 400

class AuthenticationError(APIStatusError):
status_code: int = 401

class PermissionDeniedError(APIStatusError):
status_code: int = 403

class NotFoundError(APIStatusError):
status_code: int = 404

class ConflictError(APIStatusError):
status_code: int = 409

class UnprocessableEntityError(APIStatusError):
status_code: int = 422

class RateLimitError(APIStatusError):
status_code: int = 429

class OpenAIStub(FastAPI):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.models = {}

self.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -16,6 +75,46 @@ def __init__(self, **kwargs) -> None:
allow_headers=["*"]
)

@self.exception_handler(Exception)
def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse:
# Generic server errors
#logger.opt(exception=exc).error("Logging exception traceback")

return JSONResponse(status_code=500, content={
'message': 'InternalServerError',
'code': 500,
})

@self.exception_handler(APIError)
def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse:
# Server error
logger.opt(exception=exc).error("Logging exception traceback")

if exc.internal_message:
logger.info(exc.internal_message)

return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})

@self.exception_handler(APIStatusError)
def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse:
# Client side error
logger.info(repr(exc))

if exc.internal_message:
logger.info(exc.internal_message)

return JSONResponse(status_code = exc.code, content={
'message': exc.message,
'code': exc.code,
'type': exc.__class__.__name__,
'param': exc.param,
})

@self.middleware("http")
async def log_requests(request: Request, call_next):
logger.debug(f"Request path: {request.url.path}")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ https://github.com/bdashore3/flash-attention/releases/download/v2.5.9.post1/flas
flash_attn; python_version != "3.10" and python_version != "3.11"
hf_transfer
loguru
numpy<2
openai
peft
protobuf
Expand Down
5 changes: 2 additions & 3 deletions vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ async def streamer():
}
}

if os.environ.get('OPENEDAI_DEBUG', False):
print(f'Response: {vis_chat_resp}')
logger.debug(f'Response: {vis_chat_resp}')

return vis_chat_resp

Expand Down Expand Up @@ -141,7 +140,7 @@ def parse_args(argv=None):
if not args.backend:
args.backend = guess_backend(args.model)

print(f"Loading VisionQnA[{args.backend}] with {args.model}")
logger.info(f"Loading VisionQnA[{args.backend}] with {args.model}")
backend = importlib.import_module(f'backend.{args.backend}')

extra_params = {}
Expand Down
Loading

0 comments on commit d087b6c

Please sign in to comment.