Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fp16 inference support (trt) #875

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install open-clip-torch==2.7.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please specify the open-clip-torch and tensorrt version in setup

- name: Test
id: test
run: |
Expand Down Expand Up @@ -159,11 +160,13 @@ jobs:
pip install -e "server/[tensorrt]"
pip install -e "server/[onnx]"
pip install -e "server/[transformers]"
pip install nvidia-tensorrt==8.4.1.5
{
pip install -e "server/[flash-attn]"
} || {
echo "flash attention was not installed."
}
pip install open-clip-torch==2.7.0
- name: Test
id: test
run: |
Expand Down
6 changes: 5 additions & 1 deletion server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
access_paths: str = '@r',
dtype: Optional[str] = 'fp32',
**kwargs,
):
"""
Expand All @@ -36,6 +37,7 @@ def __init__(
number if you encounter OOM errors.
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
:param dtype: inference data type, defaults to 'fp32'.
"""
super().__init__(**kwargs)

Expand All @@ -51,6 +53,7 @@ def __init__(
self._access_paths = kwargs['traversal_paths']

self._device = device
self._dtype = dtype

import torch

Expand All @@ -63,7 +66,7 @@ def __init__(
torch.cuda.is_available()
), "CUDA/GPU is not available on Pytorch. Please check your CUDA installation"

self._model = CLIPTensorRTModel(name)
self._model = CLIPTensorRTModel(name=name, dtype=dtype)

self._model.start_engines()

Expand All @@ -85,6 +88,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
37 changes: 25 additions & 12 deletions server/clip_server/model/clip_trt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, Optional

try:
import tensorrt as trt
Expand Down Expand Up @@ -51,6 +51,7 @@ class CLIPTensorRTModel(BaseCLIPModel):
def __init__(
self,
name: str,
dtype: Optional[str] = 'fp32',
):
super().__init__(name)

Expand All @@ -59,23 +60,35 @@ def __init__(
f'~/.cache/clip/{name.replace("/", "-").replace("::", "-")}'
)

self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.trt',
)
if dtype == 'fp16':
self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.fp16.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.fp16.trt',
)
else:
self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.trt',
)

if not os.path.exists(self._textual_path) or not os.path.exists(
self._visual_path
):
from clip_server.model.clip_onnx import CLIPOnnxModel

fp16 = dtype == 'fp16'

trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
onnx_model = CLIPOnnxModel(name)
onnx_model = CLIPOnnxModel(name=name, dtype=dtype)

visual_engine = build_engine(
runtime=runtime,
Expand All @@ -95,7 +108,7 @@ def __init__(
onnx_model.image_size,
),
workspace_size=10000 * 1024 * 1024,
fp16=False,
fp16=fp16,
int8=False,
)
save_engine(visual_engine, self._visual_path)
Expand All @@ -108,7 +121,7 @@ def __init__(
optimal_shape=(768, 77),
max_shape=(1024, 77),
workspace_size=10000 * 1024 * 1024,
fp16=False,
fp16=fp16,
int8=False,
)
save_engine(text_engine, self._textual_path)
Expand Down
2 changes: 1 addition & 1 deletion server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
'torchvision<=0.13.0' if sys.version_info <= (3, 7, 2) else 'torchvision',
'jina>=3.12.0',
'prometheus-client',
'open_clip_torch>=2.7.0',
'open_clip_torch==2.7.0',
],
extras_require={
'onnx': [
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def make_trt_flow(port_generator, request):
yield f


@pytest.fixture(scope='session', params=['tensorrt'])
def make_trt_flow_fp16(port_generator, request):
from clip_server.executors.clip_tensorrt import CLIPEncoder

f = Flow(port=port_generator()).add(
name=request.param, uses=CLIPEncoder, uses_with={'dtype': 'fp16'}
)
with f:
yield f


@pytest.fixture(params=['torch'])
def make_search_flow(tmpdir, port_generator, request):
from clip_server.executors.clip_torch import CLIPEncoder
Expand Down
33 changes: 33 additions & 0 deletions tests/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ def test_docarray_inputs(make_trt_flow, inputs):
assert inputs[0] is r[0]


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
[Document(text='hello, world'), Document(text='goodbye, world')],
DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
lambda: (Document(text='hello, world') for _ in range(10)),
DocumentArray(
[
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
Document(text='hello, world'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
).load_uri_to_image_tensor(),
]
),
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
),
],
)
def test_docarray_inputs_fp16(make_trt_flow_fp16, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_trt_flow_fp16.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
if hasattr(inputs, '__len__'):
assert inputs[0] is r[0]


@pytest.mark.gpu
@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down