Skip to content

Thunder slower than eager for PEFT LoRA configs with small input sizes #1738

Open
@riccardofelluga

Description

@riccardofelluga

🐛 Bug

This example from #1720 shows performance difference between eager and Thunder compiled LoRA version of a simple MLP.

To Reproduce

You can either use the available benchmarks in the hf-nemo-benchmark branch by calling pytest with:

pytest thunder/benchmarks/targets.py -k "lora_linear and not inference" --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type --benchmark-warmup-iterations=10

or run the following script:

Standalone script
import torch
import torch.nn as nn
import torch.utils.benchmark
from transformers.configuration_utils import PretrainedConfig
from peft import LoraConfig, TaskType
from thunder.dynamo import thunderfx

a = torch.randn(1, 4096, 3584, requires_grad=False, device="cuda")

q_config = PretrainedConfig(
    model_type = 'llm'
)

class HFMLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.config = q_config

        self.linear_proj = nn.Linear(3584, 256)
        self.linear_fc1 = nn.Linear(256, 256)
        self.linear_fc2 = nn.Linear(256, 3584)

    def forward(self, input_ids, **kwargs):
        y = self.linear_proj(input_ids)
        y = self.linear_fc1(y)
        return self.linear_fc2(y)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        pass


m = HFMLP()

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.0,
    bias="none",
    use_rslora=False,
    target_modules=["linear_proj", "linear_fc1", "linear_fc2"],
)

from peft import get_peft_model

hf_peft_model = get_peft_model(m, peft_config, autocast_adapter_dtype=False).to("cuda")

print(hf_peft_model)

cpeft_hf = thunderfx(hf_peft_model)
cpeft_hf(a)

def fwd_hf():
    y = cpeft_hf(a)

timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA thunder", measurement)


cpeft_hf = torch.compile(hf_peft_model, fullgraph=True)
cpeft_hf(a)

def fwd_hf():
    y = cpeft_hf(a)

timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA inductor", measurement)

def fwd_hf():
    y = hf_peft_model(a)

timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA eager", measurement)

Environment

  • H200

Additional context

Here the measurements from a run on H200(same as in #1720):

---------------------------------------------------------------------------- benchmark 'compute_type=ComputeType.TRAINING_BACKWARD': 6 tests -----------------------------------------------------------------------------
Name (time in us)                                  Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nemo_lora_linear[backward-inductor]      279.9155 (1.0)      1,028.0015 (1.97)     307.5476 (1.0)      35.6273 (1.12)     300.4275 (1.0)       9.8018 (1.0)        88;108        3.2515 (1.0)        1705           2
test_hf_lora_linear[backward-inductor]        282.4720 (1.01)       521.4545 (1.0)      308.3400 (1.00)     31.7722 (1.0)      301.8200 (1.00)     10.4370 (1.06)       91;110        3.2432 (1.00)       1670           2
test_nemo_lora_linear[backward-eager]         308.9520 (1.10)     1,209.2750 (2.32)     329.3746 (1.07)     36.2003 (1.14)     322.4815 (1.07)     12.6656 (1.29)        82;86        3.0361 (0.93)       1569           2
test_hf_lora_linear[backward-eager]           328.8140 (1.17)     1,041.9055 (2.00)     358.5735 (1.17)     34.2569 (1.08)     353.2260 (1.18)      9.9289 (1.01)       82;120        2.7888 (0.86)       1478           2
test_hf_lora_linear[backward-thunderfx]       601.0530 (2.15)     2,034.0050 (3.90)     667.3585 (2.17)     94.1061 (2.96)     641.8175 (2.14)     29.9585 (3.06)      103;164        1.4984 (0.46)       1520           1
test_nemo_lora_linear[backward-thunderfx]     615.3560 (2.20)     1,146.2360 (2.20)     669.0339 (2.18)     85.2972 (2.68)     647.3741 (2.15)     24.5128 (2.50)       76;138        1.4947 (0.46)       1523           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------- benchmark 'compute_type=ComputeType.TRAINING_FORWARD': 6 tests -----------------------------------------------------------------------------
Name (time in us)                                 Min                   Max                Mean              StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nemo_lora_linear[forward-eager]         204.8484 (1.0)        359.6930 (1.0)      217.1435 (1.0)       26.8580 (1.0)      210.0090 (1.0)       4.4708 (1.0)       100;143        4.6052 (1.0)        1581           3
test_nemo_lora_linear[forward-inductor]      214.6163 (1.05)       391.2380 (1.09)     230.6942 (1.06)      29.3198 (1.09)     222.4005 (1.06)      6.2602 (1.40)       85;144        4.3347 (0.94)       1500           3
test_hf_lora_linear[forward-inductor]        229.4977 (1.12)       533.6100 (1.48)     248.2199 (1.14)      31.6476 (1.18)     239.2725 (1.14)      8.1190 (1.82)       85;133        4.0287 (0.87)       1410           3
test_hf_lora_linear[forward-eager]           266.8125 (1.30)     2,128.4015 (5.92)     287.4951 (1.32)      61.3338 (2.28)     274.2370 (1.31)      6.5187 (1.46)      114;172        3.4783 (0.76)       1807           2
test_hf_lora_linear[forward-thunderfx]       775.8250 (3.79)     2,610.0180 (7.26)     839.2712 (3.87)     113.6345 (4.23)     813.5395 (3.87)     25.9290 (5.80)       82;110        1.1915 (0.26)       1250           1
test_nemo_lora_linear[forward-thunderfx]     782.9630 (3.82)     1,231.3460 (3.42)     845.2695 (3.89)      91.7294 (3.42)     822.9025 (3.92)     27.6695 (6.19)       80;108        1.1831 (0.26)       1196           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean

And results from the script:

HF LoRA thunder <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c6f6623f0>
fwd_hf()
  709.00 us
  1 measurement, 10 runs , 1 thread
HF LoRA inductor <torch.utils.benchmark.utils.common.Measurement object at 0x7f9b4b9162d0>
fwd_hf()
  215.16 us
  1 measurement, 10 runs , 1 thread
HF LoRA eager <torch.utils.benchmark.utils.common.Measurement object at 0x7f9b29814d40>
fwd_hf()
  250.23 us
  1 measurement, 10 runs , 1 thread

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions