Skip to content

Commit

Permalink
[CI/Build][TPU] Add TPU CI test (vllm-project#6277)
Browse files Browse the repository at this point in the history
Co-authored-by: kevin <[email protected]>
  • Loading branch information
2 people authored and dtrifiro committed Jul 17, 2024
1 parent d09c37b commit d0813a2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
16 changes: 16 additions & 0 deletions .buildkite/run-tpu-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set -e

# Build the docker image.
docker build -f Dockerfile.tpu -t vllm-tpu .

# Set up cleanup.
remove_docker_container() { docker rm -f tpu-test || true; }
trap remove_docker_container EXIT
# Remove the container that might not be cleaned up in the previous run.
remove_docker_container

# For HF_TOKEN.
source /etc/environment
# Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
python3 /workspace/vllm/examples/offline_inference_tpu.py
5 changes: 0 additions & 5 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,4 @@ COPY . /workspace/vllm
ENV VLLM_TARGET_DEVICE="tpu"
RUN cd /workspace/vllm && python setup.py develop

# Re-install outlines to avoid dependency errors.
# The outlines version must follow requirements-common.txt.
RUN pip uninstall outlines -y
RUN pip install "outlines>=0.0.43"

CMD ["/bin/bash"]
28 changes: 28 additions & 0 deletions examples/offline_inference_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from vllm import LLM, SamplingParams

prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)

# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="google/gemma-2b", enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)

0 comments on commit d0813a2

Please sign in to comment.