Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usability bug]: KV-cache LLMs are difficult to add to MLAgility #315

Open
jeremyfowers opened this issue Jun 6, 2023 · 5 comments
Open
Labels
bug Something isn't working models Relating to the model corpus new Work that hasn't started yet

Comments

@jeremyfowers
Copy link
Contributor

jeremyfowers commented Jun 6, 2023

Unlike most Transformer models, which have a simple forward() signature, KV-cache LLMs have a complex signature that is difficult to encode into MLAgility's model template.

For example, LLaMA with no KV-cache:

inputs = {
    "hidden_states": torch.ones(
        batch_size, max_seq_length, config.hidden_size, dtype=torch.float
    ),
    "attention_mask": torch.ones(
        batch_size, 1, max_seq_length, max_seq_length, dtype=torch.float
    ),
}

LLaMA with KV-cache enabled:

inputs = {
    "hidden_states": torch.ones(
        batch_size,
        1,
        config.hidden_size,
        dtype=torch.float,
    ),
    "position_ids": [[0]],
    "past_key_value": [
        torch.ones(
            batch_size,
            config.num_attention_heads,
            config.max_position_embeddings - 1,
            model.self_attn.head_dim,
            dtype=torch.float,
        ),
        torch.ones(
            batch_size,
            config.num_attention_heads,
            config.max_position_embeddings - 1,
            model.self_attn.head_dim,
            dtype=torch.float,
        ),
    ],
}

It was painful to figure out the details of the latter code, especially the specific value that needed to be assigned to position_ids to make everything work.

The reason for this interface complexity in the first place is that huggingface transformers doesn't expect anyone to invoke a KV-cache transformer a single time like this. They expect an app that actually maintains the cache. And since the cache inputs come from the model outputs, app developers dont have to think about how to format those values (woo python).

However, mlagility would also not work well with such an app because the first invocation of the model would be a prefill invocation (no KV cache used) and then the subsequent invocations would be generation (KV-cache used). These two invocation modes generate completely different ONNX files and benchmark results, yet MLAgility doesn't offer a clear way to distinguish between the two (that I know of).

Filing this issue to keep track of the problem and any potential solutions. cc @danielholanda @ramkrishna2910

@jeremyfowers jeremyfowers added bug Something isn't working models Relating to the model corpus new Work that hasn't started yet labels Jun 6, 2023
@jeremyfowers
Copy link
Contributor Author

Addendum: the "position_ids": [0], above in the KV cache inputs works in pure pytorch, but the ONNX exporter somehow misinterprets it into [tensor(0)], which is not equivalent and causes the ONNX export to fail in a confusing way.

Setting position_ids: [[0]] works in both pure pytorch and the ONNX exporter case, but this was confusing and difficult to figure out.

I see this as further data that the approach of supplying dummy inputs is not scalable as models become increasingly complex.

@jeremyfowers
Copy link
Contributor Author

@danielholanda curious to get your thoughts on this now that you're back!

@danielholanda
Copy link
Contributor

Let me know if I'm understanding this correctly.

The challenging you are describing is the amount of work needed to create template models like mlagility/models/llm_layer/llama_layer_prototype.py. Those templates are needed because we can't use mlagility out of the box, since mlagility will only use the "first" inputs used by the model (and KV-caching is only enabled when the model is called for the second time).

If that is the case, our plans to take the shape of model inputs into account when calculating the hash should solve this issue.

@jeremyfowers
Copy link
Contributor Author

Let me know if I'm understanding this correctly.

The challenging you are describing is the amount of work needed to create template models like mlagility/models/llm_layer/llama_layer_prototype.py. Those templates are needed because we can't use mlagility out of the box, since mlagility will only use the "first" inputs used by the model (and KV-caching is only enabled when the model is called for the second time).

If that is the case, our plans to take the shape of model inputs into account when calculating the hash should solve this issue.

@danielholanda interesting! Is that because we could pass a true application in, which would perform prefill followed by KV-cached-generation, and Analysis would detect that there are two models and produce two benchmarks? That would be very cool...

Some follow up questions/challenges though:

  • How does the benchit user tell the difference between the prefill and generation benchmarks? They will only be identified by different hashes (which users cant comprehend) and different input shapes (and the whole point was for the user to not have to think about input shapes). It would be super nice if the results were literally labeled like llama_prefill and llama_generation but I don't know off the top of my head how that would be accomplished.
  • A true LLM application will typically generate many output tokens, and the sequence length grows every time a new token is generated. So if we take an initial prompt of size 128 and generate 300 tokens, will we end up with 300 benchmarks? We likely just want 2. I think that can be accomplished by setting the maximum output size for the application but that feels a bit hacky.

@danielholanda
Copy link
Contributor

@danielholanda interesting! Is that because we could pass a true application in, which would perform prefill followed by KV-cached-generation, and Analysis would detect that there are two models and produce two benchmarks? That would be very cool...

Exactly. If the same model is executed with N different input shapes, then N models will be detected.

It is also true that the only information we provide to the user to differentiate between those models is the hash. We might want to also display input shapes when the same model is executed more than once with different inputs.

It would be super nice if the results were literally labeled like llama_prefill and llama_generation but I don't know off the top of my head how that would be accomplished.

I agree that this would be nice as long as we can do it in a programatic way.

A true LLM application will typically generate many output tokens, and the sequence length grows every time a new token is generated. So if we take an initial prompt of size 128 and generate 300 tokens, will we end up with 300 benchmarks? We likely just want 2. I think that can be accomplished by setting the maximum output size for the application but that feels a bit hacky.

Very interesting point. Another alternative here is to set the maximum number of "model variants" to execute per model.

Note: model variant = same model with different input shapes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working models Relating to the model corpus new Work that hasn't started yet
Projects
None yet
Development

No branches or pull requests

2 participants