Skip to content

Commit

Permalink
Update CI to store hal executable files artifacts and fix some comments
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Oct 16, 2024
1 parent 83bcae5 commit 62b286f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 45 deletions.
14 changes: 10 additions & 4 deletions .github/workflows/ci-llama.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Integration Tests
name: Llama Benchmarking Tests

on:
workflow_dispatch:
Expand All @@ -16,11 +16,11 @@ concurrency:

jobs:
test_llama:
name: "Integration Tests - llama"
name: "Llama Benchmarking Tests"
strategy:
matrix:
version: [3.11]
os: [ubuntu-latest, windows-latest]
os: [llama-mi300]
fail-fast: false
runs-on: ${{matrix.os}}
defaults:
Expand Down Expand Up @@ -67,4 +67,10 @@ jobs:
"numpy<2.0"
- name: Run llama test
run: pytest sharktank/tests/models/llama/benchmark-tests.py
run: pytest sharktank/tests/models/llama/benchmark-tests.py

- name: Upload llama executable files
uses: actions/upload-artifact@v4
with:
name: llama-files
path: ${{ github.workspace }}/files/
4 changes: 2 additions & 2 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def main():
action="store_true",
)
parser.add_argument(
"--attn-kernel",
"--attention-kernel",
type=str,
default="decomposed",
help='["decomposed", "torch_sdpa"],',
choices=["decomposed", "torch_sdpa"],
)

args = cli.parse(parser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import os
import sys
import unittest
import pytest
import subprocess
Expand Down Expand Up @@ -121,7 +122,7 @@ def iree_benchmark_module(
benchmark_args += args
cmd = subprocess.list2cmdline(benchmark_args)
logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}")
proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd)
proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd)
return_code = proc.returncode
if return_code != 0:
raise Exception(f"{cmd} failed to run")
Expand All @@ -134,16 +135,9 @@ def setUp(self):
artifacts_dir = "/data/extra/models/llama3.1_8B/"
self.irpa_path = artifacts_dir + "llama8b_f16.irpa"
self.irpa_path_fp8 = artifacts_dir + "llama8b_fp8.irpa"
self.output_mlir = self.repo_root + "llama8b_f16.mlir"
self.output_json = self.repo_root + "llama8b_f16.json"
self.output_vmfb = self.repo_root + "llama8b_f16.vmfb"
self.output_mlir_fp8 = self.repo_root + "llama8b_fp8.mlir"
self.output_json_fp8 = self.repo_root + "llama8b_fp8.json"
self.output_vmfb_fp8 = self.repo_root + "llama8b_fp8.vmfb"
self.iree_compile_args = [
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
f"--iree-hal-dump-executable-files-to={self.repo_root}/files/llama",
]
self.prefill_args_f16 = artifacts_dir + "prefill_args"
self.decode_args_f16 = artifacts_dir + "decode_args"
Expand Down Expand Up @@ -185,121 +179,141 @@ def setUp(self):
]

def testBenchmark8B_f16_Decomposed(self):
output_mlir = self.repo_root + "llama8b_f16_decomposed.mlir"
output_json = self.repo_root + "llama8b_f16_decomposed.json"
output_vmfb = self.repo_root + "llama8b_f16_decomposed.vmfb"
self.export_mlir(
"decomposed",
self.irpa_path,
self.output_mlir,
self.output_json,
output_mlir,
output_json,
self.repo_root,
)
self.iree_compile(
self.output_mlir, self.output_vmfb, self.iree_compile_args, self.repo_root
)
iree_compile_args = self.iree_compile_args + [
f"--iree-hal-dump-executable-files-to={self.repo_root}/files/llama-8b/f16-decomposed"
]
self.iree_compile(output_mlir, output_vmfb, iree_compile_args, self.repo_root)
# benchmark prefill
self.iree_benchmark_module(
"0",
self.output_vmfb_fp8,
self.irpa_path_fp8,
self.iree_run_prefill_args_fp8,
output_vmfb,
self.irpa_path,
self.iree_run_prefill_args,
self.repo_root,
)
# benchmark decode
self.iree_benchmark_module(
"0",
self.output_vmfb_fp8,
self.irpa_path_fp8,
self.iree_run_decode_args_fp8,
output_vmfb,
self.irpa_path,
self.iree_run_decode_args,
self.repo_root,
)

