Skip to content

Commit

Permalink
Fix passing base_url in model_id in InferenceEndpointsLLM (#924)
Browse files Browse the repository at this point in the history
* Fix passing `base_url` in `model_id`

* Print ruff version

* Install dev dependencies after

* Update `ruff`

* noqa

* Skip ray tests on 3.12

* Do not run ray tests in 3.12
  • Loading branch information
gabrielmbmb authored Aug 23, 2024
1 parent 2b6b238 commit 379c756
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ sources = src/distilabel tests

.PHONY: format
format:
ruff --version
ruff check --fix $(sources)
ruff format $(sources)

.PHONY: lint
lint:
ruff --version
ruff check $(sources)
ruff format --check $(sources)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ distilabel = "distilabel.cli.app:app"
"distilabel/components-gallery" = "distilabel.utils.mkdocs.components_gallery:ComponentsGalleryPlugin"

[project.optional-dependencies]
dev = ["ruff == 0.4.5", "pre-commit >= 3.5.0"]
dev = ["ruff == 0.6.2", "pre-commit >= 3.5.0"]
docs = [
"mkdocs-material >=9.5.17",
"mkdocstrings[python] >= 0.24.0",
Expand Down
4 changes: 3 additions & 1 deletion scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])")

python -m pip install uv

uv pip install --system -e ".[dev,tests,anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu]"
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu]"

if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
fi

./scripts/install_cpu_vllm.sh
uv pip install --system git+https://github.com/argilla-io/LLM-Blender.git

uv pip install --system -e ".[dev,tests]"
2 changes: 1 addition & 1 deletion src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def load(self) -> None: # noqa: C901
self._model_name = client.repository

self._aclient = AsyncInferenceClient(
model=self.base_url,
base_url=self.base_url,
token=self.api_key.get_secret_value(),
)

Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType

if TYPE_CHECKING:
from openai import OpenAI
from openai import OpenAI # noqa
from transformers import PreTrainedTokenizer
from vllm import LLM as _vLLM

Expand Down
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from typing import TYPE_CHECKING, List

import pytest

if TYPE_CHECKING:
from _pytest.config import Config
from _pytest.nodes import Item


def pytest_configure(config: "Config") -> None:
config.addinivalue_line(
"markers",
"skip_python_versions(versions): mark test to be skipped on specified Python versions",
)


def pytest_collection_modifyitems(config: "Config", items: List["Item"]) -> None:
current_version = f"{sys.version_info.major}.{sys.version_info.minor}"
for item in items:
skip_versions_marker = item.get_closest_marker("skip_python_versions")
if skip_versions_marker:
versions_to_skip = skip_versions_marker.args[0]
if current_version in versions_to_skip:
skip_reason = f"Test not supported on Python {current_version}"
item.add_marker(pytest.mark.skip(reason=skip_reason))
5 changes: 1 addition & 4 deletions tests/integration/test_ray_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from typing import TYPE_CHECKING, Dict, List

import pytest
Expand Down Expand Up @@ -149,9 +148,7 @@ def outputs(self) -> List[str]:
return ["response"]


@pytest.mark.skipif(
sys.version_info >= (3, 12), reason="`ray` is not compatible with `python>=3.12`"
)
@pytest.mark.skip_python_versions(["3.12"])
def test_run_pipeline() -> None:
import ray
from ray.cluster_utils import Cluster
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/pipeline/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from typing import Generator

import pytest
import ray
from ray.cluster_utils import Cluster

from distilabel.llms.vllm import vLLM
from distilabel.pipeline.ray import RayPipeline
Expand All @@ -27,6 +25,9 @@

@pytest.fixture
def ray_test_cluster() -> Generator[None, None, None]:
import ray
from ray.cluster_utils import Cluster

cluster = Cluster(
initialize_head=True,
head_node_args={
Expand All @@ -43,6 +44,7 @@ def ray_test_cluster() -> Generator[None, None, None]:
ray.shutdown()


@pytest.mark.skip_python_versions(["3.12"])
@pytest.mark.usefixtures("ray_test_cluster")
class TestRayPipeline:
def test_dump(self) -> None:
Expand Down

0 comments on commit 379c756

Please sign in to comment.