diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh new file mode 100644 index 0000000000000..4aabd123ae234 --- /dev/null +++ b/.buildkite/run-tpu-test.sh @@ -0,0 +1,16 @@ +set -e + +# Build the docker image. +docker build -f Dockerfile.tpu -t vllm-tpu . + +# Set up cleanup. +remove_docker_container() { docker rm -f tpu-test || true; } +trap remove_docker_container EXIT +# Remove the container that might not be cleaned up in the previous run. +remove_docker_container + +# For HF_TOKEN. +source /etc/environment +# Run a simple end-to-end example. +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \ + python3 /workspace/vllm/examples/offline_inference_tpu.py diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 23bb78682da2c..6ad8e8ccfac78 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -15,9 +15,4 @@ COPY . /workspace/vllm ENV VLLM_TARGET_DEVICE="tpu" RUN cd /workspace/vllm && python setup.py develop -# Re-install outlines to avoid dependency errors. -# The outlines version must follow requirements-common.txt. -RUN pip uninstall outlines -y -RUN pip install "outlines>=0.0.43" - CMD ["/bin/bash"] diff --git a/examples/offline_inference_tpu.py b/examples/offline_inference_tpu.py new file mode 100644 index 0000000000000..251629b8027ce --- /dev/null +++ b/examples/offline_inference_tpu.py @@ -0,0 +1,28 @@ +from vllm import LLM, SamplingParams + +prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", +] +answers = [ + " or, through inaction, allow a human being to come to harm.", + " what is essential is invisible to the eye.", + " but in rising every time we fall.", +] +N = 1 +# Currently, top-p sampling is disabled. `top_p` should be 1.0. +sampling_params = SamplingParams(temperature=0.7, + top_p=1.0, + n=N, + max_tokens=16) + +# Set `enforce_eager=True` to avoid ahead-of-time compilation. +# In real workloads, `enforace_eager` should be `False`. +llm = LLM(model="google/gemma-2b", enforce_eager=True) +outputs = llm.generate(prompts, sampling_params) +for output, answer in zip(outputs, answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text.startswith(answer)