diff --git a/benchmarks/inference-server/text-generation-inference/main.tf b/benchmarks/inference-server/text-generation-inference/main.tf index 1a9a711f2..e7d2c4dc6 100644 --- a/benchmarks/inference-server/text-generation-inference/main.tf +++ b/benchmarks/inference-server/text-generation-inference/main.tf @@ -68,6 +68,7 @@ resource "kubernetes_manifest" "default" { namespace = var.namespace model_id = var.model_id gpu_count = var.gpu_count + max_concurrent_requests = var.max_concurrent_requests ksa = var.ksa hugging_face_token_secret_list = local.hugging_face_token_secret == null ? [] : [local.hugging_face_token_secret] })) diff --git a/benchmarks/inference-server/text-generation-inference/manifest-templates/text-generation-inference.tftpl b/benchmarks/inference-server/text-generation-inference/manifest-templates/text-generation-inference.tftpl index 70f234977..7c1fe6496 100644 --- a/benchmarks/inference-server/text-generation-inference/manifest-templates/text-generation-inference.tftpl +++ b/benchmarks/inference-server/text-generation-inference/manifest-templates/text-generation-inference.tftpl @@ -50,7 +50,7 @@ spec: ports: - containerPort: 80 image: "ghcr.io/huggingface/text-generation-inference:1.4.2" - args: ["--model-id", "${model_id}", "--num-shard", "${gpu_count}"] # , "{token}" tensor parallelism, should correspond to number of gpus below + args: ["--model-id", "${model_id}", "--num-shard", "${gpu_count}", "--max-concurrent-requests", "${max_concurrent_requests}"] %{ for hugging_face_token_secret in hugging_face_token_secret_list ~} env: - name: HUGGING_FACE_HUB_TOKEN # Related token consumption diff --git a/benchmarks/inference-server/text-generation-inference/variables.tf b/benchmarks/inference-server/text-generation-inference/variables.tf index 1a90313a8..a80a9fcb9 100644 --- a/benchmarks/inference-server/text-generation-inference/variables.tf +++ b/benchmarks/inference-server/text-generation-inference/variables.tf @@ -58,6 +58,18 @@ variable "gpu_count" { } } +variable "max_concurrent_requests" { + description = "Max concurrent requests allowed for TGI to handle at once. TGI will drop all requests once it hits this max-concurrent-requests limit." + type = number + nullable = false + # TODO: default is same as tgi's default for now, update with reasonable number. + default = 128 + validation { + condition = var.max_concurrent_requests > 0 + error_message = "Max conccurent requests must be greater than 0." + } +} + variable "ksa" { description = "Kubernetes Service Account used for workload." type = string