Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

in situ auto-Frankenmerges #4718

Open
semiring opened this issue Dec 31, 2023 · 11 comments
Open

in situ auto-Frankenmerges #4718

semiring opened this issue Dec 31, 2023 · 11 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@semiring
Copy link

Feature Description

Modify llama.cpp to support on-the-fly "Frankenmerging" of the model in memory with itself.

Motivation

Frankenmerges, including auto-Frankenmerges, are becoming increasingly popular and appear to have properties that merit further study; it's Rich Sutton's "bitter lesson" in the small: stacking more decoder blocks means a greater total amount of computation in a single inference pass and, perhaps surprisingly, under the right circumstances, that greater accessible computation outweighs the 'noise' induced by performing fairly brutal surgery on the order of decoder blocks.

Right now experimentation is taking place at the level of building new models with mergekit. This is slow. The ability to mix-and-match decoder blocks on the fly in llama.cpp would speed up iteration and experimentation, helping better understand the tradeoff between greater available net computation and decoder surgery induced noise.

Possible Implementation

Something like this:

https://github.com/semiring/IRL-llama.cpp/blob/master/llama.cpp#L4346

@semiring semiring added the enhancement New feature or request label Dec 31, 2023
@kalomaze
Copy link
Contributor

kalomaze commented Jan 1, 2024

image

In typical Frankenmerge setups (like Goliath 120b pictured here) I notice a pattern:

  • Compute X amount of layers
  • Then switch models, go back Y amount of layers [from where X was at], and recompute X amount of layers from there and repeat the cycle

If you wanted to emulate the frankenmerging pattern seen in Goliath on a single 70b model at inference time, it would be:

  • Compute 16 layers, rewind to the last 8 layer point and compute another 16 layers worth after that

You could have two experimental hyperparameters for this:

  • recompute_n_layers [which determines how frequently it 'rewinds' and computes layers]
  • rewind_n_layers [which determines how many layers to 'rewind' in the forward pass]

Or something along those lines.

If you wanted to do it with two models at inference time, you could make it "switch" models every time it recomputes which would completely emulate the frankenmerging setup.

@kalomaze
Copy link
Contributor

kalomaze commented Jan 1, 2024

Also, if you're interested in implementing other ideas for "maximizing compute" on a single model:

I'm interested in seeing what happens when you iteratively compute the same layer multiple times, but weigh the change of the hidden state proportionally. For example, doing 4 passes of each hidden layer where it has a 0.25x weight of change to the hidden state for each 'pass', and so on.

@ggerganov
Copy link
Owner

We can implement a tool similar to quantize that takes 2 GGUF files and outputs a new GGUF file picking and merging certain layers from the input files.

Regarding the evaluation of a single layer multiple times, I think we can add a general-purpose solution via an optional integer array in the GGUF meta data that specifies the indices of the layers that need to be evaluated. This way, the layer loop:

for (int il = 0; il < n_layer; ++il) {

would become:

for (int iil = 0; iil < n_layer; ++iil) {
    const int il = model.layer_order ? model.layer_order[iil] : iil;

This would be general enough to implement any kind of layer repetition and would be flexible to re-configure via the KV overrides.

@ggerganov ggerganov added the good first issue Good for newcomers label Jan 2, 2024
@kalomaze
Copy link
Contributor

@semiring Do you still have interest in pursuing this concept? It would be interesting to get a smaller lora adapter for each finetuned model and apply it at inference time to save VRAM instead of loading redundant layers into memory.

@semiring
Copy link
Author

@kalomaze I don't have the cycles to work on a properly-engineered solution for this right now; if you're interested, please go ahead!

@xhedit
Copy link

xhedit commented Jan 16, 2024

I created a branch with it at https://github.com/xhedit/llama.cpp/tree/xhedit-layer-order. I added a std::string for the param to llama.h and as a result, test-c fails to build. I deleted the test from Makefile in my branch, so it's not suitable for merging until I come up with a way to keep llama.h compatible with C.

./main --layer-order "[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,20,21,22,23,24,25,26,27,28,29,30,31]" --model ../models/Hermes-7B-q8.gguf --prompt "Hi." --temp 0

Gives me a decent reply.

./main --layer-order "[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31]" --model ../models/Hermes-7B-q8.gguf --prompt "Hi." --temp 0

Gives same result as

 ./main --model ../models/Hermes-7B-q8.gguf --prompt "Hi." --temp 0

@sorasoras
Copy link

that's super interesting. merge isn't super easy to use

@jxy
Copy link
Contributor

jxy commented Feb 14, 2024

Simply changing the order in the layer loop build_llama is not enough. If there is a layer applied twice, the same KV cache is used. We need to allocate additional KV cache for those repeated layers.

@dnhkng
Copy link

dnhkng commented Feb 28, 2024

We can implement a tool similar to quantize that takes 2 GGUF files and outputs a new GGUF file picking and merging certain layers from the input files.

@ggerganov Is this being pursued? I started to try and do a GGUF merge with gguf.py, but I immediately hit:
ValueError: Only F32 and F16 tensors are supported for now

Working directly on quantised models seems to make the most sense, as probably no one will be running large merged models at F16

@ggerganov
Copy link
Owner

I think there is some work started in #5741

Regarding the error, I think you are using gguf-py? It does not seem to support quantized tensor info. Not sure how difficult it would be to add.

Implementing this in C using ggml would make more sense to me

@dnhkng
Copy link

dnhkng commented Feb 29, 2024

I used exllamaV2 for layer merging so far. The issue is when the model shares weights over duplicated layers, and there is KV cache for all layers including duplicates. The model flow might have to bounce back and forth between the cards.

For exllamav2, python is great, as you can dynamically modify the layers, with just a quick cache rebuild after a modification. I don't think it makes to do that in C++ for inferencing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

7 participants