Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF4 quantization of linear layers without LoRA applied #1119

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

winglian
Copy link
Collaborator

Context

What is the purpose of this PR? Is it to

  • [X ] add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses. #1093

Changelog

Reverts #658 to bring back FrozenNF4Linear. When quantize_base is set to true, all base weights for linear layers are quantized, even if they do not have LoRA applied to them.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

@winglian winglian added the enhancement New feature or request label Jun 25, 2024
Copy link

pytorch-bot bot commented Jun 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1119

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0eb4ad7 with merge base b317c8f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 25, 2024
@joecummings
Copy link
Contributor

cc @msaroufim

Copy link

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification question: does thid mean that quantize_base is the boolean with which one specifies the q in qLoRA, which is globally applied to all projs? (versus how LoRA could be applied to a subset of the parameters of q, k, v and not o).

@winglian
Copy link
Collaborator Author

winglian commented Jul 1, 2024

running tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
but only with lora_attn_modules: ['q_proj'] instead of all the attention linear layers, uses 15810MiB on main and 15000MiB on this branch.

EDIT: setting lora_attn_modules active across all 4 modules uses 15044MiB in this branch, which is expected to be more than the 15000MiB due to the optimizer, and also less than main because all the weights are quantized.

Copy link

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rest of PR looks fine; we should implement FrozenNF4Linear as cleanly as possible though

Comment on lines 42 to 46
self.weight.requires_grad_(False)
self.nf4_weight = to_nf4(self.weight.data)
# re-register self.weight as the nf4 weight, so that the nf4 weight
# shows up as expected in .parameters, state_dict, etc.
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid accessing .data of self.weight?

I'm curious if there's a way to use swap_tensors here to accomplish this more cleanly. https://pytorch.org/docs/stable/generated/torch.utils.swap_tensors.html cc @mikaylagawarecki any ideas?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian Could you explain the rationale for the .data access?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah can we do something similar to what's currently done in LoRALinear for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply reverting back the code that was deleted in https://github.com/pytorch/torchtune/pull/658/files#diff-74b0d911936ebd1d0e216004577afed27b84b6bfdff9c6a9a1a28f6fac054850L45

Removing the .data does seem to work as well, so I've checked in that change.

Copy link

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what test cases can we write to verify functionality?

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the changes look reasonable to me. Re @janeyx99's testing comment, a couple things I think we should add here:

(1) a unit test for nf4_linear (can probably look at #465 for ideas)
(2) some kind of e2e test. I am thinking: run one of our QLoRA recipes before and after the change, confirm that (a) the peak memory reduces as expected, and (b) we see no regression in eval metrics on the resulting fine-tuned checkpoint.

It's also likely that this will break some of our existing QLoRA tests. So e.g. the values here would need to be updated

Comment on lines 42 to 46
self.weight.requires_grad_(False)
self.nf4_weight = to_nf4(self.weight.data)
# re-register self.weight as the nf4 weight, so that the nf4 weight
# shows up as expected in .parameters, state_dict, etc.
self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah can we do something similar to what's currently done in LoRALinear for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants