Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking running llama models through IREE #22

Open
ScottTodd opened this issue May 7, 2024 · 17 comments
Open

Tracking running llama models through IREE #22

ScottTodd opened this issue May 7, 2024 · 17 comments
Assignees

Comments

@ScottTodd
Copy link
Collaborator

ScottTodd commented May 7, 2024

Goal

Run a llama model from https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/models/llama/llama.py through IREE

Starting with open_llama_3b_v2_f16_gguf since we have that in docs. Could try another model or data type but should eventually all sorts of variants working.

Approach

https://github.com/nod-ai/sharktank/tree/main/sharktank/sharktank/examples has a few files already:

file description
paged_llm_v1.py Run LLM (from GGUF or hyperparameter config + parameter weights) in PyTorch
export_paged_llm_v1.py Export LLM to a .mlir file for IREE

Next steps from there could be

  1. Compile the .mlir file using iree-compile and run it using iree-run-module
  2. Add an IREE version of paged_llm_v1.py that could either
    • Export (e.g. from GGUF) -> compile -> run, all in-process
    • Compile from .mlir -> run
    • Take an already compiled .vmfb and run it

Worklog

Export -> try compile entire program ("prefill" and "decode")

Next: continue triaging compilation errors for prefill.

Export and run just "decode"

  • Since the compilation issues happened with "prefill" and not "decode", I tried exporting a version of the program with just "decode" by commenting out this line: https://sharkpublic.blob.core.windows.net/sharkpublic/scotttodd/issue_reports/open_llama_3b_v2_f16_decode_only.mlir
  • I was able to compile that to a .vmfb file: iree-compile open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-unknown-unknown -o /tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --iree-hal-executable-debug-level=3
  • To run the program with iree-run-module I need the inputs and a parameter file
    • Inputs should be something like --input=4xi64 --input=4xi64 --input=4xi64 --input=4xi64 --input=1x2662400xf32 (need to verify)
    • For parameters, I downloaded GGUF files with huggingface-cli download --local-dir /tmp/open_llama_3b_v2_gguf SlyEcho/open_llama_3b_v2_gguf (that folder then contains /tmp/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf)
  • Trying to run: iree-run-module --module=/tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --device=vulkan --input=4xi64 --input=4xi64 --input=4xi64 --input=4xi64 --input=1x2662400xf32 --parameters=model=/tmp/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf produces this error: iree\runtime\src\iree\io\formats\gguf\gguf_parser.c:678: UNIMPLEMENTED; GGUF format version 2 is unsupported; expected version 3

Next: try upgrading GGUF version 2 to 3? Load from safetensors? Convert to IRPA?

@ScottTodd ScottTodd self-assigned this May 7, 2024
@pashu123
Copy link

pashu123 commented May 8, 2024

For spriv-vulkan backend here's the minimal repro

func.func @torch_add(%arg0: !torch.vtensor<[1,1,?,?],i1>, %arg1: !torch.vtensor<[4,1,1,?],i1>) -> !torch.vtensor<[4, 1, ?, ?],i1> {
   %int1 = torch.constant.int 1
   %2 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[1,1,?,?],i1>, !torch.vtensor<[4,1,1,?],i1>, !torch.int -> !torch.vtensor<[4,1,?,?],i1>
   return %2 : !torch.vtensor<[4,1,?,?],i1>
 }

error: spirv.IAdd op operand #0 must be 8/16/32/64-bit integer but got i1 .

@pashu123
Copy link

pashu123 commented May 8, 2024

Pulling in some of the comments from the chat.
For the CPU backend there are two options:

  1. Use the iree-opt-demote-i64-to-32 flag; these models deal with large no. and truncating might not be the good strategy.
  2. Use the --iree-opt-strip-assertions flag; assertions hanging around, strips them and compiles the model.

For spirv-vulkan backend I have posted the minimal repro above.

@stellaraccident
Copy link
Contributor

