Skip to content

Commit

Permalink
fix(tests): fix broken connection to docker container
Browse files Browse the repository at this point in the history
  • Loading branch information
baptistecolle committed Dec 4, 2024
1 parent c822676 commit 1af9edc
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 203 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ style_check:
ruff check .

style:
ruff check . --fix

# Utilities to release to PyPi
build_dist_install_tools:
Expand Down
57 changes: 28 additions & 29 deletions text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,28 @@
import contextlib
import os
import shlex
import signal
import subprocess
import sys
import threading
import time
import signal
from tempfile import TemporaryDirectory
from typing import List

import docker
import pytest
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from loguru import logger
from test_model import MODEL_CONFIGS
from text_generation import AsyncClient
from text_generation.types import Response
from loguru import logger


DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "huggingface/optimum-tpu:latest")
HF_TOKEN = os.getenv("HF_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")

# Configure loguru logger
logger.remove() # Remove default handler
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
Expand Down Expand Up @@ -62,7 +61,7 @@ def stream_container_logs(container):

class LauncherHandle:
def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}", timeout=600)
self.client = AsyncClient(f"http://localhost:{port}", timeout=3600)

def _inner_health(self):
raise NotImplementedError
Expand All @@ -71,7 +70,7 @@ async def health(self, timeout: int = 60):
assert timeout > 0
start_time = time.time()
logger.info(f"Starting health check with timeout of {timeout}s")

for attempt in range(timeout):
if not self._inner_health():
logger.error("Launcher crashed during health check")
Expand Down Expand Up @@ -126,13 +125,6 @@ def _inner_health(self) -> bool:
return self.process.poll() is None


@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="module")
def data_volume():
tmpdir = TemporaryDirectory()
Expand All @@ -145,18 +137,21 @@ def data_volume():


@pytest.fixture(scope="module")
def launcher(event_loop, data_volume):
def launcher(data_volume):
@contextlib.contextmanager
def docker_launcher(
model_id: str,
trust_remote_code: bool = False,
):
logger.info(f"Starting docker launcher for model {model_id}")
# TODO: consider finding out how to forward a port in the container instead of leaving it to 80.
# For now this is necessary because TPU dockers require to run with net=host and privileged mode.
port = 80
port = 8080

args = ["--env"]
args = [
"--max-input-length", "512",
"--max-total-tokens", "1024",
"--max-batch-prefill-tokens", "512",
"--max-batch-total-tokens", "1024"
]

if trust_remote_code:
args.append("--trust-remote-code")
Expand All @@ -175,17 +170,14 @@ def docker_launcher(
except Exception as e:
logger.error(f"Error handling existing container: {str(e)}")

env = {
"LOG_LEVEL": "info,text_generation_router,text_generation_launcher=debug",
"MAX_BATCH_SIZE": "4",
"HF_HUB_ENABLE_HF_TRANSFER": "0",
"JETSTREAM_PT": "1",
"SKIP_WARMUP": "1",
"MODEL_ID": model_id,
}
model_name = next(name for name, cfg in MODEL_CONFIGS.items() if cfg["model_id"] == model_id)
env = MODEL_CONFIGS[model_name]["env_config"].copy()

# Add model_id to env
env["MODEL_ID"] = model_id

if HF_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
env["HF_TOKEN"] = HF_TOKEN

for var in ["MAX_BATCH_SIZE", "HF_SEQUENCE_LENGTH"]:
if var in os.environ:
Expand All @@ -198,7 +190,13 @@ def docker_launcher(
logger.debug(f"Creating container with image {DOCKER_IMAGE}")
logger.debug(f"Container environment: {env}")
logger.debug(f"Container volumes: {volumes}")


# Log equivalent docker run command
env_str = ' '.join([f'-e {k}="{v}"' for k,v in env.items()])
volume_str = ' '.join([f'-v {v}' for v in volumes])
cmd_str = f'docker run -d --name {container_name} {env_str} {volume_str} --shm-size 16G --privileged --ipc host {DOCKER_IMAGE} {" ".join(args)}'
logger.debug(f"Equivalent docker run command:\n{cmd_str}")

container = client.containers.run(
DOCKER_IMAGE,
command=args,
Expand All @@ -210,6 +208,7 @@ def docker_launcher(
shm_size="16G",
privileged=True,
ipc_mode="host",
ports={"80/tcp": 8080}
)
logger.info(f"Container {container_name} started successfully")

Expand Down Expand Up @@ -245,7 +244,7 @@ def docker_launcher(
logger.info(f"Stopping container {container_name}")
container.stop()
container.wait()

container_output = container.logs().decode("utf-8")
print(container_output, file=sys.stderr)

Expand Down
87 changes: 0 additions & 87 deletions text-generation-inference/integration-tests/test_gemma.py

This file was deleted.

86 changes: 0 additions & 86 deletions text-generation-inference/integration-tests/test_gpt2.py

This file was deleted.

Loading

0 comments on commit 1af9edc

Please sign in to comment.