Skip to content

Commit

Permalink
- Adding Jetstream support to locust solution.
Browse files Browse the repository at this point in the history
- Made --stop-timeout an opt-in variable
- Piped namespace var to locust master KNS name
- Bumped transformers version for Gemma support
  • Loading branch information
kfswain committed Apr 15, 2024
1 parent 82228d7 commit 92394d3
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pyzmq==25.0.0
requests==2.31.0
roundrobin==0.0.2
six==1.16.0
transformers==4.36.0
transformers==4.39.3
typing_extensions==4.1.1
urllib3==1.26.18
Werkzeug==2.3.8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

LOCUST="/usr/local/bin/locust"
LOCUS_OPTS="-f /locust-tasks/tasks.py --host=$TARGET_HOST"
LOCUST_OPTS="-f /locust-tasks/tasks.py --host=$TARGET_HOST"
LOCUST_MODE=${LOCUST_MODE:-standalone}

if [[ "$LOCUST_MODE" = "master" ]]; then
Expand All @@ -24,17 +24,20 @@ if [[ "$LOCUST_MODE" = "master" ]]; then
# For inferencing workloads with large payload having no wait time is unreasonable.
# This timeout is set to large amount to avoid user tasks being killed too early.
# TODO: turn timeout into a variable.
LOCUS_OPTS="$LOCUS_OPTS --master --stop-timeout 10800"
LOCUST_OPTS="$LOCUST_OPTS --master "
if [[ "$STOP_TIMEOUT" != 0 ]]; then
LOCUST_OPTS="$LOCUST_OPTS --stop-timeout $STOP_TIMEOUT"
fi
elif [[ "$LOCUST_MODE" = "worker" ]]; then
huggingface-cli login --token $HUGGINGFACE_TOKEN
FILTER_PROMPTS="python /locust-tasks/load_data.py"
FILTER_PROMPTS_OPTS="--gcs_path=$GCS_PATH --tokenizer=$TOKENIZER --max_prompt_len=$MAX_PROMPT_LEN --max_num_prompts=$MAX_NUM_PROMPTS"
echo "$FILTER_PROMPTS $FILTER_PROMPTS_OPTS"
$FILTER_PROMPTS $FILTER_PROMPTS_OPTS

LOCUS_OPTS="$LOCUS_OPTS --worker --master-host=$LOCUST_MASTER"
LOCUST_OPTS="$LOCUST_OPTS --worker --master-host=$LOCUST_MASTER"
fi

echo "$LOCUST $LOCUS_OPTS"
echo "$LOCUST $LOCUST_OPTS"

$LOCUST $LOCUS_OPTS
$LOCUST $LOCUST_OPTS
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def generate_request(prompt):
"max_tokens": output_len,
"stream": False,
}
elif backend == "jetstream":
pload = {
"prompt": prompt,
"max_tokens": output_len,
}
else:
raise ValueError(f"Unknown backend: {backend}")
return pload
Expand All @@ -122,6 +127,8 @@ def get_token_count(prompt, resp):
tokenizer.encode(resp_dict['text_output']))
elif backend == "sax":
number_of_output_tokens = 0 # to be added
elif backend == "jetstream":
number_of_output_tokens = 0
else:
raise ValueError(f"Unknown backend: {backend}")
return number_of_input_tokens, number_of_output_tokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
class LocustRun:
"""Represents single run of Locust load tests."""

def __init__(self, duration, users, rate):
def __init__(self, duration, users, rate, namespace):
self.duration = duration
self.users = users
self.rate = rate
self.namespace = namespace

duration: int = 120
users: int = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@


@app.get("/run")
async def root(background_tasks: BackgroundTasks, duration=os.environ["DURATION"], users=os.environ["USERS"], rate=os.environ["RATE"]):
async def root(background_tasks: BackgroundTasks, duration=os.environ["DURATION"], users=os.environ["USERS"], rate=os.environ["RATE"], namespace=os.environ["NAMESPACE"]):

run: LocustRun = LocustRun(duration=duration,
users=users,
rate=rate)
rate=rate,
namespace=namespace)

background_tasks.add_task(call_locust, run)

return {"message": f"""Swarming started"""}


def call_locust(run: LocustRun):
locust_service = "locust-master.benchmark.svc.cluster.local"

locust_service = f"locust-master.{run.namespace}.svc.cluster.local"

run.start_time = time.time()

Expand Down Expand Up @@ -96,7 +98,7 @@ def grab_metrics(start_time: float, end_time: float, filter: str, type: MetricTy
return results
except:
print("No metrics found")
[]
results = []

return results

Expand Down
1 change: 1 addition & 0 deletions benchmarks/benchmark/tools/locust-load-inference/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ locals {
enable_custom_metrics = var.enable_custom_metrics
huggingface_secret = var.huggingface_secret
csv_upload_frequency = var.csv_upload_frequency
stop_timeout = var.stop_timeout
})) : data]
])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ spec:
value: ${inference_server_framework}
- name: ENABLE_CUSTOM_METRICS
value: ${enable_custom_metrics}
- name: STOP_TIMEOUT
value: ${stop_timeout}
ports:
- name: loc-master-web
containerPort: 8089
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ spec:
- name: USERS
value: ${users}
- name: RATE
value: ${rate}
value: ${rate}
- name: NAMESPACE
value: ${namespace}
10 changes: 8 additions & 2 deletions benchmarks/benchmark/tools/locust-load-inference/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ variable "num_locust_workers" {
default = 1
}

variable "stop_timeout" {
description = "Length of time before a locust job is stopped."
type = number
default = 0
}

variable "inference_server_service" {
description = "Inference server service"
type = string
Expand All @@ -83,8 +89,8 @@ variable "inference_server_framework" {
nullable = false
default = "tgi"
validation {
condition = var.inference_server_framework == "vllm" || var.inference_server_framework == "tgi" || var.inference_server_framework == "tensorrt_llm_triton" || var.inference_server_framework == "sax"
error_message = "The inference_server_framework must be one of: vllm, tgi, tensorrt_llm_triton, sax."
condition = var.inference_server_framework == "vllm" || var.inference_server_framework == "tgi" || var.inference_server_framework == "tensorrt_llm_triton" || var.inference_server_framework == "sax" || var.inference_server_framework == "jetstream"
error_message = "The inference_server_framework must be one of: vllm, tgi, tensorrt_llm_triton, sax, or jetstream."
}
}

Expand Down

0 comments on commit 92394d3

Please sign in to comment.