Skip to content

Commit

Permalink
Multi gpu support (#358)
Browse files Browse the repository at this point in the history
* feat: Added multi gpu support for text embedding

* feat: Add support for multi-gpu for special text models

* fix: Fix lazy_load to load the model to child processes when parallel is not none

* feat: Added lazy_load and multi-gpu to colbert

* feat: Add lazy_load and multi gpu to image models

* feat: Support lazy_load and multi-gpu to sparse models (except BM25)

* fix: Fixed BM25 not working

* refactor: Remove redundant GPUParallelProcessor

* refactor: Refactor _embed_*_parallel

* feat: Add cuda argument
refactor: Refactor how worker assign device

* fix: Fix if providers and cuda are None

* fix: Fix providers and cuda are none

* WIP: Multi gpu support review (#361)

* WIP: review

* wip: review

* refactor: refactor images

* refactor: refactor sparse

* refactor: refactor late interaction

* add model loading

* add tests

* fix: uncomment models in tests

* fix: fix variable declaration order

* fix: fix device id assignment

* tests: add multi gpu tests

* fix: fix device id assignment for sparse embeddings

* tests: update multi gpu tests

---------

Co-authored-by: George Panchuk <[email protected]>

* refactor: remove redundant declarations

* fix: rollback redundant changes

* fix: remove num workers device ids dep, fix type hint

* fix: fix post process for sparse models

* fix: remove redundant model loading

* new: add lazy load and new gpu support to cross encoders

* fix: add rerankers to multi gpu tests

* fix: unlock multilingual test

* fix: fix gpu test with cross encoder

---------

Co-authored-by: Andrey Vasnetsov <[email protected]>
Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2024
1 parent 58b5a8e commit eaecf7d
Show file tree
Hide file tree
Showing 29 changed files with 823 additions and 152 deletions.
20 changes: 16 additions & 4 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,28 @@ def _preprocess_onnx_input(
"""
return onnx_input

def load_onnx_model(
def _load_onnx_model(
self,
model_dir: Path,
model_file: str,
threads: Optional[int],
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_id: Optional[int] = None,
) -> None:
model_path = model_dir / model_file
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers

onnx_providers = (
["CPUExecutionProvider"] if providers is None else list(providers)
)
if providers is not None:
onnx_providers = list(providers)
elif cuda:
if device_id is None:
onnx_providers = ["CUDAExecutionProvider"]
else:
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
else:
onnx_providers = ["CPUExecutionProvider"]

available_providers = ort.get_available_providers()
requested_provider_names = []
for provider in onnx_providers:
Expand Down Expand Up @@ -94,6 +103,9 @@ def load_onnx_model(
RuntimeWarning,
)

def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def onnx_embed(self, *args, **kwargs) -> OnnxOutputContext:
raise NotImplementedError("Subclasses must implement this method")

Expand Down
12 changes: 7 additions & 5 deletions fastembed/image/image_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ def __init__(
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
lazy_load: bool = False,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(
model_name.lower() == model["model"].lower()
for model in supported_models
):
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(
model_name,
cache_dir,
threads=threads,
providers=providers,
cuda=cuda,
device_ids=device_ids,
lazy_load=lazy_load,
**kwargs,
)
return
Expand Down
2 changes: 1 addition & 1 deletion fastembed/image/image_embedding_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def embed(
Embeds a list of images into a list of embeddings.
Args:
images - The list of image paths to preprocess and embed.
images: The list of image paths to preprocess and embed.
batch_size: Batch size for encoding
parallel:
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
Expand Down
61 changes: 53 additions & 8 deletions fastembed/image/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
):
"""
Expand All @@ -68,24 +72,56 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

super().__init__(model_name, cache_dir, threads, **kwargs)

model_description = self._get_model_description(model_name)
self.providers = providers
self.lazy_load = lazy_load

# List of device ids, that can be used for data parallel processing in workers
self.device_ids = device_ids
self.cuda = cuda

# This device_id will be used if we need to load model in current process
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(
model_description, self.cache_dir, local_files_only=self._local_files_only
self.model_description, self.cache_dir, local_files_only=self._local_files_only
)

self.load_onnx_model(
if not self.lazy_load:
self.load_onnx_model()

def load_onnx_model(self) -> None:
"""
Load the onnx model.
"""
self._load_onnx_model(
model_dir=self._model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
model_file=self.model_description["model_file"],
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)

@classmethod
Expand Down Expand Up @@ -120,12 +156,16 @@ def embed(
Returns:
List of embeddings, one per document
"""

yield from self._embed_images(
model_name=self.model_name,
cache_dir=str(self.cache_dir),
images=images,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

Expand All @@ -148,4 +188,9 @@ def _post_process_onnx_output(self, output: OnnxOutputContext) -> Iterable[np.nd

class OnnxImageEmbeddingWorker(ImageEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> OnnxImageEmbedding:
return OnnxImageEmbedding(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
return OnnxImageEmbedding(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
48 changes: 33 additions & 15 deletions fastembed/image/onnx_image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,28 @@ def _preprocess_onnx_input(
"""
return onnx_input

def load_onnx_model(
def _load_onnx_model(
self,
model_dir: Path,
model_file: str,
threads: Optional[int],
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_id: Optional[int] = None,
) -> None:
super().load_onnx_model(
super()._load_onnx_model(
model_dir=model_dir,
model_file=model_file,
threads=threads,
providers=providers,
cuda=cuda,
device_id=device_id,
)
self.processor = load_preprocessor(model_dir=model_dir)

def load_onnx_model(self) -> None:
raise NotImplementedError("Subclasses must implement this method")

def _build_onnx_input(self, encoded: np.ndarray) -> Dict[str, np.ndarray]:
return {node.name: encoded for node in self.model.get_inputs()}

Expand All @@ -74,33 +81,44 @@ def _embed_images(
images: ImageInput,
batch_size: int = 256,
parallel: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
**kwargs,
) -> Iterable[T]:
is_small = False

if (
isinstance(images, str)
or isinstance(images, Path)
or (isinstance(images, Image.Image))
):
if isinstance(images, (str, Path, Image.Image)):
images = [images]
is_small = True

if isinstance(images, list):
if len(images) < batch_size:
is_small = True

if parallel == 0:
parallel = os.cpu_count()
if isinstance(images, list) and len(images) < batch_size:
is_small = True

if parallel is None or is_small:
if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for batch in iter_batch(images, batch_size):
yield from self._post_process_onnx_output(self.onnx_embed(batch))
else:
if parallel == 0:
parallel = os.cpu_count()

start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
params = {"model_name": model_name, "cache_dir": cache_dir, **kwargs}
params = {
"model_name": model_name,
"cache_dir": cache_dir,
"providers": providers,
**kwargs,
}

pool = ParallelWorkerPool(
parallel, self._get_worker_class(), start_method=start_method
parallel,
self._get_worker_class(),
cuda=cuda,
device_ids=device_ids,
start_method=start_method,
)
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
yield from self._post_process_onnx_output(batch)
Expand Down
62 changes: 54 additions & 8 deletions fastembed/late_interaction/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def __init__(
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
cuda: bool = False,
device_ids: Optional[List[int]] = None,
lazy_load: bool = False,
device_id: Optional[int] = None,
**kwargs,
):
"""
Expand All @@ -126,29 +130,60 @@ def __init__(
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
Defaults to False.
device_ids (Optional[List[int]], optional): The list of device ids to use for data parallel processing in
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""

super().__init__(model_name, cache_dir, threads, **kwargs)
self.providers = providers
self.lazy_load = lazy_load

# List of device ids, that can be used for data parallel processing in workers
self.device_ids = device_ids
self.cuda = cuda

# This device_id will be used if we need to load model in current process
if device_id is not None:
self.device_id = device_id
elif self.device_ids is not None:
self.device_id = self.device_ids[0]
else:
self.device_id = None

model_description = self._get_model_description(model_name)
self.model_description = self._get_model_description(model_name)
self.cache_dir = define_cache_dir(cache_dir)

self._model_dir = self.download_model(
model_description, self.cache_dir, local_files_only=self._local_files_only
self.model_description, self.cache_dir, local_files_only=self._local_files_only
)
self.mask_token_id = None
self.pad_token_id = None
self.skip_list = set()

if not self.lazy_load:
self.load_onnx_model()

self.load_onnx_model(
def load_onnx_model(self) -> None:
self._load_onnx_model(
model_dir=self._model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
model_file=self.model_description["model_file"],
threads=self.threads,
providers=self.providers,
cuda=self.cuda,
device_id=self.device_id,
)
self.mask_token_id = self.special_token_to_id["[MASK]"]
self.pad_token_id = self.tokenizer.padding["pad_id"]

self.skip_list = {
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
for symbol in string.punctuation
Expand Down Expand Up @@ -182,13 +217,19 @@ def embed(
documents=documents,
batch_size=batch_size,
parallel=parallel,
providers=self.providers,
cuda=self.cuda,
device_ids=self.device_ids,
**kwargs,
)

def query_embed(self, query: Union[str, List[str]], **kwargs) -> Iterable[np.ndarray]:
if isinstance(query, str):
query = [query]

if not hasattr(self, "model") or self.model is None:
self.load_onnx_model()

for text in query:
yield from self._post_process_onnx_output(
self.onnx_embed([text], is_doc=False), is_doc=False
Expand All @@ -201,4 +242,9 @@ def _get_worker_class(cls) -> Type[TextEmbeddingWorker]:

class ColbertEmbeddingWorker(TextEmbeddingWorker):
def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> Colbert:
return Colbert(model_name=model_name, cache_dir=cache_dir, threads=1, **kwargs)
return Colbert(
model_name=model_name,
cache_dir=cache_dir,
threads=1,
**kwargs,
)
Loading

0 comments on commit eaecf7d

Please sign in to comment.