From 059e78cf2d9262ceea0d01a973848b699e01fd84 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 24 Sep 2024 21:12:47 +0200 Subject: [PATCH] llama: reverting kv_cache in case of failed compute --- src/llama.cpp | 59 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index ea7c417340fa62..073a27a3225ab2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2815,6 +2815,42 @@ struct llama_kv_cache { } }; +class llama_kv_cache_state { + struct llama_kv_cache_state_short { + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; + uint32_t n = 0; + + std::vector cells; + } old_state; + + bool saved = false; + +public: + void save_state(const llama_kv_cache& cache) { + old_state.head = cache.head; + old_state.size = cache.size; + old_state.used = cache.used; + old_state.n = cache.n; + old_state.cells = cache.cells; + + saved = true; + } + + void restore(llama_kv_cache& cache) { + if (saved) { + cache.head = old_state.head; + cache.size = old_state.size; + cache.used = old_state.used; + cache.n = old_state.n; + cache.cells = std::move(old_state.cells); + + saved = false; + } + } +}; + struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -17184,6 +17220,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_cache_state kv_cache_state_holder; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17261,6 +17298,7 @@ static int llama_decode_internal( // non-causal masks do not use the KV cache if (hparams.causal_attn) { llama_kv_cache_update(&lctx); + kv_cache_state_holder.save_state(kv_self); // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it @@ -17318,16 +17356,17 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (compute_status != GGML_STATUS_SUCCESS) { + kv_cache_state_holder.restore(kv_self); + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } } // update the kv ring buffer