-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d3b5a25
commit 87bd3c7
Showing
1 changed file
with
43 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,9 @@ Create a SAS token in Azure: | |
- Replace [Add your SAS token here] (including the [ and ]) by SAS token string in instructions below | ||
|
||
``` | ||
azcopy copy 'https://sharkblobs.blob.core.windows.net/halo-models/llm-dev/llama3_8b/8b_f16.irpa?[Add SAS token here]' '8b_f16.irpa' | ||
azcopy copy \ | ||
'https://sharkblobs.blob.core.windows.net/halo-models/llm-dev/llama3_8b/8b_f16.irpa?[Add SAS token here]' \ | ||
'8b_f16.irpa' | ||
``` | ||
|
||
If you have trouble accessing `sharkblobs`, you can copy the 8b f16 unsharded irpa file from the `SharkMi300x` machine: | ||
|
@@ -37,12 +39,25 @@ scp [email protected]:/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16.irpa 8b_f1 | |
## 2. Generate the IR | ||
a. To generate the IR for prefill only: | ||
``` | ||
python3 -m sharktank.examples.export_paged_llm_v1 --bs=4 --irpa-file=8b_f16.irpa --output-mlir=8b_f16_prefill_nondecomposed.mlir --output-config=8b_f16_prefill_nondecomposed.json --attention-kernel=torch --skip-decode --block-seq-stride=32 | ||
python3 -m sharktank.examples.export_paged_llm_v1 \ | ||
--bs=4 \ | ||
--irpa-file=8b_f16.irpa \ | ||
--output-mlir=8b_f16_prefill_nondecomposed.mlir \ | ||
--output-config=8b_f16_prefill_nondecomposed.json \ | ||
--attention-kernel=torch \ | ||
--skip-decode \ | ||
--block-seq-stride=32 | ||
``` | ||
|
||
To generate the IR for both prefill + decode (remove the `--skip-decode` flag): | ||
``` | ||
python3 -m sharktank.examples.export_paged_llm_v1 --bs=4 --irpa-file=8b_f16.irpa --output-mlir=8b_f16_prefill_nondecomposed.mlir --output-config=8b_f16_prefill_nondecomposed.json --attention-kernel=torch --block-seq-stride=32 | ||
python3 -m sharktank.examples.export_paged_llm_v1 \ | ||
--bs=4 \ | ||
--irpa-file=8b_f16.irpa \ | ||
--output-mlir=8b_f16_prefill_nondecomposed.mlir \ | ||
--output-config=8b_f16_prefill_nondecomposed.json \ | ||
--attention-kernel=torch \ | ||
--block-seq-stride=32 | ||
``` | ||
|
||
## 3. Get the numpy inputs | ||
|
@@ -55,7 +70,20 @@ Get the 8b f16 tp1 unsharded decode numpy inputs: [get_8b_f16_tp1_decode_inputs. | |
This command compiles the full IR (both prefill + decode) into a vmfb. | ||
|
||
``` | ||
../iree-build-no-trace/tools/iree-compile 8b_f16_prefill_nondecomposed.mlir --iree-hip-target=gfx942 -o=prefill_8b.vmfb --iree-hal-target-device=hip --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=false --iree-hal-memoization=true --iree-opt-strip-assertions | ||
../iree-build-no-trace/tools/iree-compile 8b_f16_prefill_nondecomposed.mlir \ | ||
--iree-hip-target=gfx942 \ | ||
-o=prefill_8b.vmfb \ | ||
--iree-hal-target-device=hip \ | ||
--iree-dispatch-creation-enable-aggressive-fusion=true \ | ||
--iree-global-opt-propagate-transposes=true \ | ||
--iree-opt-aggressively-propagate-transposes=true \ | ||
--iree-opt-data-tiling=false \ | ||
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \ | ||
--iree-hal-indirect-command-buffers=true \ | ||
--iree-stream-resource-memory-model=discrete \ | ||
--iree-hip-legacy-sync=false \ | ||
--iree-hal-memoization=true \ | ||
--iree-opt-strip-assertions | ||
``` | ||
|
||
## 5. Benchmark command | ||
|
@@ -105,7 +133,10 @@ ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ | |
## Sharded | ||
Sharded - If you want to create your own tp8 sharded irpa files use this command: | ||
``` | ||
python3 -m sharktank.examples.sharding.shard_llm_dataset --irpa-file 8b_fp16.irpa --output-irpa 8b_fp16_tp8.irpa --tensor-parallelism-size 8 | ||
python3 -m sharktank.examples.sharding.shard_llm_dataset \ | ||
--irpa-file 8b_fp16.irpa \ | ||
--output-irpa 8b_fp16_tp8.irpa \ | ||
--tensor-parallelism-size 8 | ||
``` | ||
|
||
Larger sharded irpa files (e.g. 70b, 405b) will be stored in `sharkblobs` soon. Otherwise, you can copy the 70b/405b f16 sharded irpa files from the `SharkMi300x` machine (long copy time): | ||
|
@@ -116,7 +147,13 @@ scp [email protected]:/data/llama3.1/weights/405b/fp16/tp8/* . | |
Sharded - You need to use the unranked sharded irpa file to generate the sharded IR: | ||
|
||
``` | ||
python3 -m sharktank.examples.export_paged_llm_v1 --bs=4 --irpa-file=/shark-dev/405b/llama3.1_405b_fp16_tp8_parameters.irpa --output-mlir=405b_f16_prefill_tp8_nondecomposed.mlir --output-config=405b_f16_prefill_tp8_nondecomposed.json --attention-kernel=torch --skip-decode | ||
python3 -m sharktank.examples.export_paged_llm_v1 \ | ||
--bs=4 \ | ||
--irpa-file=/shark-dev/405b/llama3.1_405b_fp16_tp8_parameters.irpa \ | ||
--output-mlir=405b_f16_prefill_tp8_nondecomposed.mlir \ | ||
--output-config=405b_f16_prefill_tp8_nondecomposed.json \ | ||
--attention-kernel=torch \ | ||
--skip-decode | ||
``` | ||
|
||
Get the 8b f16 tp8 sharded numpy inputs: [get_8b_f16_tp8_numpy_inputs.sh](https://gist.github.com/aviator19941/9b3cd6511347e57671b7ff1da7c80bfa) | ||
|