Skip to content

Commit

Permalink
implement slerp
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Mar 3, 2024
1 parent a032bb6 commit 10c477b
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11389,12 +11389,6 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
llm_load_arch(*ml, *model);
llm_load_hparams(*ml, *model);

if (i > 0 && models[i-1]->hparams != model->hparams) {
LLAMA_LOG_ERROR("hparams of input models are different, aborting...");
clean_up();
return -1;
}

models.push_back(std::move(model));
mls.push_back(std::move(ml));
}
Expand Down Expand Up @@ -11521,6 +11515,8 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
std::vector<no_init<float>> f32_in_buf1; // dequant it internally
std::vector<float> f32_out_buf(n_elements, 0.0); // do not resize!
std::vector<uint8_t> out_buf(ggml_nbytes(out_tensor)); // do not resize!
const int n_per_row = out_tensor->ne[0];
const int n_rows = n_elements / n_per_row;

if (ins.method == LLAMA_MERGE_COPY) {
LLAMA_LOG_INFO("copy\n");
Expand Down Expand Up @@ -11565,22 +11561,56 @@ int32_t llama_merge_models(const struct llama_merge_config * config) {
}

if (ins.method == LLAMA_MERGE_SLERP) {
// Python code: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
LLAMA_LOG_INFO("slerp ");
float * in0 = (float *) f32_in_buf0.data();
float * in1 = (float *) f32_in_buf1.data();
float * dest = (float *) f32_out_buf.data();
for (size_t i = 0; i < n_elements; i++) {
//dest[i] = in0[i] * ins.t + in1[i] * 0;
dest[i] = in0[i];
static const float dot_threshold = 0.9995;
auto lerp_row = [](float * in0, float * in1, float * out, size_t nelem, float t) {
for (size_t i = 0; i < nelem; i++) {
out[i] = in0[i] * (1.0 - t) + in1[i] * t;
}
};
auto slerp_row = [&lerp_row](float * in0, float * in1, float * out, size_t nelem, float t) {
float norm0 = std::sqrt(std::inner_product(in0, in0 + nelem, in0, 0.0));
float norm1 = std::sqrt(std::inner_product(in1, in1 + nelem, in1, 0.0));
// Normalize the vectors to get the directions and angles
std::vector<float> v0(nelem);
std::vector<float> v1(nelem);
for (size_t i = 0; i < nelem; i++) {
v0[i] = in0[i] / norm0;
v1[i] = in1[i] / norm1;
}
// Dot product with the normalized vectors
float dot = std::inner_product(v0.begin(), v0.end(), v1.begin(), 0.0);
// If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp
if (std::abs(dot) > dot_threshold) {
return lerp_row(in0, in1, out, nelem, t);
}
// Calculate initial angle between v0 and v1
float theta_0 = std::acos(dot);
float sin_theta_0 = std::sin(theta_0);
// Angle at timestep t
float theta_t = theta_0 * t;
float sin_theta_t = std::sin(theta_t);
// Finish the slerp algorithm
float s0 = std::sin(theta_0 - theta_t) / sin_theta_0;
float s1 = sin_theta_t / sin_theta_0;
for (size_t i = 0; i < nelem; i++) {
out[i] = in0[i] * s0 + in1[i] * s1;
}
};
for (int r = 0; r < n_rows; r++) {
float * in0 = (float *) f32_in_buf0.data();
float * in1 = (float *) f32_in_buf1.data();
float * dest = (float *) f32_out_buf.data();
size_t offset = n_per_row * r;
slerp_row(in0 + offset, in1 + offset, dest + offset, n_per_row, ins.t);
}
}

// re-quantize it
{
LLAMA_LOG_INFO("requant\n");
std::array<int64_t, 1 << 4> hist_cur = {};
const int n_per_row = out_tensor->ne[0];
const int n_rows = n_elements / n_per_row;
static const int min_chunk_size = 32 * 512;
const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
size_t new_size = llama_tensor_quantize_internal(
Expand Down

0 comments on commit 10c477b

Please sign in to comment.