Skip to content

Commit

Permalink
🧹 Cleanup of the batch size environment variables (#121)
Browse files Browse the repository at this point in the history
* refactor HF_BATCH_SIZE and BATCH_SIZE into MAX_BATCH_SIZE

* change default batch size to 4
  • Loading branch information
baptistecolle authored Nov 28, 2024
1 parent 8c2c199 commit ffa990d
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion optimum/tpu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def get_export_kwargs_from_env():
batch_size = os.environ.get("HF_BATCH_SIZE", None)
batch_size = os.environ.get("MAX_BATCH_SIZE", None)
if batch_size is not None:
batch_size = int(batch_size)
sequence_length = os.environ.get("HF_SEQUENCE_LENGTH", None)
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ docker run -p 8080:80 \
--net=host --privileged \
-v $(pwd)/data:/data \
-e HF_TOKEN=${HF_TOKEN} \
-e HF_BATCH_SIZE=1 \
-e MAX_BATCH_SIZE=4 \
-e HF_SEQUENCE_LENGTH=1024 \
ghcr.io/huggingface/tpu-tgi:latest \
--model-id mistralai/Mistral-7B-v0.1 \
Expand Down
8 changes: 4 additions & 4 deletions text-generation-inference/docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
ulimit -l 68719476736

# Hugging Face Hub related
if [[ -z "${BATCH_SIZE}" ]]; then
BATCH_SIZE=2
if [[ -z "${MAX_BATCH_SIZE}" ]]; then
MAX_BATCH_SIZE=4
fi
export BATCH_SIZE="${BATCH_SIZE}"
export MAX_BATCH_SIZE="${MAX_BATCH_SIZE}"

if [[ -z "${JSON_OUTPUT_DISABLE}" ]]; then
JSON_OUTPUT_DISABLE=--json-output
Expand All @@ -33,6 +33,6 @@ export QUANTIZATION="${QUANTIZATION}"


exec text-generation-launcher --port 8080 \
--max-batch-size ${BATCH_SIZE} \
--max-batch-size ${MAX_BATCH_SIZE} \
${JSON_OUTPUT_DISABLE} \
--model-id ${MODEL_ID}
2 changes: 1 addition & 1 deletion text-generation-inference/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def docker_launcher(
if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN

for var in ["HF_BATCH_SIZE", "HF_SEQUENCE_LENGTH"]:
for var in ["MAX_BATCH_SIZE", "HF_SEQUENCE_LENGTH"]:
if var in os.environ:
env[var] = os.environ[var]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def serve(
from .server import serve

# Read environment variables forwarded by the launcher
max_batch_size = int(os.environ.get("MAX_BATCH_SIZE", "1"))
max_batch_size = int(os.environ.get("MAX_BATCH_SIZE", "4"))
max_total_tokens = int(os.environ.get("MAX_TOTAL_TOKENS", "64"))

# Start the server
Expand Down

0 comments on commit ffa990d

Please sign in to comment.