From 3efc07c28417235a4ae55cbca025e25e67319dce Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sat, 16 Nov 2024 10:22:16 -0800 Subject: [PATCH 1/2] refactor -> select_model(functional) --- libs/infinity_emb/infinity_emb/engine.py | 15 +--- .../infinity_emb/inference/batch_handler.py | 88 +++++++++++-------- .../infinity_emb/inference/select_model.py | 88 +++++++++---------- .../infinity_emb/infinity_server.py | 4 +- .../infinity_emb/transformer/abstract.py | 7 +- libs/infinity_emb/poetry.lock | 56 ++++++++++-- libs/infinity_emb/pyproject.toml | 1 + .../unit_test/inference/test_batch_handler.py | 25 +++--- .../unit_test/inference/test_select_model.py | 3 +- .../tests/unit_test/test_engine.py | 8 +- 10 files changed, 179 insertions(+), 116 deletions(-) diff --git a/libs/infinity_emb/infinity_emb/engine.py b/libs/infinity_emb/infinity_emb/engine.py index 73867a7d..33fe5147 100644 --- a/libs/infinity_emb/infinity_emb/engine.py +++ b/libs/infinity_emb/infinity_emb/engine.py @@ -52,9 +52,7 @@ def __init__( self.running = False self._running_sepamore: Optional[Semaphore] = None - self._model_replicas, self._min_inference_t, self._max_inference_t = select_model( - self._engine_args - ) + self._model_replicas_functions = select_model(self._engine_args) @classmethod def from_args( @@ -72,11 +70,7 @@ def from_args( return engine def __str__(self) -> str: - return ( - f"AsyncEmbeddingEngine(running={self.running}, " - f"inference_time={[self._min_inference_t, self._max_inference_t]}, " - f"{self._engine_args})" - ) + return f"AsyncEmbeddingEngine(running={self.running}, " f"{self._engine_args})" async def astart(self): """startup engine""" @@ -87,8 +81,7 @@ async def astart(self): self.running = True self._batch_handler = BatchHandler( max_batch_size=self._engine_args.batch_size, - model_replicas=self._model_replicas, - # batch_delay=self._min_inference_t / 2, + model_replicas=self._model_replicas_functions, vector_disk_cache_path=self._engine_args.vector_disk_cache_path, verbose=logger.level <= 10, lengths_via_tokenize=self._engine_args.lengths_via_tokenize, @@ -124,7 +117,7 @@ def is_running(self) -> bool: @property def capabilities(self) -> set[ModelCapabilites]: - return self._model_replicas[0].capabilities + return self._batch_handler.capabilities @property def engine_args(self) -> EngineArgs: diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 7ed4b83d..92d3a0ec 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -10,8 +10,8 @@ from concurrent.futures import ThreadPoolExecutor from queue import Queue from typing import Any, Optional, Sequence, Union, TYPE_CHECKING - import numpy as np +from functools import cached_property from infinity_emb.env import MANAGER from infinity_emb.inference.caching_layer import Cache @@ -39,7 +39,7 @@ from infinity_emb.transformer.vision.utils import resolve_images if TYPE_CHECKING: - from infinity_emb.transformer.abstract import BaseTypeHint + from infinity_emb.transformer.abstract import CallableReturningBaseTypeHint QUEUE_TIMEOUT = 0.5 @@ -64,7 +64,7 @@ def submit(self, *args, **kwargs): class BatchHandler: def __init__( self, - model_replicas: list["BaseTypeHint"], + model_replicas: list["CallableReturningBaseTypeHint"], max_batch_size: int, max_queue_wait: int = MANAGER.queue_size, batch_delay: float = 5e-3, @@ -91,6 +91,9 @@ def __init__( self._max_queue_wait = max_queue_wait self._lengths_via_tokenize = lengths_via_tokenize + self._max_batch_size = max_batch_size + self._batch_delay = batch_delay + self._verbose = verbose self._shutdown = threading.Event() self._threadpool = ThreadPoolExecutor() @@ -114,18 +117,8 @@ def __init__( self._result_store = ResultKVStoreFuture(cache) # model - self.model_worker = [ - ModelWorker( - shutdown=ShutdownReadOnly(self._shutdown), - model=model_replica, - threadpool=ThreadPoolExecutorReadOnly(self._threadpool), - input_q=self._publish_to_model_queue, - output_q=self._result_queue, - verbose=self.batch_delay, - batch_delay=batch_delay, - ) - for model_replica in model_replicas - ] + self.model_replica_fns = model_replicas + self._capabilities = None if batch_delay > 0.1: logger.warning(f"high batch delay of {batch_delay}") @@ -136,6 +129,12 @@ def __init__( " Consider increasing queue size" ) + @cached_property + def _tiktoken_encoding(self): + import tiktoken + + return tiktoken.encoding_for_model("gpt-3.5-turbo") + async def embed(self, sentences: list[str]) -> tuple[list["EmbeddingReturnType"], int]: """Schedule a sentence to be embedded. Awaits until embedded. @@ -289,10 +288,7 @@ async def audio_embed( f"Options are {self.capabilities}." ) - items = await resolve_audios( - audios, - getattr(self.model_worker[0]._model, "sampling_rate", -42), - ) + items = await resolve_audios(audios, self._extras.get("sampling_rate", -42)) embeddings, usage = await self._schedule(items) return embeddings, usage @@ -319,8 +315,8 @@ async def _schedule(self, list_queueitem: Sequence[AbstractSingle]) -> tuple[lis @property def capabilities(self) -> set[ModelCapabilites]: - # TODO: try to remove inheritance here and return upon init. - return self.model_worker[0].capabilities + assert self._capabilities is not None, "Model not loaded" + return self._capabilities def is_overloaded(self) -> bool: """checks if more items can be queued. @@ -352,12 +348,11 @@ async def _get_prios_usage(self, items: Sequence[AbstractSingle]) -> tuple[list[ if not self._lengths_via_tokenize: return get_lengths_with_tokenize([it.str_repr() for it in items]) else: - return await to_thread( - get_lengths_with_tokenize, - self._threadpool, - _sentences=[it.str_repr() for it in items], - tokenize=self.model_worker[0].tokenize_lengths, - ) + tokenized = [ + len(i) + for i in self._tiktoken_encoding.encode_batch([it.str_repr() for it in items]) + ] + return tokenized, sum(tokenized) def _publish_towards_model( self, @@ -452,8 +447,21 @@ async def spawn(self): ShutdownReadOnly(self._shutdown), self._result_queue, self._threadpool ) ) - for worker in self.model_worker: - worker.spawn() + + def get_model_worker(model_replica_fn) -> tuple[set[ModelCapabilites], dict]: + return ModelWorker( + shutdown=ShutdownReadOnly(self._shutdown), + model_fn=model_replica_fn, + threadpool=ThreadPoolExecutorReadOnly(self._threadpool), + input_q=self._publish_to_model_queue, + output_q=self._result_queue, + verbose=self.batch_delay, + batch_delay=self._batch_delay, + ).spawn() + + self._capabilities, self._extras = get_model_worker(self.model_replica_fns[0]) + if len(self.model_replica_fns) > 1: + self._threadpool.map(get_model_worker, self.model_replica_fns[1:]) async def shutdown(self): """ @@ -473,7 +481,7 @@ class ModelWorker: def __init__( self, shutdown: ShutdownReadOnly, - model: "BaseTypeHint", + model_fn: "CallableReturningBaseTypeHint", threadpool: ThreadPoolExecutorReadOnly, input_q: Queue, output_q: Queue, @@ -481,7 +489,7 @@ def __init__( verbose=False, ) -> None: self._shutdown = shutdown - self._model = model + self._model_fn = model_fn self._threadpool = threadpool self._feature_queue: Queue = Queue(3) self._postprocess_queue: Queue = Queue(5) @@ -492,20 +500,28 @@ def __init__( self._verbose = verbose self._ready = False - def spawn(self): + def spawn(self) -> tuple[set[ModelCapabilites], dict]: if self._ready: raise ValueError("already spawned") # start the threads + self._model = self._model_fn() self._threadpool.submit(self._preprocess_batch) self._threadpool.submit(self._core_batch) self._threadpool.submit(self._postprocess_batch) + extras = {} + if hasattr(self._model, "sampling_rate"): + extras["sampling_rate"] = self._model.sampling_rate # type: ignore + + return self._model.capabilities, extras # type: ignore + @property - def capabilities(self) -> set[ModelCapabilites]: - return self._model.capabilities + def model(self): + assert self._model is not None, "Model not loaded" + return self._model def tokenize_lengths(self, *args, **kwargs): - return self._model.tokenize_lengths(*args, **kwargs) + return self.model.tokenize_lengths(*args, **kwargs) def _preprocess_batch(self): """loops and checks if the _core_batch has worked on all items""" @@ -560,7 +576,7 @@ def _core_batch(self): if self._verbose: logger.debug("[🧠] Inference on batch_size=%s", len(batch)) self._last_inference = time.perf_counter() - embed = self._model.encode_core(feat) + embed = self.model.encode_core(feat) # while-loop just for shutdown while not self._shutdown.is_set(): diff --git a/libs/infinity_emb/infinity_emb/inference/select_model.py b/libs/infinity_emb/infinity_emb/inference/select_model.py index 4f0b24a4..38e92c7d 100644 --- a/libs/infinity_emb/infinity_emb/inference/select_model.py +++ b/libs/infinity_emb/infinity_emb/inference/select_model.py @@ -3,13 +3,11 @@ import json from pathlib import Path -from typing import Union +from typing import Union, TYPE_CHECKING + -from infinity_emb.args import ( - EngineArgs, -) from infinity_emb.log_handler import logger -from infinity_emb.transformer.abstract import BaseCrossEncoder, BaseEmbedder +from functools import partial from infinity_emb.transformer.utils import ( AudioEmbedEngine, EmbedderEngine, @@ -19,9 +17,15 @@ RerankEngine, ) +if TYPE_CHECKING: + from infinity_emb.transformer.abstract import CallableReturningBaseTypeHint, BaseTypeHint + from infinity_emb.args import ( + EngineArgs, + ) + def get_engine_type_from_config( - engine_args: EngineArgs, + engine_args: "EngineArgs", ) -> Union[EmbedderEngine, RerankEngine, PredictEngine, ImageEmbedEngine, AudioEmbedEngine]: """resolved the class of inference engine path from config.json of the repo.""" if engine_args.engine in [InferenceEngine.debugengine]: @@ -57,55 +61,51 @@ def get_engine_type_from_config( return EmbedderEngine.from_inference_engine(engine_args.engine) +def _get_engine_replica(unloaded_engine, engine_args, device_map) -> "BaseTypeHint": + engine_args_copy = engine_args.copy() + engine_args_copy._loading_strategy.device_placement = device_map + loaded_engine = unloaded_engine.value(engine_args=engine_args_copy) + + if engine_args.model_warmup: + # size one, warm up warm start timings. + # loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) + # size one token + min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(5)) + emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup( + batch_size=engine_args.batch_size, n_tokens=1 + ) + + logger.info(log_msg) + # now warm up with max_token, max batch size + loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512) + emb_per_sec, _, log_msg = loaded_engine.warmup( + batch_size=engine_args.batch_size, n_tokens=512 + ) + logger.info(log_msg) + logger.info( + f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}" + f" embeddings/sec at batch_size={engine_args.batch_size}" + ) + return loaded_engine + + def select_model( - engine_args: EngineArgs, -) -> tuple[list[Union[BaseCrossEncoder, BaseEmbedder]], float, float]: + engine_args: "EngineArgs", +) -> list["CallableReturningBaseTypeHint"]: """based on engine args, fully instantiates the Engine.""" logger.info( f"model=`{engine_args.model_name_or_path}` selected, " f"using engine=`{engine_args.engine.value}`" f" and device=`{engine_args.device.resolve()}`" ) - # engine_args.update_loading_strategy() - unloaded_engine = get_engine_type_from_config(engine_args) engine_replicas = [] - min_inference_t = 4e-3 - max_inference_t = 4e-3 - # TODO: Can be parallelized for device_map in engine_args._loading_strategy.device_mapping: # type: ignore - engine_args_copy = engine_args.copy() - engine_args_copy._loading_strategy.device_placement = device_map - loaded_engine = unloaded_engine.value(engine_args=engine_args_copy) - - if engine_args.model_warmup: - # size one, warm up warm start timings. - # loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) - # size one token - min_inference_t = min( - min(loaded_engine.warmup(batch_size=1, n_tokens=1)[1] for _ in range(10)), - min_inference_t, - ) - loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=1) - emb_per_sec_short, max_inference_temp, log_msg = loaded_engine.warmup( - batch_size=engine_args.batch_size, n_tokens=1 - ) - max_inference_t = max(max_inference_temp, max_inference_t) - - logger.info(log_msg) - # now warm up with max_token, max batch size - loaded_engine.warmup(batch_size=engine_args.batch_size, n_tokens=512) - emb_per_sec, _, log_msg = loaded_engine.warmup( - batch_size=engine_args.batch_size, n_tokens=512 - ) - logger.info(log_msg) - logger.info( - f"model warmed up, between {emb_per_sec:.2f}-{emb_per_sec_short:.2f}" - f" embeddings/sec at batch_size={engine_args.batch_size}" - ) - engine_replicas.append(loaded_engine) + engine_replicas.append( + partial(_get_engine_replica, unloaded_engine, engine_args, device_map) + ) assert len(engine_replicas) > 0, "No engine replicas were loaded" - return engine_replicas, min_inference_t, max_inference_t + return engine_replicas # type: ignore diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 0ac1c835..95d71357 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -90,13 +90,13 @@ def create_server( async def lifespan(app: FastAPI): instrumentator.expose(app) # type: ignore logger.info( - f"Creating {len(engine_args_list)}engines: engines={[e.served_model_name for e in engine_args_list]}" + f"Creating {len(engine_args_list)} engines: engines={[e.served_model_name for e in engine_args_list]}" ) telemetry_log_info() app.engine_array = AsyncEngineArray.from_args(engine_args_list) # type: ignore th = threading.Thread( target=send_telemetry_start, - args=(engine_args_list, [e.capabilities for e in app.engine_array]), # type: ignore + args=(engine_args_list, [{} for e in app.engine_array]), # type: ignore ) th.daemon = True th.start() diff --git a/libs/infinity_emb/infinity_emb/transformer/abstract.py b/libs/infinity_emb/infinity_emb/transformer/abstract.py index e5430aa6..945dac5e 100644 --- a/libs/infinity_emb/infinity_emb/transformer/abstract.py +++ b/libs/infinity_emb/infinity_emb/transformer/abstract.py @@ -4,7 +4,7 @@ import random from abc import ABC, abstractmethod from time import perf_counter -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Union, Protocol from infinity_emb._optional_imports import CHECK_PIL # , CHECK_SOUNDFILE from infinity_emb.primitives import ( @@ -225,6 +225,11 @@ def warmup(self, *, batch_size: int = 64, n_tokens=1) -> tuple[float, float, str ] +class CallableReturningBaseTypeHint(Protocol): + def __call__(self) -> BaseTypeHint: + pass + + def run_warmup(model, inputs) -> tuple[float, float, str]: inputs_formated = [i.content.to_input() for i in inputs] start = perf_counter() diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index 89c34b56..95edaebc 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -4166,6 +4166,53 @@ files = [ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, ] +[[package]] +name = "tiktoken" +version = "0.8.0" +description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, + {file = "tiktoken-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9269348cb650726f44dd3bbb3f9110ac19a8dcc8f54949ad3ef652ca22a38e21"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e13f37bc4ef2d012731e93e0fef21dc3b7aea5bb9009618de9a4026844e560"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d13c981511331eac0d01a59b5df7c0d4060a8be1e378672822213da51e0a2"}, + {file = "tiktoken-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6b2ddbc79a22621ce8b1166afa9f9a888a664a579350dc7c09346a3b5de837d9"}, + {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, + {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, + {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04"}, + {file = "tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc"}, + {file = "tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:02be1666096aff7da6cbd7cdaa8e7917bfed3467cd64b38b1f112e96d3b06a24"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94ff53c5c74b535b2cbf431d907fc13c678bbd009ee633a2aca269a04389f9a"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b231f5e8982c245ee3065cd84a4712d64692348bc609d84467c57b4b72dcbc5"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4177faa809bd55f699e88c96d9bb4635d22e3f59d635ba6fd9ffedf7150b9953"}, + {file = "tiktoken-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5376b6f8dc4753cd81ead935c5f518fa0fbe7e133d9e25f648d8c4dabdd4bad7"}, + {file = "tiktoken-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:18228d624807d66c87acd8f25fc135665617cab220671eb65b50f5d70fa51f69"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e17807445f0cf1f25771c9d86496bd8b5c376f7419912519699f3cc4dc5c12e"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:886f80bd339578bbdba6ed6d0567a0d5c6cfe198d9e587ba6c447654c65b8edc"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6adc8323016d7758d6de7313527f755b0fc6c72985b7d9291be5d96d73ecd1e1"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b591fb2b30d6a72121a80be24ec7a0e9eb51c5500ddc7e4c2496516dd5e3816b"}, + {file = "tiktoken-0.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845287b9798e476b4d762c3ebda5102be87ca26e5d2c9854002825d60cdb815d"}, + {file = "tiktoken-0.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:1473cfe584252dc3fa62adceb5b1c763c1874e04511b197da4e6de51d6ce5a02"}, + {file = "tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2"}, +] + +[package.dependencies] +regex = ">=2022.1.18" +requests = ">=2.26.0" + +[package.extras] +blobfile = ["blobfile (>=2)"] + [[package]] name = "timm" version = "1.0.11" @@ -4527,11 +4574,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -5310,4 +5352,4 @@ vision = ["colpali-engine", "pillow", "timm", "torchvision"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "23feae6cd9a95ff4a6ed50da692d28ba9b514d3067adc9bcc8e4860a70a13942" +content-hash = "b2e7c00402b2caa33c1233c66af26849b5c33e5f83c2f196f990adad542e9e94" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 017c8502..f525e7fe 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -56,6 +56,7 @@ diskcache = {version = "*", optional=true} onnxruntime-gpu = {version = "1.19.*", optional=true} tensorrt = {version = "^10.6.0", optional=true} soundfile = {version="^0.12.1", optional=true} +tiktoken = "^0.8.0" [tool.poetry.scripts] diff --git a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py index 3a36ac34..f57f24cd 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_batch_handler.py @@ -27,16 +27,19 @@ reason="windows and macos are not stable with python3.12", ) async def load_patched_bh() -> tuple[SentenceTransformerPatched, BatchHandler]: - model = SentenceTransformerPatched( - engine_args=EngineArgs( - model_name_or_path=MODEL_NAME, - bettertransformer=not torch.backends.mps.is_available(), + def get_m(): + model = SentenceTransformerPatched( + engine_args=EngineArgs( + model_name_or_path=MODEL_NAME, + bettertransformer=not torch.backends.mps.is_available(), + ) ) - ) - model.encode(["hello " * 512] * BATCH_SIZE) - bh = BatchHandler(model_replicas=[model], max_batch_size=BATCH_SIZE) + model.encode(["hello " * 512] * BATCH_SIZE) + return model + + bh = BatchHandler(model_replicas=[get_m], max_batch_size=BATCH_SIZE) await bh.spawn() - return model, bh + return get_m(), bh @pytest.mark.performance @@ -91,13 +94,13 @@ def method_st(_sentences): # yappi.stop() method_st(sentences[::10]) await method_batch_handler(sentences[::10]) - time.sleep(2) + time.sleep(1) time_batch_handler = np.median( [(await method_batch_handler(sentences)) for _ in range(N_TIMINGS)] ) - time.sleep(2) + time.sleep(1) time_st = np.median([method_st(sentences) for _ in range(N_TIMINGS)]) - time.sleep(2) + time.sleep(1) time_st_patched = np.median([method_patched(sentences) for _ in range(N_TIMINGS)]) print( diff --git a/libs/infinity_emb/tests/unit_test/inference/test_select_model.py b/libs/infinity_emb/tests/unit_test/inference/test_select_model.py index ec75bbf4..afd58474 100644 --- a/libs/infinity_emb/tests/unit_test/inference/test_select_model.py +++ b/libs/infinity_emb/tests/unit_test/inference/test_select_model.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize("engine", [e for e in InferenceEngine if e != InferenceEngine.neuron]) def test_engine(engine): - select_model( + model_funcs = select_model( EngineArgs( engine=engine, model_name_or_path=(pytest.DEFAULT_BERT_MODEL), @@ -16,3 +16,4 @@ def test_engine(engine): model_warmup=False, ) ) + [model_func() for model_func in model_funcs] diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index c3bf054b..5d30f7dc 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -45,8 +45,9 @@ async def test_async_api_torch(): device="cpu", ) ) - assert engine.capabilities == {"embed"} + async with engine: + assert engine.capabilities == {"embed"} embeddings, usage = await engine.embed(sentences=sentences) assert isinstance(embeddings, list) assert isinstance(embeddings[0], np.ndarray) @@ -74,8 +75,9 @@ async def test_async_api_torch_double_launch(): device="cpu", ) ) - assert engine.capabilities == {"embed"} + async with engine: + assert engine.capabilities == {"embed"} embeddings, usage = await engine.embed(sentences=sentences) assert isinstance(embeddings, list) assert isinstance(embeddings[0], np.ndarray) @@ -164,9 +166,9 @@ async def test_async_api_torch_CLASSIFY(): device="cpu", ) ) - assert engine.capabilities == {"classify"} async with engine: + assert engine.capabilities == {"classify"} predictions, usage = await engine.classify(sentences=sentences) assert usage == sum([len(s) for s in sentences]) assert len(predictions) == len(sentences) From 80d65beadc2ab1fc24e94a48ec5c90751f8d712e Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Sat, 16 Nov 2024 10:56:45 -0800 Subject: [PATCH 2/2] add better typing --- libs/embed_package/embed/_infer.py | 2 +- .../infinity_emb/fastapi_schemas/pymodels.py | 6 ++--- .../infinity_emb/inference/batch_handler.py | 4 ++-- .../infinity_emb/inference/select_model.py | 6 ++--- .../infinity_emb/infinity_server.py | 23 ++++++++++--------- .../infinity_emb/transformer/audio/utils.py | 2 +- .../infinity_emb/transformer/embedder/ct2.py | 2 +- .../transformer/embedder/neuron.py | 4 ++-- .../transformer/quantization/interface.py | 2 +- .../transformer/vision/torch_vision.py | 2 +- 10 files changed, 27 insertions(+), 26 deletions(-) diff --git a/libs/embed_package/embed/_infer.py b/libs/embed_package/embed/_infer.py index 93e185c8..e5a5c789 100644 --- a/libs/embed_package/embed/_infer.py +++ b/libs/embed_package/embed/_infer.py @@ -1,7 +1,7 @@ from concurrent.futures import Future from typing import Collection, Literal, Union -from infinity_emb import EngineArgs, SyncEngineArray # type: ignore +from infinity_emb import EngineArgs, SyncEngineArray from infinity_emb.infinity_server import AutoPadding __all__ = ["BatchedInference"] diff --git a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py index 3a623471..5fe7e2ba 100644 --- a/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py +++ b/libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py @@ -66,7 +66,7 @@ class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput): ), Annotated[str, INPUT_STRING], ] - modality: Literal[Modality.text] = Modality.text # type: ignore + modality: Literal[Modality.text] = Modality.text class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput): @@ -82,11 +82,11 @@ class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput): class OpenAIEmbeddingInput_Audio(_OpenAIEmbeddingInput_URI): - modality: Literal[Modality.audio] = Modality.audio # type: ignore + modality: Literal[Modality.audio] = Modality.audio class OpenAIEmbeddingInput_Image(_OpenAIEmbeddingInput_URI): - modality: Literal[Modality.image] = Modality.image # type: ignore + modality: Literal[Modality.image] = Modality.image def get_modality(obj: dict) -> str: diff --git a/libs/infinity_emb/infinity_emb/inference/batch_handler.py b/libs/infinity_emb/infinity_emb/inference/batch_handler.py index 92d3a0ec..32c8d4f7 100644 --- a/libs/infinity_emb/infinity_emb/inference/batch_handler.py +++ b/libs/infinity_emb/infinity_emb/inference/batch_handler.py @@ -511,9 +511,9 @@ def spawn(self) -> tuple[set[ModelCapabilites], dict]: extras = {} if hasattr(self._model, "sampling_rate"): - extras["sampling_rate"] = self._model.sampling_rate # type: ignore + extras["sampling_rate"] = self._model.sampling_rate - return self._model.capabilities, extras # type: ignore + return self._model.capabilities, extras @property def model(self): diff --git a/libs/infinity_emb/infinity_emb/inference/select_model.py b/libs/infinity_emb/infinity_emb/inference/select_model.py index 38e92c7d..289a4e16 100644 --- a/libs/infinity_emb/infinity_emb/inference/select_model.py +++ b/libs/infinity_emb/infinity_emb/inference/select_model.py @@ -18,7 +18,7 @@ ) if TYPE_CHECKING: - from infinity_emb.transformer.abstract import CallableReturningBaseTypeHint, BaseTypeHint + from infinity_emb.transformer.abstract import BaseTypeHint # , CallableReturningBaseTypeHint from infinity_emb.args import ( EngineArgs, ) @@ -91,7 +91,7 @@ def _get_engine_replica(unloaded_engine, engine_args, device_map) -> "BaseTypeHi def select_model( engine_args: "EngineArgs", -) -> list["CallableReturningBaseTypeHint"]: +) -> list[partial["BaseTypeHint"]]: """based on engine args, fully instantiates the Engine.""" logger.info( f"model=`{engine_args.model_name_or_path}` selected, " @@ -108,4 +108,4 @@ def select_model( ) assert len(engine_replicas) > 0, "No engine replicas were loaded" - return engine_replicas # type: ignore + return engine_replicas diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 95d71357..3df3572d 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -88,20 +88,20 @@ def create_server( @asynccontextmanager async def lifespan(app: FastAPI): - instrumentator.expose(app) # type: ignore + instrumentator.expose(app) logger.info( f"Creating {len(engine_args_list)} engines: engines={[e.served_model_name for e in engine_args_list]}" ) telemetry_log_info() - app.engine_array = AsyncEngineArray.from_args(engine_args_list) # type: ignore + engine_array = AsyncEngineArray.from_args(engine_args_list) th = threading.Thread( target=send_telemetry_start, - args=(engine_args_list, [{} for e in app.engine_array]), # type: ignore + args=(engine_args_list, [{} for e in engine_array]), ) th.daemon = True th.start() # start in a threadpool - await app.engine_array.astart() # type: ignore + await engine_array.astart() logger.info( docs.startup_message( @@ -120,8 +120,9 @@ async def kill_later(seconds: int): logger.info(f"Preloaded configuration successfully. {engine_args_list} " " -> exit .") asyncio.create_task(kill_later(3)) + app.engine_array = engine_array # type: ignore yield - await app.engine_array.astop() # type: ignore + await engine_array.astop() # shutdown! app = FastAPI( @@ -691,7 +692,7 @@ def v1( device: Device = MANAGER.device[0], # type: ignore lengths_via_tokenize: bool = MANAGER.lengths_via_tokenize[0], dtype: Dtype = MANAGER.dtype[0], # type: ignore - embedding_dtype: EmbeddingDtype = EmbeddingDtype.default_value(), # type: ignore + embedding_dtype: EmbeddingDtype = EmbeddingDtype.default_value(), pooling_method: PoolingMethod = MANAGER.pooling_method[0], # type: ignore compile: bool = MANAGER.compile[0], bettertransformer: bool = MANAGER.bettertransformer[0], @@ -701,7 +702,7 @@ def v1( url_prefix: str = MANAGER.url_prefix, host: str = MANAGER.host, port: int = MANAGER.port, - log_level: UVICORN_LOG_LEVELS = MANAGER.log_level, # type: ignore + log_level: UVICORN_LOG_LEVELS = MANAGER.log_level, ): """Infinity API ♾️ cli v1 - deprecated, consider use cli v2 via `infinity_emb v2`.""" if api_key: @@ -719,9 +720,9 @@ def v1( time.sleep(1) v2( model_id=[model_name_or_path], - served_model_name=[served_model_name], # type: ignore + served_model_name=[served_model_name], batch_size=[batch_size], - revision=[revision], # type: ignore + revision=[revision], trust_remote_code=[trust_remote_code], engine=[engine], dtype=[dtype], @@ -732,7 +733,7 @@ def v1( lengths_via_tokenize=[lengths_via_tokenize], compile=[compile], bettertransformer=[bettertransformer], - embedding_dtype=[EmbeddingDtype.float32], # set to float32 + embedding_dtype=[EmbeddingDtype.float32], # unique kwargs preload_only=preload_only, url_prefix=url_prefix, @@ -846,7 +847,7 @@ def v2( ), log_level: UVICORN_LOG_LEVELS = typer.Option( **_construct("log_level"), help="console log level." - ), # type: ignore + ), permissive_cors: bool = typer.Option( **_construct("permissive_cors"), help="whether to allow permissive cors." ), diff --git a/libs/infinity_emb/infinity_emb/transformer/audio/utils.py b/libs/infinity_emb/infinity_emb/transformer/audio/utils.py index 4893df2a..bdb59d6c 100644 --- a/libs/infinity_emb/infinity_emb/transformer/audio/utils.py +++ b/libs/infinity_emb/infinity_emb/transformer/audio/utils.py @@ -12,7 +12,7 @@ import aiohttp if CHECK_SOUNDFILE.is_available: - import soundfile as sf # type: ignore + import soundfile as sf # type: ignore[import-untyped] async def resolve_audio( diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/ct2.py b/libs/infinity_emb/infinity_emb/transformer/embedder/ct2.py index 7ae7f1e3..95f4effe 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/ct2.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/ct2.py @@ -32,7 +32,7 @@ class Module: # type: ignore[no-redef] if CHECK_CTRANSLATE2.is_available: - import ctranslate2 # type: ignore + import ctranslate2 # type: ignore[import-untyped] class CT2SentenceTransformer(SentenceTransformerPatched): diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py index 240e27ae..93b3cd9a 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/neuron.py @@ -140,8 +140,8 @@ def encode_core(self, input_dict: dict[str, np.ndarray]) -> dict: } @quant_embedding_decorator() - def encode_post(self, embedding: dict) -> EmbeddingReturnType: - embedding = self.pooling( # type: ignore + def encode_post(self, embedding: dict[str, "torch.Tensor"]) -> EmbeddingReturnType: + embedding = self.pooling( embedding["token_embeddings"].numpy(), embedding["attention_mask"].numpy() ) diff --git a/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py b/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py index 88c37a96..8169d516 100644 --- a/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py +++ b/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Union import numpy as np -import requests # type: ignore +import requests # type: ignore[import-untyped] from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH from infinity_emb.env import MANAGER diff --git a/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py b/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py index 205ef9b9..6d99d425 100644 --- a/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py +++ b/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py @@ -27,7 +27,7 @@ if CHECK_TORCH.is_available: import torch if CHECK_TRANSFORMERS.is_available: - from transformers import AutoConfig, AutoModel, AutoProcessor # type: ignore + from transformers import AutoConfig, AutoModel, AutoProcessor # type: ignore[import-untyped] if CHECK_PIL.is_available: from PIL import Image