Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The next tutorials #426

Open
1 of 7 tasks
msaroufim opened this issue Jun 24, 2024 · 5 comments
Open
1 of 7 tasks

The next tutorials #426

msaroufim opened this issue Jun 24, 2024 · 5 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed

Comments

@msaroufim
Copy link
Member

msaroufim commented Jun 24, 2024

From our README.md

torchao is a library to create and integrate high-performance custom data types layouts into your PyTorch workflows

And so far we've done a good job building out the primitive data types along with their corresponding transformed Linear Layers so for example given a new ExoticDtype() we have a playbook to create ExoticDtypeLinear() and indeed for weight only transformations this is a perfectly fine workflow and how the majority of quantization libraries operate.

For example

m = DownloadModelFromHuggingFace()
quantize_(m, int4_weight_only()) # This will swap out all torch.nn.Linear with a 4 bit Linear

We can make the above shine with more accessible blogs and performance benchmarks and integrations with more partners

However, this is doing somewhat of a disservice at explaining the ao value proposition. For example, we're a dtype library and not a dtype Linear library so given a dtype it should be easy for us to do a lot more. So some examples I'd like to see next are

  • Quantized Optimizers with the most obvious additions being 8 bit and 4 bit ADAM
  • Quantized KV cache
  • Quantization Aware training with an exotic dtype

None of the above is "research", this is very much the way engineering is moving for inference https://blog.character.ai/optimizing-ai-inference-at-character-ai/

Also given an exotic quantization schema I'd like to be more proactive in helping people benchmark their models so this should include

  • Flop utilization
  • Memory bandwidth
  • Cache hit rate (for kv cache only)
  • Roofline analysis
@msaroufim msaroufim added good first issue Good for newcomers help wanted Extra attention is needed labels Jun 24, 2024
@jeromeku
Copy link
Collaborator

@msaroufim

Would love to work on this.

  • Regarding 4-bit and 8-bit ADAM is this already done? I'd been working on a triton-only version of AdamW8-bit which could potentially serve as a drop-in replacement for the original CUDA version that would work more seamlessly with torch.compile and the rest of torchao primitives. Or are you thinking of integrating the original 8-bit version as a custom cuda op?
  • Quantized KV-cache - lots of research in this area, given the importance of efficient (long-context) inference. Would this include methods that require training-time architectural changes (per the character.ai blog) or post-training (dynamic) methods? For the latter, I think some of the major buckets are KV cache offloading, compression / quantization, and eviction (i.e. token pruning). Could further categorize by methods that compress at the layer, head, token, and hidden dim level.
  • Regarding profiling, happy to help here as well. Pretty familiar with torch.profiler as well as extending it for even more fine-grained metrics.

@msaroufim
Copy link
Member Author

msaroufim commented Jun 24, 2024

  • Regarding 8 bit ADAM. I was thinking we'd try to code-generate this using int8_weight_only or the dynamically quantized version. So generally I want us to follow the heuristic of first try compile() and if that doesn't work then Triton and if that doesn't work then integrate the original as a custom op. I'd be fine if you need to integrate the original bnb kernel as a custom op if it makes testing in CI easier
  • For the quantized KV cache I was specifically thinking about inference however we're starting to do more with quantization aware training (fake quantization) in collaboration with the tune team so you can chat with @andrewor14 Add support for int4 weight-only QAT #383 generally for these cross repo PRs I'd like to see them work in standalone examples in ao and we can refer to a more prod ready version in tune. Seems like we should also be moving towards quantized training (non fake)
  • Yeah regarding profiling that'd be very helpful, I think people are having trouble figuring out how far they are away from roofline so some simple benchmark utils in torchao.utils or torchao.benchmark would do wonders

@jeromeku
Copy link
Collaborator

@msaroufim

RE: profiling

  • what is the ideal set of metrics for this analysis?
  • do you have some sample workloads that I could use as test cases while developing these utils

@msaroufim
Copy link
Member Author

For metrics the most important ones are memory bandwidth and flop utilization. A good representative workload for now is probably llama2 and llama3 https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py and this script has good metric instrumentation already so extending it feels natural

And for specific algorithms to test out I'd be most curious about testing out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants
@jeromeku @msaroufim @gau-nernst and others