Skip to content

Commit

Permalink
Test inference endpoint model config parsing from path (#434)
Browse files Browse the repository at this point in the history
* Add example model config for existing endpoint

* Test InferenceEndpointModelConfig.from_path

* Comment default main branch in example

* Fix typo

* Delete unused add_special_tokens param in endpoint example config

* Fix typo

* Implement InferenceEndpointModelConfig.from_path

* Use InferenceEndpointModelConfig.from_path

* Refactor InferenceEndpointModelConfig.from_path

* Align docs
  • Loading branch information
albertvillanova authored Dec 12, 2024
1 parent de8dba3 commit f907a34
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 31 deletions.
4 changes: 1 addition & 3 deletions docs/source/evaluate-the-model-on-a-server-or-container.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ model:
# endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
# reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation
model_name: "meta-llama/Llama-2-7b-hf"
revision: "main"
# revision: "main" # defaults to "main"
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
instance:
accelerator: "gpu"
Expand All @@ -45,8 +45,6 @@ model:
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
env_vars:
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
generation:
add_special_tokens: true
```

### Text Generation Inference (TGI)
Expand Down
6 changes: 2 additions & 4 deletions examples/model_configs/endpoint_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ model:
# endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
# reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation
model_name: "meta-llama/Llama-2-7b-hf"
revision: "main"
# revision: "main" # defaults to "main"
dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16"
instance:
accelerator: "gpu"
Expand All @@ -14,9 +14,7 @@ model:
instance_size: "x1"
framework: "pytorch"
endpoint_type: "protected"
namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace
namespace: null # The namespace under which to launch the endpoint. Defaults to the current user's namespace
image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models.
env_vars:
null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048`
generation:
add_special_tokens: true
5 changes: 5 additions & 0 deletions examples/model_configs/endpoint_model_reuse_existing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model:
base_params:
# Pass either model_name, or endpoint_name and true reuse_existing
endpoint_name: "llama-2-7B-lighteval" # needs to be lower case without special characters
reuse_existing: true # defaults to false; if true, ignore all params in instance, and don't delete the endpoint after evaluation
25 changes: 2 additions & 23 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def inference_endpoint(
"""
Evaluate models using inference-endpoints as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
Expand All @@ -220,31 +219,11 @@ def inference_endpoint(

parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
all_params = {
"model_name": config["base_params"].get("model_name", None),
"endpoint_name": config["base_params"].get("endpoint_name", None),
"model_dtype": config["base_params"].get("dtype", None),
"revision": config["base_params"].get("revision", None) or "main",
"reuse_existing": config["base_params"].get("reuse_existing"),
"accelerator": config.get("instance", {}).get("accelerator", None),
"region": config.get("instance", {}).get("region", None),
"vendor": config.get("instance", {}).get("vendor", None),
"instance_size": config.get("instance", {}).get("instance_size", None),
"instance_type": config.get("instance", {}).get("instance_type", None),
"namespace": config.get("instance", {}).get("namespace", None),
"image_url": config.get("instance", {}).get("image_url", None),
"env_vars": config.get("instance", {}).get("env_vars", None),
}
model_config = InferenceEndpointModelConfig(
# We only initialize params which have a non default value
**{k: v for k, v in all_params.items() if v is not None},
)

model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
11 changes: 10 additions & 1 deletion src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,21 @@ def __post_init__(self):
# xor operator, one is None but not the other
if (self.instance_size is None) ^ (self.instance_type is None):
raise ValueError(
"When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling."
"When creating an inference endpoint, you need to specify explicitly both instance_type and instance_size, or none of them for autoscaling."
)

if not (self.endpoint_name is None) ^ int(self.model_name is None):
raise ValueError("You need to set either endpoint_name or model_name (but not both).")

@classmethod
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
config["base_params"]["model_dtype"] = config["base_params"].pop("dtype", None)
return cls(**config["base_params"], **config.get("instance", {}))

def get_dtype_args(self) -> Dict[str, str]:
if self.model_dtype is None:
return {}
Expand Down
85 changes: 85 additions & 0 deletions tests/models/test_endpoint_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import pytest

from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig


# "examples/model_configs/endpoint_model.yaml"


class TestInferenceEndpointModelConfig:
@pytest.mark.parametrize(
"config_path, expected_config",
[
(
"examples/model_configs/endpoint_model.yaml",
{
"model_name": "meta-llama/Llama-2-7b-hf",
"revision": "main",
"model_dtype": "float16",
"endpoint_name": None,
"reuse_existing": False,
"accelerator": "gpu",
"region": "eu-west-1",
"vendor": "aws",
"instance_type": "nvidia-a10g",
"instance_size": "x1",
"framework": "pytorch",
"endpoint_type": "protected",
"namespace": None,
"image_url": None,
"env_vars": None,
},
),
(
"examples/model_configs/endpoint_model_lite.yaml",
{
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
# Defaults:
"revision": "main",
"model_dtype": None,
"endpoint_name": None,
"reuse_existing": False,
"accelerator": "gpu",
"region": "us-east-1",
"vendor": "aws",
"instance_type": None,
"instance_size": None,
"framework": "pytorch",
"endpoint_type": "protected",
"namespace": None,
"image_url": None,
"env_vars": None,
},
),
(
"examples/model_configs/endpoint_model_reuse_existing.yaml",
{"endpoint_name": "llama-2-7B-lighteval", "reuse_existing": True},
),
],
)
def test_from_path(self, config_path, expected_config):
config = InferenceEndpointModelConfig.from_path(config_path)
for key, value in expected_config.items():
assert getattr(config, key) == value

0 comments on commit f907a34

Please sign in to comment.