Skip to content

Commit

Permalink
llama: reverting kv_cache in case of failed compute
Browse files Browse the repository at this point in the history
  • Loading branch information
Xarbirus committed Oct 14, 2024
1 parent 9edd061 commit 059e78c
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2815,6 +2815,42 @@ struct llama_kv_cache {
}
};

class llama_kv_cache_state {
struct llama_kv_cache_state_short {
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0;
uint32_t n = 0;

std::vector<llama_kv_cell> cells;
} old_state;

bool saved = false;

public:
void save_state(const llama_kv_cache& cache) {
old_state.head = cache.head;
old_state.size = cache.size;
old_state.used = cache.used;
old_state.n = cache.n;
old_state.cells = cache.cells;

saved = true;
}

void restore(llama_kv_cache& cache) {
if (saved) {
cache.head = old_state.head;
cache.size = old_state.size;
cache.used = old_state.used;
cache.n = old_state.n;
cache.cells = std::move(old_state.cells);

saved = false;
}
}
};

struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<struct ggml_context *> ctxs;
Expand Down Expand Up @@ -17184,6 +17220,7 @@ static int llama_decode_internal(
lctx.n_queued_tokens += n_tokens_all;

auto & kv_self = lctx.kv_self;
llama_kv_cache_state kv_cache_state_holder;

const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
Expand Down Expand Up @@ -17261,6 +17298,7 @@ static int llama_decode_internal(
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
kv_cache_state_holder.save_state(kv_self);

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
Expand Down Expand Up @@ -17318,16 +17356,17 @@ static int llama_decode_internal(
llama_set_inputs(lctx, ubatch);

const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
if (compute_status != GGML_STATUS_SUCCESS) {
kv_cache_state_holder.restore(kv_self);
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
}

// update the kv ring buffer
Expand Down

0 comments on commit 059e78c

Please sign in to comment.