Skip to content

Commit

Permalink
Update llama_benchmarking.md
Browse files Browse the repository at this point in the history
  • Loading branch information
aviator19941 authored Dec 13, 2024
1 parent 9426559 commit 486c6dd
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions llama_benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ python3 -m sharktank.examples.export_paged_llm_v1 --bs=4 --irpa-file=8b_f16.irpa

## 3. Get the numpy inputs

Get the 8b f16 tp1 unsharded numpy inputs: [get_8b_f16_tp1_numpy_inputs.sh](https://gist.github.com/aviator19941/380acabc77aeb4749fac14262e17db69)
Get the 8b f16 tp1 unsharded prefill numpy inputs: [get_8b_f16_tp1_prefill_inputs.sh](https://gist.github.com/aviator19941/380acabc77aeb4749fac14262e17db69)

Get the 8b f16 tp1 unsharded decode numpy inputs: [get_8b_f16_tp1_decode_inputs.sh](https://gist.github.com/aviator19941/5f7db8ada6a4a95efe1d9a7975fed276)

## 4. Compile command
This command compiles the full IR (both prefill + decode) into a vmfb.
Expand All @@ -58,7 +60,7 @@ This command compiles the full IR (both prefill + decode) into a vmfb.

## 5. Benchmark command
In order to benchmark prefill, make sure you specify the function as `prefill_bs{batch_size}` and specify the 4 inputs using the numpy files in
`/data/llama-3.1/weights/8b/prefill_args_bs4_128`.
`prefill_args_bs4_128_stride_32`.

Prefill benchmark command:

Expand All @@ -71,15 +73,14 @@ ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
--parameters=model=8b_fp16.irpa \
--device=hip://4 \
--function=prefill_bs4 \
--input=@/data/llama-3.1/weights/8b/prefill_args_bs4_128/random_tokens.npy \
--input=@/data/llama-3.1/weights/8b/prefill_args_bs4_128/seq_lens.npy \
--input=@/data/llama-3.1/weights/8b/prefill_args_bs4_128/seq_block_ids.npy \
--input=@/data/llama-3.1/weights/8b/prefill_args_bs4_128/cs_f16.npy \
--input=@prefill_args_bs4_128_stride_32/tokens.npy \
--input=@prefill_args_bs4_128_stride_32/seq_lens.npy \
--input=@prefill_args_bs4_128_stride_32/seq_block_ids.npy \
--input=@prefill_args_bs4_128_stride_32/cs_f16.npy \
--benchmark_repetitions=3
```

In order to benchmark decode, make sure you specify the function as `decode_bs{batch_size}` and specify the 5 inputs using the numpy files in
`/data/llama-3.1/weights/8b/decode_args_bs4_128`.
In order to benchmark decode, make sure you specify the function as `decode_bs{batch_size}` and specify the 5 inputs using the numpy files in `decode_args_bs4_128_stride_32`.

Decode benchmark command:

Expand All @@ -92,11 +93,11 @@ ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
--parameters=model=8b_fp16.irpa \
--device=hip://4 \
--function=decode_bs4 \
--input=@../SHARK-Platform/decode_args_bs4_128_stride_32/next_tokens.npy \
--input=@../SHARK-Platform/decode_args_bs4_128_stride_32/seq_lens.npy \
--input=@../SHARK-Platform/decode_args_bs4_128_stride_32/start_positions.npy \
--input=@../SHARK-Platform/decode_args_bs4_128_stride_32/seq_block_ids.npy \
--input=@../SHARK-Platform/decode_args_bs4_128_stride_32/cs_f16.npy \
--input=@decode_args_bs4_128_stride_32/next_tokens.npy \
--input=@decode_args_bs4_128_stride_32/seq_lens.npy \
--input=@decode_args_bs4_128_stride_32/start_positions.npy \
--input=@decode_args_bs4_128_stride_32/seq_block_ids.npy \
--input=@decode_args_bs4_128_stride_32/cs_f16.npy \
--benchmark_repetitions=3
```

Expand Down

0 comments on commit 486c6dd

Please sign in to comment.