Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CausalLMModel intefrace and add IREE numerics test for Llama 3.1 8B FP16 TP8 #375

Closed

Conversation

sogartar
Copy link
Contributor

We do not have a clearly defined interface for LMs decode and prefill have different signature when exporting to IREE.
Here is added a new ABC CausalLMModel that makes a distinction between the to variants.
The BaseCausalLMModel provides a default implementation for the new prefill_from_seq_lens and decode_from_seq_lens methods.

The export script export_paged_llm_v1 does too much in its exported functions, first it computes the attention mask then it shards its arguments results.
This change moves this into a separate class that conforms to the CausalLMModel intefrace.

Introduce a new CausalLMIreeModel that conforms to CausalLMModel, but is backed by an IREE module.
It is not performant and only meant for testing as it marshals tensors and uses the IREE Python bindings.
This can then be used for example in the paged_llm_v1.TorchGenerator or other places where a LM is expected.

Refactor the sharded Llama tests. Increase code reuse and use the TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and export_paged_llm_v1 scripts in the test flow to increase their test coverage.

Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good numerical accuracy. It is compared to an fp64 unsharded torch variant to ensure that the reference is of high accuracy.

…a 3.1 8B FP16 TP8

We do not have a clearly defined interface for LMs decode and prefill
have different signature when exporting to IREE.
Here is added a new ABC CausalLMModel that makes a distinction between
the to variants.
The BaseCausalLMModel provides a default implementation for the new
prefill_from_seq_lens and decode_from_seq_lens methods.

The export script export_paged_llm_v1 does too much in its exported
functions, first it computes the attention mask then it shards its
arguments results.
This change moves this into a separate class that conforms to the
CausalLMModel intefrace.

Introduce a new CausalLMIreeModel that conforms to CausalLMModel, but is
backed by an IREE module.
It is not performant and only meant for testing as it marshals tensors
and uses the IREE Python bindings.
This can then be used for example in the paged_llm_v1.TorchGenerator
or other places where a LM is expected.

Refactor the sharded Llama tests. Increase code reuse and use the
TorchGenerator in the toy-sized tests. Use the shard_llm_dataset and
export_paged_llm_v1 scripts in the test flow to increase their test
coverage.

Introduce a Llama 3.1 8B FP16 TP8 test that appears to not have good
numerical accuracy. It is compared to an fp64 unsharded torch variant
to ensure that the reference is of high accuracy.
@sogartar sogartar marked this pull request as draft October 30, 2024 15:15
@sogartar sogartar requested a review from dan-garvey October 30, 2024 15:16
@sogartar sogartar marked this pull request as ready for review October 30, 2024 15:16
@sogartar sogartar requested a review from IanNod October 30, 2024 15:38
Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is way too much in one PR making this almost impossible to reasonably review. Can you please break this into multiple PRs with clear goals. This scale of change is almost impossible to tell whether it will trigger more failures

@@ -27,7 +27,7 @@
################################################################################


class PagedLlamaModelV1(BaseCausalLMModel):
class PagedLlamaModelV1(BaseCausalLMModel, CausalLMModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double inheritance is a giant red flag to me. It feels extremely wrong to have two sets of CausalLMModel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the naming is not good. One of them is implementation and the other is an ABC. I will rename them to something more clear. I did it like that because BaseCausalLMModel implements just a part of the whole CausalLMModel interface.



def main():
def dtype_from_str(s: str) -> torch.dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a map rather than string manipulations. It is as simple as

{
  "fp8":torch.fp8,
  "f32":torch.f32,
....
}

@@ -59,8 +69,18 @@ def main():
default="decomposed",
choices=["decomposed", "torch"],
)
parser.add_argument(
"--attention-dtype",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a justification for adding these in the first place? They should be inferred from the data types of the functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one test that uses fp32.

@sogartar
Copy link
Contributor Author

@rsuderman here is one part of this PR as a separate #383.

@sogartar
Copy link
Contributor Author

@rsuderman, per you request I have broken up this into multiple PRs. The last one is #394, which references all its dependencies.

@sogartar sogartar closed this Oct 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants