diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 43242e0c0..900c1a9ae 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -74,7 +74,7 @@ def main(): tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. - kv_cache_type="paged" if args.bs == [1] else "paged", + kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, ) llama_config.fake_quant = args.fake_quant diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 6cf79402e..88f5c344c 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -144,7 +144,7 @@ class LlamaModelConfig: # Block sequence stride for a paged KV cache. This must divide evenly # into the context length. - block_seq_stride: int = 32 + block_seq_stride: int = 16 # Either "paged" or "direct". kv_cache_type: str = "paged"