* Almost working except mask, need to rebase to main to pick up the the ring buffer support then fix the mask. Int8 updates also included but not tested.
* Fixed the test_model_impl for llama, but test_llama_e2e is still failing.
* Adds lazy_cache_update and restructure the cache flags.
* Disable all the prints. Fix create engine.
* Fix typos and minor errors.
* Fixes create engine.
* Adds new_cache_stacked and fixes cache update.
* Fix cache update when new_cach_stacked is False.
* Fix the cache manager and make unit tests pass except for 1.
* Updates the exportable model to return cache.
* Removed the fori loop in cache finalize. Moves the cache.finalize() to the end of existing cache attention.
* Try to use shard_map for cache update.
* Fix update single cache line in cache.finalize()
* Adds int8 support.
* Int8 left aligned lazy cache update working, performance still not good enough.
* Fix the stacked cache introduced in the previous couple of commits.
* Put original ragged attention back.
* Add the original ragged attention kernel.
* Fixes the bf16/int8 cache stack.
* Fix int8 stacked cache insertion in engine and finalization.
* Fixes int8 with lazy cache update.
* Updates the int8 test.
* Fix the int8 ragged attention output sharding.
* Fix group query attention broadcasting issue.
* Fix shard map input issue. Variables not listed as inputs are freezed into jit function.
* Fix the flash attention mask shape; Fix the update single cache line quant version
* Adds the kv cache test.
* Replace quantized cache "pos" with "input_pos" to align with bf16 cache. Fix the kv cache quantization test.
* Fix prefill cache insertion issue for stacked cache; Changes reduce dim for quantization from 1,3 to -3,-1 to make it more robust;
* Adds lazy cache update with generate cache stacked new cache unstacked for performance validation.
* Fix the shard map sharding for stacked generate cache and unstacked new cache.
* Using Jax API to slicing instead of Pytorch index slicing.
* Adds stacked cache support in ragged attention reference kernel.
* Adds stacked cache support for the modified ragged kernel.
* Llama2 70b int8 optimization done. Output not correct yet.
* Remove testing temp output files.
* Fix the llama 70b output accuracy resulting from gqa.
* Fixes the attention output slicing issue when not using flash attention. Refactor to use only 1 flash attention kernel. Changes the modified ring buffer ragged attention kernel with quantization, layer, etc.
* Fix the pallas kernel OOB issue
* Fix tests; Fix lint issues;
* Fix the interactive script.
* Fix lint errors.
* Fix errors.
* Fix the comments.
* Fix based on comments; Fix all the unit tests.
* Fix the remaining pylint errors.
* Default ring buffer back to true so that all the test_run_server and run_interactive in CPU mode can work. When we default ring buffer to false, should add additional flags to run_interactive CI to set test mode to true so that pallas kernel can run.
* Fix all the lint errors.
* Remove the deps/JetStream changes.
* Fix merge errors, fix lint errors.