Skip to content

Commit

Permalink
model matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Glen authored and Glen committed Dec 8, 2024
1 parent fb8d870 commit c8f9372
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
10 changes: 5 additions & 5 deletions .github/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime


async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
"""
Measures the performance of an API endpoint by sending a prompt and recording metrics.
Expand All @@ -19,7 +19,6 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
Returns:
Dict[str, Any]: A dictionary containing performance metrics or error information.
"""
model = os.environ.get('model', 'llama-3.2-1b')

results = {
'model': model,
Expand Down Expand Up @@ -100,17 +99,18 @@ async def main() -> None:
prompt_warmup = "what is the capital of France?"
prompt_essay = "write an essay about cats"

model = os.environ.get('model', 'llama-3.2-1b')
# Warmup request
print("\nPerforming warmup request...", flush=True)
try:
warmup_results = await measure_performance(api_endpoint, prompt_warmup)
warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
print("Warmup completed successfully", flush=True)
except Exception as e:
print(f"Warmup request failed: {e}", flush=True)

# Measure performance for the essay prompt
print("\nMeasuring performance for the essay prompt...", flush=True)
results = await measure_performance(api_endpoint, prompt_essay)
results = await measure_performance(api_endpoint, prompt_essay, model)

try:
s3_client = boto3.client(
Expand All @@ -124,7 +124,7 @@ async def main() -> None:
now = datetime.utcnow()
timestamp = now.strftime('%H-%M-%S')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
s3_key = f"{job_name}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"

# Upload to S3
s3_client.put_object(
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/bench_job.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
ps aux | grep exo || true
CALLING_JOB="${{ inputs.calling_job_name }}"
UNIQUE_JOB_ID="${CALLING_JOB}_${GITHUB_RUN_ID}"
UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
source env/bin/activate
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ on:

jobs:
test-m4-cluster:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b']
# Optional: add fail-fast: false if you want all matrix jobs to continue even if one fails
fail-fast: false
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 2}'
model: 'llama-3.2-1b'
model: ${{ matrix.model }}
calling_job_name: 'test-m4-cluster'
secrets: inherit

0 comments on commit c8f9372

Please sign in to comment.