-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce CausalLMModel intefrace and add IREE numerics test for Llama 3.1 8B FP16 TP8 #375
Introduce CausalLMModel intefrace and add IREE numerics test for Llama 3.1 8B FP16 TP8 #375
Conversation
…a 3.1 8B FP16 TP8 We do not have a clearly defined interface for LMs decode and prefill have different signature when exporting to IREE. Here is added a new ABC CausalLMModel that makes a distinction between the to variants. The BaseCausalLMModel provides a default implementation for the new prefill_from_seq_lens and decode_from_seq_lens methods. The export script export_paged_llm_v1 does too much in its exported functions, first it computes the attention mask then it shards its arguments results. This change moves this into a separate class that conforms to the CausalLMModel intefrace. Introduce a new CausalLMIreeModel that conforms to CausalLMModel, but is backed by an IREE module. It is not performant and only meant for testing as it marshals tensors and uses the IREE Python bindings. This can then be used for example in the paged_llm_v1.TorchGenerator or other places where a LM is expected. Refactor the sharded Llama tests. Increase code reuse and use the TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and export_paged_llm_v1 scripts in the test flow to increase their test coverage. Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good numerical accuracy. It is compared to an fp64 unsharded torch variant to ensure that the reference is of high accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is way too much in one PR making this almost impossible to reasonably review. Can you please break this into multiple PRs with clear goals. This scale of change is almost impossible to tell whether it will trigger more failures
@@ -27,7 +27,7 @@ | |||
################################################################################ | |||
|
|||
|
|||
class PagedLlamaModelV1(BaseCausalLMModel): | |||
class PagedLlamaModelV1(BaseCausalLMModel, CausalLMModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double inheritance is a giant red flag to me. It feels extremely wrong to have two sets of CausalLMModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the naming is not good. One of them is implementation and the other is an ABC. I will rename them to something more clear. I did it like that because BaseCausalLMModel implements just a part of the whole CausalLMModel interface.
|
||
|
||
def main(): | ||
def dtype_from_str(s: str) -> torch.dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a map rather than string manipulations. It is as simple as
{
"fp8":torch.fp8,
"f32":torch.f32,
....
}
@@ -59,8 +69,18 @@ def main(): | |||
default="decomposed", | |||
choices=["decomposed", "torch"], | |||
) | |||
parser.add_argument( | |||
"--attention-dtype", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide a justification for adding these in the first place? They should be inferred from the data types of the functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is one test that uses fp32.
@rsuderman here is one part of this PR as a separate #383. |
@rsuderman, per you request I have broken up this into multiple PRs. The last one is #394, which references all its dependencies. |
We do not have a clearly defined interface for LMs decode and prefill have different signature when exporting to IREE.
Here is added a new ABC CausalLMModel that makes a distinction between the to variants.
The BaseCausalLMModel provides a default implementation for the new prefill_from_seq_lens and decode_from_seq_lens methods.
The export script export_paged_llm_v1 does too much in its exported functions, first it computes the attention mask then it shards its arguments results.
This change moves this into a separate class that conforms to the CausalLMModel intefrace.
Introduce a new CausalLMIreeModel that conforms to CausalLMModel, but is backed by an IREE module.
It is not performant and only meant for testing as it marshals tensors and uses the IREE Python bindings.
This can then be used for example in the paged_llm_v1.TorchGenerator or other places where a LM is expected.
Refactor the sharded Llama tests. Increase code reuse and use the TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and export_paged_llm_v1 scripts in the test flow to increase their test coverage.
Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good numerical accuracy. It is compared to an fp64 unsharded torch variant to ensure that the reference is of high accuracy.