From 9bd96e27f7def550fa0f0cd8bf85aed1b84f5586 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 21 Feb 2024 18:07:43 -0500 Subject: [PATCH] handle cache_position kwarg in updated llama modeling --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4bded9b027..4d96913c52 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -688,6 +688,9 @@ def llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[ # pylint: disable=unused-argument + torch.LongTensor + ] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions