-
-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Repeat layers to create FrankenModels #275
base: master
Are you sure you want to change the base?
Conversation
Have you observed test_inference to show better perplexity results doing that? I tested the PR out on some 13B's with various repeat methods and every single time it's very substantially worse. Kinda feels like a really expensive way to add noise. |
No arguments from me that this is better, but at least it's easier to experiment and find out. I find the results from 70b models feels nicer, maybe this effect is only apparent in large models? Quick note: the current argument does not match the notation used my Mergekit, so the results will be different. I will update this code to match their implementation. |
I'm a little preoccupied these days so I haven't had a chance to look at this. But are you just repeating forward passes? If so, do you also create new layers in the K/V cache? Otherwise you're not going to get sensible results. |
@turboderp Thanks for the tip. Surprisingly, although there are no new layers in the KV cache, its really not bad. Weird, right? I'm looking at the output of TinyLlama-1.1B-Chat-v1.0-5.0bpw-h6-exl2, with and without the middle 6 layers repeated once, and there is no obvious degradation in performance. I will try adding in extra KV cache layers now. I think we need a benchmark to see how this is affecting things, something like an Chatbot Arena so that we get real human comparisons. For a quick test: Update: Yes, it feels like the extra layers raise the temperature, but increasing the layer repeats and simultaneously lowering temperature seems to generate very nice text. |
Inference simply won't work correctly without extra layers in the cache. It won't be equivalent to an actual Frankenstein model, as you'll be overwriting keys/values from the repeated layers. To actually add layers to the cache adds a lot of complications with multi-GPU splitting, though. |
I've increased the cache-layers to match the total number of new layers, but I'm not sure how the inference pass uses and updates the cache k and v tensors. I have modified the forward pass to use updated 'layer_idx' values whilst keeping the other module attributes shared (tensor weights etc), but it's not working. I'm getting: |
I'm pretty sure the keys/values in cache are calculated from just the input tokens (independent from the previous layers), so repeated layers would have identical cache anyway. |
@dnhkng The right approach would be to apply the new layer index while loading the model, allocating a cache layer for each but then creating a reference layer rather than an actual layer whenever possible. It might still be necessary to duplicate layers across device boundaries, or you could end up with the hidden state bouncing back and forth between devices. That would at least have to be benchmarked to see if the overhead is acceptable or not. @silphendio The keys and values are computed (along with the queries) from the hidden state, not from the input tokens. So they're different for every layer, even if two layers happen to have the same weights. |
I'm using a wrapper to mask the Leaving the cache at its original size seems to yield better results, but maybe that's because it behaves more like the original model. |
I will try and get a single GPU model working first.
|
Some measured numbers using
All done using torch 2.1.2+rocm5.6. I know perplexity isn't everything but I feel it's good for pre-screening to see if a method is moving in the right direction. |
@zpin Can you check this gist? I tried your method, and although the extra cache layers are created, they appear to be unused. The script just prints out the first value of each cache tensor after an inference, and only layers up to 22 layer (the size of the input model) contain values, all the other layers past that are always zeros. Maybe it's just a bug on my part though.
That would be an interesting finding maybe? At each repeated layer, you use the KV cache of the previous repeat, forcing the model to stay on track, even with slightly different new input... weird... UPDATE: Found the issue: class ExLlamaV2AttentionWrapper(ExLlamaV2Attention):
def __init__(self, obj, new_idx):
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, '_new_idx', new_idx)
def __getattribute__(self, name):
if name == 'layer_idx':
return object.__getattribute__(self, '_new_idx')
# Delegate all other attributes to the wrapped object
try:
return getattr(object.__getattribute__(self, '_obj'), name)
except AttributeError:
return object.__getattribute__(self, name) This code block only reports the new 'layer_idx' externally via the getattribute method, but not internally for the attn class! i.e. the class itself only sees the original 'layer_idx', not '_new_idx' in its place. So, when the cache is built, its only based on the original 'layer_idx' values. |
Please try with these changes, the wrapped object didn't always use the masked layer_idx:
Yeah, I don't understand this enough to draw any conclusions. But it might also have been something else, since you noticed that the cache isn't used correctly. |
@Beinsezii Using the proper cache system and Beinsezii_MythoMax-L2-13B-EXL2_4k_hb8_b8, we get the following perplexities @zpin
I've tested both, and you get the same perplexity. As mentioned by @turboderp, this only works on single-gpu models so far. UPDATE: It's probably a stupid bug, but if anyone has time: https://gist.github.com/dnhkng/34e78b6082ec26124d72624dc3f6f666 |
A real test would probably be on a static self-merge that uses mergekit to see if it's comparable to this PR. |
@dnhkng I updated my gist https://gist.github.com/silphendio/535cd9c1821aa1290aa10d587b76a49c Instead of using the AttentionWrapper, I just copied the layer with the standard copy function and then set the layer_idx. @zpin in my (admittedly limited) tests, keeping the cache at its original size tends to result in more spelling mistakes. The model is also prone to leaps of logic, where it just omits stuff, and then it loses the train of thought. But it varies greatly based on the random seed. @turboderp You're right of course, I wonder how I got this silly idea. On a side note, it's slightly confusing that |
I've started a full test of the output (TinyStories-style), to get an understanding of the effects of layer duplication, with hundreds of layering combinations. It will be interesting to see the results. |
OK, it seems that reusing the cache has quite interesting effects on the output. Clearly, the results are very different! If you reuse the cache layer AND include the first few layers, the results are always very bad (constantly repeated words or \n symbols). However, reusing the cache in the middle of the LLM, and keeping the repeat section short does not seem to hurt performance significantly. But, overall, when we reuse the cache, only 6 repeat variants get a comparable rating as the baseline, vs 27 for the unique cache. However, if we can find a setting where the case reuse doesn't hurt AND the results are better, that would be a big win. I'll continue the tests on a larger model, and see how that goes. When I start seeing a solid pattern, I'll start doing many prompts and use GPT4 for evaluation. Examples of the best generation for each method, based on the prompt: "Imagine what alien communication might be like and create a hypothetical scenario for initial contact." Baseline:
Using shared-cache layers, and repeating layers 15-18:
Using Unique-cache layers, and repeating 12-20:
The results for TinyLlama-1.1B-Chat-v1.0 show that there are many configurations where the rated story is equivalent to the baseline, which itself is very surprising. This model was trained on 3T tokens, so very well-trained. Using a larger model, like 70B's, is where things will get interesting. |
Can someone confirm which code should I test from coding-related tasks? This gist: https://gist.github.com/silphendio/535cd9c1821aa1290aa10d587b76a49c ? Additionally, we are not going to have the same tensors as with Goliath using EXL2 as tensors are modified during quantization. Calibration is not negligible - for example, this week I noticed that DeepSeek Coder Instruct can get SOTA HumanEval in 8bit (up to ~82.5%) beating results from WizardCoder 33B 1.1 in fp16 :) (and Wizard cannot go above 82% with the same calibration set in 8bit) BTW. I have not seen a correlation between perplexity and "comprehension" capabilities, and it's also visible with Goliath while it makes spelling mistakes but follows instructions like a big model. |
this is really awesome. |
I have been doing some tests and it is a real-time sink :D I have been using one of the ugly prompt test cases I have, it confuses the model, and also seems to be challenging for some of them. Turbo 3.5 is not really passing it, gpt-4 always gets it, deepseek 67B also does a pretty good job. Note that "BUG" is without a number and it should stay like this, it should rename variables, remove one append and remove test/assertion part. I tried to prompt engineer the model to obey, sometimes will confuse the model even more sometimes it will help like in codellamas - deepseek models seems to not even notice its written there (usually related logits are less than 1% probable) There was a combination that did most of the things (however still did some spelling mistakes) however it was before i started to test things methodically instead of having fun and I cannot replicate - but you need to trust me it was pretty good reply. My observations so far:
Model: DeepSeek Coder Instruct 33B, 62 layers
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19, 16, 17, 18, 19, 20, 21, 22, 23, 20, 21, 22, 23, 20, 21, 22, 23, 24, 25, 26, 27, 24, 25, 26, 27, 24, 25, 26, 27, 28, 29, 30, 31, 28, 29, 30, 31, 28, 29, 30, 31, 32, 33, 34, 35, 32, 33, 34, 35, 32, 33, 34, 35, 36, 37, 38, 39, 36, 37, 38, 39, 36, 37, 38, 39, 40, 41, 42, 43, 40, 41, 42, 43, 40, 41, 42, 43, 44, 45, 46, 47, 44, 45, 46, 47, 44, 45, 46, 47, 48, 49, 50, 51, 48, 49, 50, 51, 48, 49, 50, 51, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19, 20, 21, 22, 23, 20, 21, 22, 23, 24, 25, 26, 27, 24, 25, 26, 27, 28, 29, 30, 31, 28, 29, 30, 31, 32, 33, 34, 35, 32, 33, 34, 35, 36, 37, 38, 39, 36, 37, 38, 39, 40, 41, 42, 43, 40, 41, 42, 43, 44, 45, 46, 47, 44, 45, 46, 47, 48, 49, 50, 51, 48, 49, 50, 51, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
Ideally I would be aiming to get a 67B level reply from 33B model. 67B reply is:
As for:
|
we have examples of models for which the strategy works. Venus-120b-v1.2. which is interleaving of lizpreciatior/lzlv_70b_fp16_hf |
So in other words, this strategy applied to lzlv 70B would get the same or similar output to Venus 120B without increasing the size of the 70B? |
Yes exactly
…On Tue, Jan 23, 2024, 6:46 PM St33lMouse ***@***.***> wrote:
we have examples of models for which the strategy works.
Venus-120b-v1.2 <https://huggingface.co/nsfwthrowitaway69/Venus-120b-v1.2>.
which is interleaving of lizpreciatior/lzlv_70b_fp16_hf MegaDolphin-120b
<https://huggingface.co/cognitivecomputations/MegaDolphin-120b> which is
interleaving of cognitivecomputations/dolphin-2.2-70b
So in other words, this strategy applied to lzlv 70B would get the same or
similar output to Venus 120B without increasing the size of the 70B?
—
Reply to this email directly, view it on GitHub
<#275 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAIQ4BJTGCV5MJKJHTBYSWTYQBYWRAVCNFSM6AAAAABBYLVDLOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMBXGI2TQOBVGE>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
That's exactly the idea |
I have nothing valuable to add, but that there a lot of people who know nothing about the effectiveness of this interested in the results of this, so a detailed evaluation on why it was rejected (if that happens) would be helpful. Alternatively, occasional updates to let us all know it hasn't died would keep our hopes alive. On a related note, some people are discussing work that might be related to this PR here (with a mention of various papers) woadwarrior 11h ago andersxa 4h ago |
In case someone wants to evaluate if self-merging really improves performance or not, I'm adding another model to the few available options. Since I couldn't get it done with this PR (despite trying), I've created self-merged models of miqu-1-70b in the same way as Venus and MegaDolphin 120B.
|
Very nice model. |
I am trying out Stephan's 2.4 model. Wasn't able to get the 2.65 to fit. But 2.4 fits at 16k context, no OOM on two 3090's. It is solid and strong. I'm not sure if it is any smarter than plain miqu 70B, but it feels equivalent, with more of a LZLV style of prose. Good for RP with a single narrator character. No problems so far, no detectable alignment. Very nice. |
@St33lMouse have you been able to run it using the code from this PR or code snippet? You can try the same and higher quant |
I'm using the 2.4 bpw quant (exllama2). All I did was download that model and run it with Ooba using exllama2. I can't run a heavier quant because it won't fit on my cards. By the way, getting a little repetition around 15k context. |
Any progress here? dnhkng's branch still works, and I just used it for some tests, but it would be very useful to have that in the official exllamav2 and by extension tabbyAPI (which would allow much better tests through common frontends). |
@turboderp Can you recommend a way to save frankenmerge models I have created by manually stacking layers? |
first of all, you all are awesome. Really great work. EDIT: Sorry update, this doesn't work as expected. I think due to the shallow copy, layer lora pointers just reference the last lora loaded for that layer. Working on a fix. |
Description
This slightly modifies the forward pass, to reuse layers to allow the creation and use of 'Frankenmodels' quickly and easily.
The format of the new argument is:
python test_inference.py -m /models/lzlv_70b_fp16_hf-4.0bpw-h6-exl2 -p "Once upon a time:" -gs 18,18 --repeats '[(0,20),(10,30),(20,40),(30,50),(40,60),(50,70),(60,79)]'
This would generate the nsfwthrowitaway69/Venus-120b-v1.2
Frankenmodel dynamically, while reducing VRAM usage by 50b params.
The repeats parameter is a string list of tuples. As the final layers in most models are model.norm and lm_head (not a number), the last value in the last tuple should be one lower than the final layer number. So long as this is the case, the code will extend out the layer so all are included. i.e.
[(0,20),(10,28)] and [(0,20),(10,30)]
would generate the same Frankenmodel.Related Discussion
discussion at #270 and a discussion on Reddit localllama about the potential of easily creating Frankenstein models using exllama.
Explanation of changes
A new parameter was added to argparse, and in ExLlamaV2.init if the param was used, we build a list of the layer order to use including repeats. In the ExLlamaV2._forward method, the actual model forward pass is extracted into a private process_module method, and this is called in the usual way with looping through self.modules if 'repeats' is not passed, and looping through self.layers_list if it is passed.