Skip to content

Commit

Permalink
minor: calculate total VRAM offloading via FFN splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
hodlen committed Dec 19, 2023
1 parent 83cb2fb commit 9c35efe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 46 deletions.
73 changes: 28 additions & 45 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2729,7 +2729,7 @@ struct llama_augmentation_model_loader {
int model_layer = model->layers.size();
int ffn_dim = model->layers[0].ffn_up->ne[1];
const size_t ggml_aux_tensor_size = 4 * (100 * 100 + model_layer*ffn_dim*sizeof(float) * ggml_tensor_overhead() );
printf("augmentation buffer: %ld\n", ggml_aux_tensor_size);

struct ggml_init_params params = {
/*.mem_size =*/ ggml_aux_tensor_size,
/*.mem_buffer =*/ nullptr,
Expand Down Expand Up @@ -2781,37 +2781,29 @@ struct llama_augmentation_model_loader {
#endif
}

void slice_ffn_mat_to_gpu(llama_layer & layer) {
size_t slice_ffn_mat_to_gpu(llama_layer & layer) {
std::vector<uint8_t> work_buffer;
ggml_cgraph * tmp_sum_gf = ggml_new_graph(aux_ctx);
ggml_tensor * gpu_idx = layer.gpu_idx;

// calculate the size of tensor to be copied
ggml_tensor * sum_t = ggml_sum(aux_ctx, gpu_idx);
ggml_build_forward_expand(tmp_sum_gf, sum_t);
ggml_graph_compute_helper(work_buffer, tmp_sum_gf, 2);
int64_t gpu_rows = *ggml_get_data_i32(sum_t);


int64_t gpu_index_len = gpu_idx->ne[0];
// ggml_tensor * gpu_bucket = ggml_new_tensor_1d(aux_ctx, GGML_TYPE_I32, gpu_rows);
// make bucket a reverse index back to unstriped mat
// int32_t * pbucket_data = (int32_t *)gpu_bucket->data;
// for (int i = 0; i < gpu_index_len; i++) {
// if (ggml_get_data_i32(gpu_idx)[i] == 0) {
// continue;
// }
// *pbucket_data = i;
// ++pbucket_data;
// }
// layer.gpu_bucket = gpu_bucket;
ggml_tensor *gpu_bucket = layer.gpu_bucket;
size_t offloaded_bytes = 0;

layer.ffn_gate_gpu = create_striped_mat_to_gpu(layer.ffn_gate, gpu_bucket);
layer.ffn_up_gpu = create_striped_mat_to_gpu(layer.ffn_up, gpu_bucket);
layer.ffn_down_gpu = create_striped_mat_to_gpu(layer.ffn_down_t, gpu_bucket);
layer.ffn_down_gpu = create_striped_mat_to_gpu(layer.ffn_down, gpu_bucket);

if (layer.ffn_gate_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_gate_gpu);
}
if (layer.ffn_up_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_up_gpu);
}
if (layer.ffn_down_gpu) {
offloaded_bytes += ggml_nbytes(layer.ffn_down_gpu);
}
return offloaded_bytes;
}

int apply_augmentation_to_base_model(llama_model * model) {
size_t offload_ffn_split(llama_model * model) {
LLAMA_LOG_INFO("%s: applying augmentation to model - please wait ...\n", __func__);
const int64_t t_start_aug_us = ggml_time_us();
std::vector<uint8_t> work_buffer;
Expand All @@ -2823,6 +2815,7 @@ struct llama_augmentation_model_loader {
#endif

// load gpu_idx and slice mat to gpu
size_t offloaded_bytes = 0;
for (llama_layer &model_layer : model -> layers) {
// gpu_idx load
if (model_layer.gpu_idx == NULL && model_layer.gpu_bucket == NULL) {
Expand All @@ -2832,12 +2825,12 @@ struct llama_augmentation_model_loader {
ggml_tensor * gpu_bucket = ggml_new_tensor_1d(aux_ctx, GGML_TYPE_I32, 0);
model_layer.gpu_bucket = gpu_bucket;
}
slice_ffn_mat_to_gpu(model_layer);
offloaded_bytes += slice_ffn_mat_to_gpu(model_layer);
LLAMA_LOG_INFO(".");
}

LLAMA_LOG_INFO(" done (%.2f ms)\n", (ggml_time_us() - t_start_aug_us) / 1000.0);
return 0;
return offloaded_bytes;
}
};

Expand Down Expand Up @@ -2941,22 +2934,17 @@ static bool llm_load_gpu_split_with_budget(llama_model_loader & ml, llama_model
return load_gpu_split_from_split_file(model, cached_split_path, vram_allocatable_bytes);
}

static void llm_generate_empty_gpu_split(llama_model_loader &ml, llama_model & model) {
// TODO: move code here & remove augmentation ml
llama_model_apply_augmentation(&model);
}

static void llm_load_gpu_split(llama_model_loader & ml, llama_model & model, size_t vram_budget_bytes, bool no_cache) {
#if defined(GGML_USE_CUBLAS)
if (vram_budget_bytes >= 256ull * 1024 * 1024) {
if (llm_load_gpu_split_with_budget(ml, model, vram_budget_bytes, no_cache)) {
return;
}
if (!llm_load_gpu_split_with_budget(ml, model, vram_budget_bytes, no_cache)) {
LLAMA_LOG_ERROR("%s: error: failed to generate gpu split, an empty one will be used\n", __func__);
}
}
#endif
// Fall back to the empty GPU split
llm_generate_empty_gpu_split(ml, model);
// Apply GPU index and split FFNs to GPU
size_t ffn_offloaded_bytes = llama_model_offload_ffn_split(&model);
LLAMA_LOG_INFO("%s: offloaded %.2f MiB of FFN weights to GPU\n", __func__, ffn_offloaded_bytes / 1024.0 / 1024.0);
}

static void llm_load_sparse_model_tensors(
Expand Down Expand Up @@ -9670,15 +9658,10 @@ int llama_model_apply_gpu_idx_from_file(struct llama_model * model, const char *
return 0;
}

// Apply postprocessing steps for PowerInfer derived models
int llama_model_apply_augmentation(struct llama_model * model) {
size_t llama_model_offload_ffn_split(struct llama_model * model) {
llama_augmentation_model_loader * aug_ml = new llama_augmentation_model_loader(model);
if (aug_ml -> apply_augmentation_to_base_model(model) > 0) {
LLAMA_LOG_ERROR("%s: failed to apply augmentation adapter\n", __func__);
return 1;
}
model -> aug_model_loader = std::unique_ptr<llama_augmentation_model_loader>(aug_ml);
return 0;
size_t offloaded_bytes = aug_ml->offload_ffn_split(model);
return offloaded_bytes;
}

int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ extern "C" {
const char * path_mlp,
bool use_mmap);

LLAMA_API int llama_model_apply_augmentation(struct llama_model * model);
LLAMA_API size_t llama_model_offload_ffn_split(struct llama_model * model);

//
// KV cache
Expand Down

0 comments on commit 9c35efe

Please sign in to comment.