Skip to content

Commit

Permalink
Merge pull request #26 from Oscilloscope98/embedding-device-logic-update
Browse files Browse the repository at this point in the history
[Community] Update optimization logic regarding device for `llama-index-embeddings-ipex-llm`
  • Loading branch information
Oscilloscope98 authored May 22, 2024
2 parents 8f1e978 + 4bc7180 commit d7fc6c4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
from llama_index.core.utils import get_cache_dir, infer_torch_device
from llama_index.core.utils import get_cache_dir
from llama_index.embeddings.ipex_llm.utils import (
DEFAULT_HUGGINGFACE_EMBEDDING_MODEL,
BGE_MODELS,
Expand Down Expand Up @@ -57,7 +57,15 @@ def __init__(
callback_manager: Optional[CallbackManager] = None,
**model_kwargs,
):
self._device = device or infer_torch_device()
# Set "cpu" as default device
self._device = device or "cpu"

if self._device not in ["cpu", "xpu"]:
logger.warning(
"IpexLLMEmbedding currently only supports device to be 'cpu' or 'xpu', "
f"but you have: {self._device}; Use 'cpu' instead."
)
self._device = "cpu"

cache_folder = cache_folder or get_cache_dir()

Expand All @@ -84,13 +92,11 @@ def __init__(
**model_kwargs,
)

if self._device == "cpu":
self._model = _optimize_pre(self._model)
self._model = _optimize_post(self._model)
# TODO: optimize using ipex-llm optimize_model
elif self._device == "xpu":
self._model = _optimize_pre(self._model)
self._model = _optimize_post(self._model)
# Apply ipex-llm optimizations
self._model = _optimize_pre(self._model)
self._model = _optimize_post(self._model)
if self._device == "xpu":
# TODO: apply `ipex_llm.optimize_model`
self._model = self._model.half().to(self._device)

if max_length:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-embeddings-ipex-llm"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.1"
version = "0.1.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down

0 comments on commit d7fc6c4

Please sign in to comment.