Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LORA Implementation for Parameter Efficient Fine-Tuning on new datasets #159

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

sidhantls
Copy link

@sidhantls sidhantls commented Oct 31, 2024

Addresses #158 and the request for PEFT mentioned in the README.

Feature

This PR adds PEFT support with Low-Rank adapters (LORA) for fine-tuning Parler-TTS on new datasets.

LORA is applied to the Parler-TTS decoder Transformer where PEFT is applied to Linear projection layers. With the T5 encoder of Parler-Mini frozen, PEFT trains only 8% params, compared to 50% without PEFT

Benefits

  • PEFT enables Fine-tuning of Parler-Mini on 8GB GPU (with limited offloading?).
    • On my windows machine, it takes 29.52s/it with PEFT and 117.50s/it without PEFT for fine-tuning.
  • PEFT enables Fine-Tuning of Parler-TTS Large, 2.3B on Google Collab. Without, I get OOM error.

Results:

image

  • Training plots with PEFT and without PEFT can be found in this Weights and Biases Report along with generated audio samples by the two models. This is fine-tuning on Jenny-6H.
  • Audio Samples Comparison: Audio samples sounds very similar. In some cases, the PEFT model produces better (better pronunciation and no cut-off at the end): Refer to Index 47, 49

Reproduce:

  • Google Collab: Notebook. To turn off PEFT, set USE_PEFT=false

To do:

  • Analysis to see if Parameter Efficient Fine-Tuning will be useful to the community @ylacombe @sanchit-gandhi
  • Complete PR Code: Saving Model:
    • After PEFT training is complete, it requires the adapters to be removed and saved. I have implemented this, but it is an in-place operation. So, how are we going to save checkpoints? For the final model, we can use this in-place operation with replaces adapters with regular layers so that the model can be loaded the same way without any other code changes


logger = logging.getLogger(__name__)

os.environ["WANDB_MODE"] = "offline"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to do: remove this line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant