Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: (proposal) propagating the results of graph_compute to the user interface #9525

Merged
merged 7 commits into from
Nov 13, 2024
4 changes: 2 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -797,15 +797,15 @@ extern "C" {
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
Expand Down
120 changes: 107 additions & 13 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init(
return true;
}

// a structure holds information about the slot found in llama_kv_cache_find_slot
struct llama_kv_cache_slot_info {
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
bool found = false; // the slot was found

explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}

operator bool() const { return found; }
};
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};

// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_ubatch & batch) {
const uint32_t n_tokens = batch.n_tokens;
Expand Down Expand Up @@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
return false;
return llama_kv_cache_slot_info_failed;
}
if (j > 0) {
llama_kv_cell & seq = cache.cells[seq_id];
Expand Down Expand Up @@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot(
// allow getting the range of used cells, from head to head + n
cache.head = min;
cache.n = max - min + 1;
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
[](const llama_kv_cell& cell){ return !cell.is_empty(); });

// sanity check
return cache.n >= n_seqs;
return llama_kv_cache_slot_info(cache.n >= n_seqs);
}
// otherwise, one cell per token.

if (n_tokens > cache.size) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
return false;
return llama_kv_cache_slot_info_failed;
}

uint32_t n_tested = 0;
Expand Down Expand Up @@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot(

if (n_tested >= cache.size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false;
return llama_kv_cache_slot_info_failed;
}
}

Expand All @@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot(

cache.used += n_tokens;

return true;
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
}

// find how many cells are currently in use
Expand Down Expand Up @@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
return cparams.flash_attn ? 256u : 32u;
}

// saves the kv_cache state for future recovery.
// used to rollback llama_kv_cache_find_slot changes.
struct llama_kv_slot_restorer {
struct llama_kv_cache_state {
uint32_t head = 0;
uint32_t n = 0;
} old_state;

// for non-recurrent models only
// list of slots to restore
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;

bool do_restore = false;

explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}

// saves a slot information for future restoration
void save(const struct llama_kv_cache_slot_info & slot) {
if (slot) {
do_restore = true;
if (slot.boundaries.first != slot.boundaries.second) {
slot_boundaries.push_back(slot.boundaries);
}
}
}

// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore(struct llama_kv_cache & cache) {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;

if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
llama_kv_cache_seq_rm(cache, -1, -1, -1);
} else {
for (auto & slot : slot_boundaries) {
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
}
}
}
}
};

//
// model loading and saving
//
Expand Down Expand Up @@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
}
}

static void llama_graph_compute(
// returns the result of ggml_backend_sched_graph_compute_async execution
static enum ggml_status llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
int n_threads,
Expand All @@ -17196,15 +17259,20 @@ static void llama_graph_compute(
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}

auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
if (err != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}

// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));

return status;
}

// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - batch: batch to evaluate
Expand Down Expand Up @@ -17254,6 +17322,7 @@ static int llama_decode_internal(
lctx.n_queued_tokens += n_tokens_all;

auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);

const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
Expand Down Expand Up @@ -17338,9 +17407,11 @@ static int llama_decode_internal(
kv_self.head = 0;
}

if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
return 1;
}
kv_slot_restorer.save(slot);

if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
Expand Down Expand Up @@ -17387,7 +17458,19 @@ static int llama_decode_internal(

llama_set_inputs(lctx, ubatch);

llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
if (compute_status != GGML_STATUS_SUCCESS) {
kv_slot_restorer.restore(kv_self);
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
}

// update the kv ring buffer
{
Expand Down Expand Up @@ -17624,7 +17707,18 @@ static int llama_encode_internal(

llama_set_inputs(lctx, ubatch);

llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}

// extract embeddings
if (embd) {
Expand Down
Loading