@pytest.mark.xfail
@pytest.mark.skip(reason="TODO: Need to plumb through attention_kernel")
def testBenchmark8B_f16_Non_Decomposed(self):
output_mlir = self.repo_root + "llama8b_f16_torch_sdpa.mlir"
output_json = self.repo_root + "llama8b_f16_torch_sdpa.json"
output_vmfb = self.repo_root + "llama8b_f16_torch_sdpa.vmfb"
self.export_mlir(
"torch_sdpa",
self.irpa_path,
self.output_mlir,
self.output_json,
output_mlir,
output_json,
self.repo_root,
)
self.iree_compile(
self.output_mlir, self.output_vmfb, self.iree_compile_args, self.repo_root
)
iree_compile_args = self.iree_compile_args + [
f"--iree-hal-dump-executable-files-to={self.repo_root}/files/llama-8b/f16-torch-sdpa"
]
self.iree_compile(output_mlir, output_vmfb, iree_compile_args, self.repo_root)
# benchmark prefill
self.iree_benchmark_module(
"0",
self.output_vmfb,
output_vmfb,
self.irpa_path,
self.iree_run_prefill_args,
self.repo_root,
)
# benchmark decode
self.iree_benchmark_module(
"0",
self.output_vmfb,
output_vmfb,
self.irpa_path,
self.iree_run_decode_args,
self.repo_root,
)

@pytest.mark.xfail
def testBenchmark8B_fp8_Decomposed(self):
output_mlir = self.repo_root + "llama8b_fp8_decomposed.mlir"
output_json = self.repo_root + "llama8b_fp8_decomposed.json"
output_vmfb = self.repo_root + "llama8b_fp8_decomposed.vmfb"
self.export_mlir(
"decomposed",
self.irpa_path_fp8,
self.output_mlir_fp8,
self.output_json_fp8,
output_mlir,
output_json,
self.repo_root,
)
iree_compile_args = self.iree_compile_args + [
f"--iree-hal-dump-executable-files-to={self.repo_root}/files/llama-8b/fp8-decomposed"
]
self.iree_compile(
self.output_mlir_fp8,
self.output_vmfb_fp8,
output_mlir,
output_vmfb,
self.iree_compile_args,
self.repo_root,
)
# benchmark prefill
self.iree_benchmark_module(
"0",
self.output_vmfb,
output_vmfb,
self.irpa_path,
self.iree_run_prefill_args,
self.repo_root,
)
# benchmark decode
self.iree_benchmark_module(
"0",
self.output_vmfb,
output_vmfb,
self.irpa_path,
self.iree_run_decode_args,
self.repo_root,
)

@pytest.mark.xfail
def testBenchmark8B_fp8_Non_Decomposed(self):
output_mlir = self.repo_root + "llama8b_fp8_torch_sdpa.mlir"
output_json = self.repo_root + "llama8b_fp8_torch_sdpa.json"
output_vmfb = self.repo_root + "llama8b_fp8_torch_sdpa.vmfb"
self.export_mlir(
"torch_sdpa",
self.irpa_path_fp8,
self.output_mlir_fp8,
self.output_json_fp8,
output_mlir,
output_json,
self.repo_root,
)
iree_compile_args = self.iree_compile_args + [
f"--iree-hal-dump-executable-files-to={self.repo_root}/files/llama-8b/fp8-torch-sdpa"
]
self.iree_compile(
self.output_mlir_fp8,
self.output_vmfb_fp8,
output_mlir,
output_vmfb,
self.iree_compile_args,
self.repo_root,
)
# benchmark prefill
self.iree_benchmark_module(
"0",
self.output_vmfb_fp8,
output_vmfb,
self.irpa_path_fp8,
self.iree_run_prefill_args_fp8,
self.repo_root,
)
# benchmark decode
self.iree_benchmark_module(
"0",
self.output_vmfb_fp8,
output_vmfb,
self.irpa_path_fp8,
self.iree_run_decode_args_fp8,
self.repo_root,
Expand Down

0 comments on commit 62b286f

Please sign in to comment.