I think that assert can be safely dropped at the torch level in the same way as the broadcast asserts: when in strict mode from torch, the invariant being checked for dynamic legality must be true (torch enforces it).

@ScottTodd
Copy link
Collaborator Author

ScottTodd commented May 8, 2024

Thanks, I'm also able to compile for llvm-cpu with --iree-opt-strip-assertions.
edit: specifically with llvm/torch-mlir#3277 too

@ScottTodd
Copy link
Collaborator Author

Compilation correctness

  • Still need the torch-mlir patch to avoid the compiler OOM-ing/crashing/never finishing when compiling the "prefill" stage on llvm-cpu
  • Can compile "prefill" for llvm-cpu with --iree-opt-strip-assertions (sounds like we should fix the frontend to omit asserts from aten.view)
  • Still failing to compile "prefill" for vulkan-spirv (some i1 handling)

GGUF version 2 vs version 3

  • Confirmed GGUF version 2 on https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf/tree/main?show_file_info=open-llama-3b-v2-f16.gguf

  • Found version upgrade instructions at https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#prepare-and-quantize, requires building llama.cpp/examples/quantize.cpp from source?

  • Found version 3 in https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/tree/main?show_file_info=Meta-Llama-3-8B.Q8_0.gguf, going to try that: huggingface-cli download --local-dir /tmp/huggingface/llama3_8B QuantFactory/Meta-Llama-3-8B-GGUF Meta-Llama-3-8B.Q8_0.gguf

    • Export fails for that (I think someone else was specifically looking at llama-3), see errors:

      Click to expand full stderr output

      
      python -m sharktank.examples.export_paged_llm_v1 --hf-dataset=llama3_8B_q8_0 --output=/tmp/llama3_8B_q8_0.mlir
      

      Exporting decode_bs4
      Traceback (most recent call last):
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_dynamo\utils.py", line 1766, in run_node
      return getattr(args[0], node.target)(*args[1:], **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\utils_stats.py", line 20, in wrapper
      return fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 896, in torch_dispatch
      return self.dispatch(func, types, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 1241, in dispatch
      return self.cached_dispatch_impl(func, types, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 974, in cached_dispatch_impl
      output = self.dispatch_impl(func, types, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 1393, in dispatch_impl
      return decomposition_table[func](*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_refs_init
      .py", line 4547, in view
      return reshape_view_helper(a, *shape, allow_copy=False)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_refs_init
      .py", line 3629, in reshape_view_helper
      shape = utils.infer_size(shape, a.numel())
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_prims_common_init
      .py", line 891, in infer_size
      if d == -1:
      ^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_init
      .py", line 374, in bool
      return self.node.bool
      ()
      ^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\sym_node.py", line 432, in bool

      return self.guard_bool("", 0)
      ^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\sym_node.py", line 374, in guard_bool
      r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\recording.py", line 231, in wrapper
      return fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 4138, in evaluate_expr
      raise self._make_data_dependent_error(
      torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, -1) (unhinted: Eq(u0, -1)). (Size-like symbols: none)

      Potential framework code culprit (scroll up for full backtrace):
      File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_prims_common_init_.py", line 891, in infer_size
      if d == -1:

  • https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF is also version 2

  • Looked into versions, found that GGUF version 3 just "added" big endian support, so we should be able to support version 2 and version 3. Trying that with open_llama_3b_v2_gguf again

Running just decode, with zeroed arguments:

Vulkan:

iree-run-module --module=/tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf

EXEC @decode_bs4
D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:363: INVALID_ARGUMENT; source and target ranges overlap within the same buffer; stack:
  0x00007ff6e1f7238f iree-run-module <iree_hal_command_buffer_copy_buffer_validation+0x23f> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:361)
  0x00007ff6e1f67e18 iree-run-module <iree_hal_command_buffer_copy_buffer+0xa8> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer.c:458)
  0x00007ff6e1f00a72 iree-run-module <iree_hal_module_command_buffer_copy_buffer+0xc2> (D:\dev\projects\iree\runtime\src\iree\modules\hal\module.c:798)
  0x00007ff6e1f16642 iree-run-module <iree_vm_shim_rrIrII_v+0x82> (D:\dev\projects\iree\runtime\src\iree\vm\shims.c:65)
  0x00007ff6e1f19754 iree-run-module <iree_vm_native_module_issue_call+0x84> (D:\dev\projects\iree\runtime\src\iree\vm\native_module.c:342)

CPU (local-task): assert hit, --trace-execution output: https://gist.github.com/ScottTodd/8c215d943f6f27fa480a8ba5ed328cb3

iree-run-module --module=/tmp/open_llama_3b_v2_f16_cpu.vmfb --device=local-sync --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf --function=decode_bs4 --trace-execution

...
[module.decode_bs4$async+000410C2]    %r0 = vm.call @hal.command_buffer.create(%r266(!hal.device/0x0000015AD90B22D0), %i206(1), %i206(1), %i83(0))
[module.decode_bs4$async+000410D6]    vm.call @hal.command_buffer.copy_buffer(%r0(!hal.command_buffer/0x0000015CAECF9450), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %i100(10649600))

--- assert hit ---
ucrtbase.dll!00007ff8674d286e() (Unknown Source:0)
iree-run-module.exe!iree_abort() Line 26 (d:\dev\projects\iree\runtime\src\iree\base\assert.h:26)
iree-run-module.exe!iree_vm_buffer_deinitialize(iree_vm_buffer_t * buffer) Line 79 (d:\dev\projects\iree\runtime\src\iree\vm\buffer.c:79)
iree-run-module.exe!iree_vm_bytecode_module_destroy(void * self) Line 152 (d:\dev\projects\iree\runtime\src\iree\vm\bytecode\module.c:152)
iree-run-module.exe!iree_vm_context_release_modules(iree_vm_context_t * context, unsigned __int64 start, unsigned __int64 end) Line 288 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:288)
iree-run-module.exe!iree_vm_context_destroy(iree_vm_context_t * context) Line 362 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:362)
iree-run-module.exe!iree_tooling_run_module_with_data(iree_vm_instance_t * instance, iree_string_view_t default_device_uri, iree_const_byte_span_t module_contents, iree_allocator_t host_allocator, int * out_exit_code) Line 422 (d:\dev\projects\iree\runtime\src\iree\tooling\run_module.c:422)

Next: figure out the runtime errors. Miscompile? Going over some runtime limits? local-sync and local-task have different errors. Look at the VM IR and see if anything stands out.

@ScottTodd
Copy link
Collaborator Author

I created a mock version of open_llama_3b_v2_f16.mlir here: https://gist.github.com/ScottTodd/ee0cd9d6ab80e4814edad353235cf664. That just returns 1 for all values (no math/kernels/etc.).

Compile with:

iree-compile \
  mock_open_llama_3b_v2_f16.mlir \
  --iree-hal-target-backends=llvm-cpu \
  -o mock_open_llama_3b_v2_f16_cpu.vmfb

Run prefill with:

iree-run-module \
  --module=mock_open_llama_3b_v2_f16_cpu.vmfb \
  --device=local-sync \
  --function=decode_bs4 \
  --input=4x1xi64 \
  --input=4xi64 \
  --input=4x1xi64 \
  --input=1x2662400xf32 \
  --parameters=model=open-llama-3b-v2-f16.gguf

Run decode with:

iree-run-module \
  --module=mock_open_llama_3b_v2_f16_cpu.vmfb \
  --device=local-sync \
  --function=decode_bs4 \
  --input=4x1xi64 \
  --input=4xi64 \
  --input=4xi64 \
  --input=4x1xi64 \
  --input=1x2662400xf32 \
  --parameters=model=open-llama-3b-v2-f16.gguf

I'm planning on loading that into Python and standing up an IREE version of https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/examples/paged_llm_v1.py . Once the real model compiles, I'll substitute it.

@stellaraccident
Copy link
Contributor

stellaraccident commented May 10, 2024

Thanks, I'm also able to compile for llvm-cpu with --iree-opt-strip-assertions. edit: specifically with llvm/torch-mlir#3277 too

This upstream patch removes these assertions and implements a more direct lowering (no more switchy stuff): llvm/torch-mlir#3319

@ScottTodd
Copy link
Collaborator Author

Latest attempt:


Compile for Vulkan: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-unknown-unknown -o /tmp/open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3

Run on Vulkan: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf

Vulkan output:

EXEC @decode_bs4
result[0]: hal.buffer_view
4x1x32000xf32=[[NAN NAN NAN NAN NAN NAN NAN ...

Compile for CPU: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=llvm-cpu -o /tmp/open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3

Run on CPU: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --device=local-task --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf

CPU crashes inside a dispatch (iree_elf_call_i_ppp).


Will trace execution and look at individual dispatches to go deeper.

@ScottTodd
Copy link
Collaborator Author

Currently debugging a runtime crash in decode still with @rsuderman .

We're suspecting that the in-place scatter operations are writing out of bounds. The exported programs had a sequence of scatters back to back so Rob has a branch (https://github.com/rsuderman/sharktank/tree/rework_update) that makes the key value store updates use a single scatter (if I'm understanding correctly). The model fails to compile after those changes.

I have a reduced test case of just a single index_put_ (in place operation that lowers to scatter and uses torch.overwrite.tensor.contents) here: https://gist.github.com/ScottTodd/df0d426a351a6737e16f507b187a210b . Looks like an issue in the torch-mlir lowering since it reproduces with torch-mlir-opt --pass-pipeline="builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))"

@ScottTodd
Copy link
Collaborator Author

A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493

some debugging around the callstack also pointed at scatter ops:

--- areOpsFusable ---
  producer:
%48 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_27, %47 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%expanded_21 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
  iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
  consumer:
%51 = iree_linalg_ext.scatter {__root_op__ = 17 : i64} dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_35, %50 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%48 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
  iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>

I'm not sure if that is worth debugging further, may have been a buggy test case reduction. Going to follow up on the minimal index_put_ compilation error above next.

@ScottTodd
Copy link
Collaborator Author

Filed llvm/torch-mlir#3433 for the index_put_ lowering that fails. Not sure if that's unique to our reduced test cases or if it appears in the full model too. Building out more test coverage and confidence in the operations e2e will help anyways.

@ScottTodd
Copy link
Collaborator Author

A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493

This occurs in the full model too. Can work around it by disabling all dispatch region fusions (add a return false around here). Should file a reproducer upstream - the compiler must not crash (assert) on valid input. If the input is invalid then we'd need to update the frontend (torch-mlir / iree-turbine / sharktank).

@ScottTodd
Copy link
Collaborator Author

Compiling with --iree-input-demote-i64-to-i32 works around the runtime crash with the decode() function. We're trying to update the model definition (in the Python source) to use i32 while also digging into why the runtime crashes with i64.

@ScottTodd
Copy link
Collaborator Author

Tried to change dtypes in the model from i64 to i32 (https://github.com/nod-ai/sharktank/compare/main...ScottTodd:llama-i32?expand=1), ran into errors compiling after export like this:

~/iree-build/tools/iree-compile ~/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir -o ~/scratch/open_llama_3b_v2_f16_i32more3_1block_asan.vmfb --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-link-embedded=false --iree-llvmcpu-sanitize=address
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: error: 'arith.cmpi' op requires all operands to have the same type
    %43 = torch.aten.index.Tensor %0, %42 : !torch.vtensor<[2048,50],complex<f32>>, !torch.list<optional<vtensor>> -> !torch.vtensor<[4,1,50],complex<f32>>
          ^
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: note: see current operation: %3741 = "arith.cmpi"(%arg275, %3740) <{predicate = 2 : i64}> : (i32, i64) -> i1

It sounds like iree-org/iree#17696 fixes decode crashes while still using i32 types.

@ScottTodd
Copy link
Collaborator Author

Confirmed that these patches help

All together, I see decode appearing to work (outputs appear sensible and aligned with prefill). Can continue to validate.

@ScottTodd
Copy link
Collaborator Author

ScottTodd commented Jun 20, 2024

Ideas for next steps / follow-up tasks:

@ScottTodd
Copy link
Collaborator Author

Still seeing a crash in decode on Windows with these args:

iree-run-module \
  --module=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb \
  --function=decode_bs4 \
  --device=local-task \
  --input=4x1xi64=0 \
  --input=4xi64=1 \
  --input=4xi64=1 \
  --input=4x1xi64=0,1,2,3 \
  --input=1x2662400xf16 \
  --parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf

I'll wrap all my repro steps (documented here: #69) into a script and run that script across my machines. Hopefully just a case of needing the cache (that --input=1x2662400xf16 arg) to be populated.

ScottTodd added a commit that referenced this issue Jun 27, 2024
Progress on #22

TODOs sprinkled throughout. Immediate next steps I'm considering:

1. Add a test / CI workflow that follows these steps (likely in Bash to
start, then later in Python once more pieces are connected seamlessly)
2. Sanity check with other models from Hugging Face, other IREE backend
targets, different batch sizes, etc. (could parameterize a test script
on those options 🤔)
3. Extract some real inputs/outputs for use with `iree-run-module` then
plug in to
https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests to get
presubmit coverage for `iree-compile` (guarding against compilation
correctness regressions in LLVM integrates and other changes)
ScottTodd added a commit that referenced this issue Jun 28, 2024
Progress on #22

Sample runs on my fork:
* https://github.com/ScottTodd/sharktank/actions/runs/9670685134
* https://github.com/ScottTodd/sharktank/actions/runs/9715408887

I decided to run this on a nightly `schedule` and on
`workflow_dispatch`. It takes around 10 minutes so it _could_ run on
`pull_request` if we want too.

As these components stabilize and we spend less time hacking on
individual steps using the full toolkit (python -> manual `iree-compile`
vs. using the in-process compiler API) we can switch the test from a
bash script to a pytest file. Need to start somewhere :)
ScottTodd added a commit to nod-ai/SHARK-TestSuite that referenced this issue Jun 28, 2024
Progress on nod-ai/sharktank#22

This adds one test for a llama model running through
https://github.com/nod-ai/sharktank. That project is still getting set
up, so new docs for this particular workflow are coming in at
nod-ai/sharktank#69 and tests in that repo are
in nod-ai/sharktank#70.

Specifically, this exercises:
*
[`sharktank/models/llama/llama.py`](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/models/llama/llama.py)
*
[`sharktank/examples/export_paged_llm_v1.py`](https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/examples/export_paged_llm_v1.py)
with batch sizes == [4]
* The `open-llama-3b-v2-f16.gguf` file from
https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf
* Compilation and crashless execution, _not_ numerical correctness (yet)

Ideas for future work:

* Test cases for the same model/parameters
  * Other batch sizes
  * `decode()` as well as `prefill()`
* Real inputs with expected outputs (`decode()` crashes on some faked
inputs still 🤔)
* Other flag combinations and target configurations (starting simple
though)
* Test cases for other models/parameters
  * 8b / 70b parameter models
  * Mistral, Mixtral, Gemma, etc.
ScottTodd added a commit to iree-org/iree that referenced this issue Jul 1, 2024
Progress on nod-ai/sharktank#22. See
nod-ai/SHARK-TestSuite#272 for the specifics of
what the new test is exercising.

The "models" tests now include `pytorch/models/` and `sharktank/`, so
all test names are qualified relative to `iree_tests/` in the test suite
repo. (Totally inflating my commit stats here, sorry :P)

ci-exactly: build_packages,regression_test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants