From 5909ea1f670d3018ef84623d63f2976f2a17cf08 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 20 Mar 2024 17:30:22 -0700 Subject: [PATCH] Fix dynamic RoPE (#350) --- server/lorax_server/utils/layers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 3bf042c42..091ca0464 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -918,6 +918,14 @@ def get_cos_sin( """ Return cos and sin for the asked position ids """ + + # When using dynamic position embeddings, the max sequence length might exceed + # the max position embeddings of the base model, so we need to update our + # cache during warmup. + # This should never result in a change after warmup, otherwise we break + # cuda graphs. + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) return cos.unsqueeze(1), sin.unsqueeze(1)