Skip to content

Commit

Permalink
fix: percentiles
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Sep 18, 2024
1 parent 3758d0d commit 6dcd931
Show file tree
Hide file tree
Showing 9 changed files with 707 additions and 42 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
build:
cargo build --release --package text-generation-inference-benchmark --bin text-generation-inference-benchmark

run: build
cargo run --package text-generation-inference-benchmark --bin text-generation-inference-benchmark -- $@
108 changes: 102 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
# Text Generation Inference benchmarking tool

A lightweight benchmarking tool for inference servers.
A lightweight benchmarking tool for LLM inference servers.
Benchmarks using constant arrival rate or constant virtual user count.



![ui.png](assets%2Fui.png)

## Table of contents

<!-- TOC -->
* [Text Generation Inference benchmarking tool](#text-generation-inference-benchmarking-tool)
* [Table of contents](#table-of-contents)
* [TODO](#todo)
* [Running a benchmark](#running-a-benchmark)
* [Development](#development)
* [Frequently Asked Questions](#frequently-asked-questions)
<!-- TOC -->

## TODO
- [X] Customizable token count and variance
- [ ] Check results
- [X] Allow for multiturn prompts for prefix caching
- [X] Allow for system prompts for prefix caching
- [ ] Allow for multi-turn prompts
- [ ] Push results to Optimum benchmark backend
- [ ] Script to generate plots from results
- [X] Script to generate plots from results

## Running a benchmark

```
## Get started

### Run a benchmark

Run a benchmark using Docker image:

```shell
# start a TGI/vLLM server somewhere, then run benchmark...
# ... we mount results to the current directory
$ docker run \
Expand All @@ -33,4 +52,81 @@ $ docker run \
--decode-options "num_tokens=50,max_tokens=60,min_tokens=40,variance=10"
```

Results will be saved in `results.json` in current directory.
Results will be saved in `results.json` in current directory.


### Configure your benchmark

#### Benchmark mode

In default mode, tool runs a `sweep` benchmark. It first runs a throughput test to find the maximum throughput, then
sweeps on QPS values up to the maximum throughput.

Available modes:
- `sweep`: runs a sweep benchmark
- `rate`: runs a benchmark at a fixed request rate
- `throughput`: runs a benchmark at a fixed throughput (constant VUs)


#### Dataset configuration

Prompts are sampled for a Hugging Face dataset file, using a [subset of ShareGPT
as default](https://huggingface.co/datasets/hlarcher/share_gpt_small). You can specify a different dataset file using the
`--dataset` and `--dataset-file` option.

Dataset is expected to be JSON with the following format:
```json
[
{
"conversations": [
{
"role": "user",
"content": "rewrite that entire paragraph in the same style like this one: "
}
]
}
]
```

To benchmark with prefix caching, you can use a system prompt that will be sent with each request from a discussion.
```json
[
{
"conversations": [
{
"role": "system",
"content": "You are a helpful assistant that makes jokes at each response."
},
{
"role": "user",
"content": "rewrite that entire paragraph in the same style like this one:"
}
]
}
]
```


#### Prompt configuration
For consistent results you can configure the token count and variance. The tool will sample prompts with the specified
values, sampling token counts from a normal distribution with the specified variance.

```shell
--prompt-options "num_tokens=50,max_tokens=60,min_tokens=40,variance=10"
```


## Development

You need [Rust](https://rustup.rs/) installed to build the benchmarking tool.
```shell
$ make build
```


## Frequently Asked Questions
* **What's the difference between constant arrival rate and constant virtual user count?**
* **Constant virtual user count** means that the number of virtual users is fixed. Each virtual user can send a single requests and waits for server response. It's basically simulating a fixed number of users querying the server.
* **Constant arrival rate** means that the rate of requests is fixed and the number of virtual users is adjusted to maintain that rate. Queries hit the server independently of responses performances.

**Constant virtual user count** is a closed loop model where the server's response time dictates the number of iterations. **Constant arrival rate** is an open-loop model more representative of real-life workloads.
252 changes: 252 additions & 0 deletions optimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import argparse
import hashlib
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Protocol, Optional
from urllib.parse import urlparse

from opensearchpy import OpenSearch

PERFORMANCE_RECORD_LATENCY_MS = "latency"
PERFORMANCE_RECORD_THROUGHPUT_SAMPLE_PER_SEC = "throughput"


@dataclass
class PerformanceRecord:
metric: str
kind: str
value: Any

when: datetime = field(default_factory=lambda: datetime.now())
meta: Dict[str, Any] = field(default_factory=dict)

@staticmethod
def latency(metric: str, value_ms: float, meta: Optional[Dict[str, Any]] = None, when: Optional[datetime] = None):
r"""
Create a PerformanceRecord tracking latency information
Args:
`metric` (`str`):
Metric identifier
`value_ms` (`float`):
The recorded latency, in millisecond, for the underlying metric record
`meta` (`Optional[Dict[str, Any]]`, defaults to `{}`)
Information relative to the recorded metric to store alongside the metric readout
`when` (`Optional[datetime]`, defaults to `datetime.now()`)
Indicates when the underlying metric was recorded
Returns:
The performance record for the target metric representing latency
"""
return PerformanceRecord(
metric=metric, kind=PERFORMANCE_RECORD_LATENCY_MS, value=value_ms, when=when, meta=meta
)

@staticmethod
def throughput(metric: str, value_sample_per_sec: float, meta: Optional[Dict[str, Any]] = None,
when: Optional[datetime] = None):
r"""
Create a PerformanceRecord tracking throughput information
Args:
`metric` (`str`):
Metric identifier
`value_sample_per_sec` (`float`):
The recorded throughput, in samples per second, for the underlying metric record
`meta` (`Optional[Dict[str, Any]]`, defaults to `{}`)
Information relative to the recorded metric to store alongside the metric readout
`when` (`Optional[datetime]`, defaults to `datetime.now()`)
Indicates when the underlying metric was recorded
Returns:
The performance record for the target metric representing throughput
"""
return PerformanceRecord(
metric=metric,
kind=PERFORMANCE_RECORD_THROUGHPUT_SAMPLE_PER_SEC,
value=value_sample_per_sec,
when=when,
meta=meta
)

def as_document(self) -> Dict[str, Any]:
r"""
Convert the actual `PerformanceRecord` to a dictionary based representation compatible with document storage
Returns:
Dictionary of strings keys with the information stored in this record
"""
parcel = {"date": self.when.timestamp(), "metric": self.metric, "kind": self.kind, "value": self.value}
return parcel | self.meta


class PerformanceTrackerStore(Protocol):
r"""
Base interface defining a performance tracker tool
"""

@staticmethod
def from_uri(uri: str) -> "PerformanceTrackerStore":
r"""
Create the `PerformanceTrackerStore` from the provided URI information
Args:
`uri` (`str`):
URI specifying over which protocol and where will be stored the record(s)
Returns:
Instance of a `PerformanceTrackerStore` which information are inferred from the specified URI
"""
pass

def push(self, collection: str, record: "PerformanceRecord"):
r"""
Attempt to append the provided record to the underlying tracker putting under the specified collection
Args:
`collection` (`str`):
Name of the bucket the specified record should be pushed
`record` (`PerformanceRecord`):
The materialized record to push
"""
pass


class OpenSearchPerformanceTrackerStore(PerformanceTrackerStore):
r"""
Amazon Web Services (AWS) OpenSearch based PerformanceTrackerStore
Supported URIs are as follows:
- os://<username:password@><hostname>:<port>
- os+aws://<aws_access_key_id:aws_secret_access_key@><hostname>:<port>
- os+aws://<hostname>:<port> - will use the stored aws credentials on the system
"""

# Extract region and service from AWS url (ex: us-east-1.es.amazonaws.com)
AWS_URL_RE = re.compile(r"([a-z]+-[a-z]+-[0-9])\.(.*)?\.amazonaws.com")

def __init__(self, url: str, auth):
uri = urlparse(url)
self._client = OpenSearch(
[{"host": uri.hostname, "port": uri.port or 443}],
http_auth=auth,
http_compress=True,
use_ssl=True
)

# Sanity check
self._client.info()

@staticmethod
def from_uri(uri: str) -> "PerformanceTrackerStore":
if not (_uri := urlparse(uri)).scheme.startswith("es"):
raise ValueError(f"Invalid URI {uri}: should start with os:// or os+aws://")

if _uri.scheme == "es+aws":
from boto3 import Session as AwsSession
from botocore.credentials import Credentials as AwsCredentials
from opensearchpy import Urllib3AWSV4SignerAuth

# Create AWS session from the (eventual) creds
if not _uri.username and not _uri.password:
session = AwsSession()
creds = session.get_credentials()
else:
creds = AwsCredentials(_uri.username, _uri.password)

# Parse the url to extract region and service
if len(match := re.findall(OpenSearchPerformanceTrackerStore.AWS_URL_RE, _uri.netloc)) != 1:
raise ValueError(f"Failed to parse AWS es service URL {uri}")

region, service = match[0]
auth = Urllib3AWSV4SignerAuth(creds, region, service)
else:
auth = (_uri.username, _uri.password)

return OpenSearchPerformanceTrackerStore(uri, auth)

def _ensure_collection_exists(self, collection: str):
if not self._client.indices.exists(collection):
self._client.indices.create(collection)

def push(self, collection: str, record: "PerformanceRecord"):
self._ensure_collection_exists(collection)
self._client.index(collection, record.as_document())


class AutoPerformanceTracker:

@staticmethod
def from_uri(uri: str) -> "PerformanceTrackerStore":
if uri.startswith("es://") or uri.startswith("es+aws://"):
return OpenSearchPerformanceTrackerStore.from_uri(uri)

raise ValueError(
f"Unable to determine the service associated with URI: {uri}. "
"Valid schemas are es:// or es+aws://"
)


def main():
parser = argparse.ArgumentParser(
prog='text-generation-inference-benchmark-optimum',
description='Pushes benchmark results to an OpenSearch instance'
)
parser.add_argument(
'--uri',
type=str,
required=False,
help='URI to the OpenSearch instance where to push the benchmark results',
default='"es+aws://search-optimum-benchmarks-kb3meoztyufprqul537nq7deny.us-east-1.es.amazonaws.com"'
)
parser.add_argument(
'--collection',
type=str,
required=False,
help='Collection name where to push the benchmark results',
default='ci_tgi_performances_tracker'
)
parser.add_argument(
'--meta',
action='append',
required=False,
help='Meta information to store alongside the benchmark results, use multiple times for multiple values',
nargs='?'
)
parser.add_argument(
'results',
type=str,
help='File containing the benchmark results to push',
)
args = parser.parse_args()
meta = flatten(args.meta)
bench_id = hashlib.md5(open(args.results, 'rb').read()).hexdigest()
meta['bench_id'] = bench_id

with open(args.results, 'r') as f:
data = json.load(f)

tracker=AutoPerformanceTracker.from_uri("es+aws://search-optimum-benchmarks-kb3meoztyufprqul537nq7deny.us-east-1.es.amazonaws.com")
filtered_results = [result for result in data['results'] if
result['id'] != 'warmup' and result['id'] != 'throughput']
latency_metrics_to_push = ['inter_token_latency_ms_p90', 'time_to_first_token_ms_p90', 'e2e_latency_ms_p90']
throughput_metrics_to_push = ['token_throughput_secs']
start_time = data['start_time']
for result in filtered_results:
for metric in latency_metrics_to_push:
record = PerformanceRecord.latency(metric, result[metric], {**meta, 'qps': result['config']['rate']},
when=start_time)
print(record)
tracker.push("ci_tgi_performances_tracker", record)
for metric in throughput_metrics_to_push:
record = PerformanceRecord.throughput(metric, result[metric], {**meta, 'qps': result['config']['rate']},
when=start_time)
print(record)
tracker.push("ci_tgi_performances_tracker", record)

# record=PerformanceRecord.latency("TIME_TO_FIRST_TOKEN", 100,{})


def flatten(l: list[str]) -> dict[str, str]:
d = {}
for e in l:
e = e.split('=')
d[e[0]] = e[1]
return d


if __name__ == '__main__':
main()
Loading

0 comments on commit 6dcd931

Please sign in to comment.