From 10c477b8a84a3fd1a48377e9effff4c48f45f9b7 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 3 Mar 2024 18:58:42 +0100 Subject: [PATCH] implement slerp --- llama.cpp | 58 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/llama.cpp b/llama.cpp index 102657ac95600..c786bc5778aa4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -11389,12 +11389,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) { llm_load_arch(*ml, *model); llm_load_hparams(*ml, *model); - if (i > 0 && models[i-1]->hparams != model->hparams) { - LLAMA_LOG_ERROR("hparams of input models are different, aborting..."); - clean_up(); - return -1; - } - models.push_back(std::move(model)); mls.push_back(std::move(ml)); } @@ -11521,6 +11515,8 @@ int32_t llama_merge_models(const struct llama_merge_config * config) { std::vector> f32_in_buf1; // dequant it internally std::vector f32_out_buf(n_elements, 0.0); // do not resize! std::vector out_buf(ggml_nbytes(out_tensor)); // do not resize! + const int n_per_row = out_tensor->ne[0]; + const int n_rows = n_elements / n_per_row; if (ins.method == LLAMA_MERGE_COPY) { LLAMA_LOG_INFO("copy\n"); @@ -11565,13 +11561,49 @@ int32_t llama_merge_models(const struct llama_merge_config * config) { } if (ins.method == LLAMA_MERGE_SLERP) { + // Python code: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c LLAMA_LOG_INFO("slerp "); - float * in0 = (float *) f32_in_buf0.data(); - float * in1 = (float *) f32_in_buf1.data(); - float * dest = (float *) f32_out_buf.data(); - for (size_t i = 0; i < n_elements; i++) { - //dest[i] = in0[i] * ins.t + in1[i] * 0; - dest[i] = in0[i]; + static const float dot_threshold = 0.9995; + auto lerp_row = [](float * in0, float * in1, float * out, size_t nelem, float t) { + for (size_t i = 0; i < nelem; i++) { + out[i] = in0[i] * (1.0 - t) + in1[i] * t; + } + }; + auto slerp_row = [&lerp_row](float * in0, float * in1, float * out, size_t nelem, float t) { + float norm0 = std::sqrt(std::inner_product(in0, in0 + nelem, in0, 0.0)); + float norm1 = std::sqrt(std::inner_product(in1, in1 + nelem, in1, 0.0)); + // Normalize the vectors to get the directions and angles + std::vector v0(nelem); + std::vector v1(nelem); + for (size_t i = 0; i < nelem; i++) { + v0[i] = in0[i] / norm0; + v1[i] = in1[i] / norm1; + } + // Dot product with the normalized vectors + float dot = std::inner_product(v0.begin(), v0.end(), v1.begin(), 0.0); + // If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp + if (std::abs(dot) > dot_threshold) { + return lerp_row(in0, in1, out, nelem, t); + } + // Calculate initial angle between v0 and v1 + float theta_0 = std::acos(dot); + float sin_theta_0 = std::sin(theta_0); + // Angle at timestep t + float theta_t = theta_0 * t; + float sin_theta_t = std::sin(theta_t); + // Finish the slerp algorithm + float s0 = std::sin(theta_0 - theta_t) / sin_theta_0; + float s1 = sin_theta_t / sin_theta_0; + for (size_t i = 0; i < nelem; i++) { + out[i] = in0[i] * s0 + in1[i] * s1; + } + }; + for (int r = 0; r < n_rows; r++) { + float * in0 = (float *) f32_in_buf0.data(); + float * in1 = (float *) f32_in_buf1.data(); + float * dest = (float *) f32_out_buf.data(); + size_t offset = n_per_row * r; + slerp_row(in0 + offset, in1 + offset, dest + offset, n_per_row, ins.t); } } @@ -11579,8 +11611,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) { { LLAMA_LOG_INFO("requant\n"); std::array hist_cur = {}; - const int n_per_row = out_tensor->ne[0]; - const int n_rows = n_elements / n_per_row; static const int min_chunk_size = 32 * 512; const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row); size_t new_size = llama_tensor_quantize_internal(