Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora fine-tuning using PEFT Library #204

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sidhantls
Copy link

@sidhantls sidhantls commented Feb 14, 2025

Adds a LORA implementation for parameter efficient fine-tuning of Parler TTS

Address #183 #158 and other request

Feature

This PR adds PEFT support with Low-Rank adapters (LORA) for fine-tuning Parler-TTS on new datasets.

LORA is applied to the Parler-TTS decoder Transformer where PEFT is applied to Linear projection layers. Fine-tuning with lora trains only 0.5% of parameters for Parler Mini

Benefits

  • PEFT enables Fine-tuning of Parler-Mini on 8GB GPU (with limited offloading?).
    • On my windows machine, it takes 29.52s/it with PEFT and 117.50s/it without PEFT for fine-tuning.
  • PEFT enables Fine-Tuning of Parler-TTS Large, 2.3B on Google Collab. Without, I get OOM error.

An alternative implementation of PR #159, which enables training with lora, loading checkpoints and final LORA model. Moreover, it uses the "peft" library, rather than #159, which was a custom implementation. Moreover, this PR allows loading saved checkpoints, which was not possible in #159

How to use:

Fine-Tuning:

When running accelerate launch ./training/run_parler_tts_training.py for fine-tuning, use --use_lora true --lora_r 8 --lora_alpha 16 --lora_dropout 0.05

Loading Checkpoints:

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from peft import PeftModel
import torch

device = "cuda"

# Base model (pretrained)
base_model_name = "parler-tts/parler_tts_mini_v0.1"
peft_model_path = "output_dir_training/"  # PEFT model path/checkpoint path

torch_dtype = torch.float16

# Load base model
base_model = ParlerTTSForConditionalGeneration.from_pretrained(base_model_name).to(device, dtype=torch_dtype)

# Load PEFT model on top of the base model
peft_model = PeftModel.from_pretrained(base_model, peft_model_path)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# merge LORA adapters into model, 50% faster inference,
model = peft_model.merge_and_unload()

## regular inference code for model
from IPython.display import Audio

prompt = "It was a bright cold day in April."
description = "Jenny speaks with a monotone voice, in a very close-sounding environment with almost no noise. She speaks slightly fast."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# generate audio 
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()

Audio(data=audio_arr, rate=44000)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant