Open
Description
🐛 Bug
This example from #1720 shows performance difference between eager and Thunder compiled LoRA version of a simple MLP.
To Reproduce
You can either use the available benchmarks in the hf-nemo-benchmark branch by calling pytest with:
pytest thunder/benchmarks/targets.py -k "lora_linear and not inference" --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type --benchmark-warmup-iterations=10
or run the following script:
Standalone script
import torch
import torch.nn as nn
import torch.utils.benchmark
from transformers.configuration_utils import PretrainedConfig
from peft import LoraConfig, TaskType
from thunder.dynamo import thunderfx
a = torch.randn(1, 4096, 3584, requires_grad=False, device="cuda")
q_config = PretrainedConfig(
model_type = 'llm'
)
class HFMLP(nn.Module):
def __init__(self):
super().__init__()
self.config = q_config
self.linear_proj = nn.Linear(3584, 256)
self.linear_fc1 = nn.Linear(256, 256)
self.linear_fc2 = nn.Linear(256, 3584)
def forward(self, input_ids, **kwargs):
y = self.linear_proj(input_ids)
y = self.linear_fc1(y)
return self.linear_fc2(y)
def prepare_inputs_for_generation(self, *args, **kwargs):
pass
m = HFMLP()
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=16,
lora_alpha=32,
lora_dropout=0.0,
bias="none",
use_rslora=False,
target_modules=["linear_proj", "linear_fc1", "linear_fc2"],
)
from peft import get_peft_model
hf_peft_model = get_peft_model(m, peft_config, autocast_adapter_dtype=False).to("cuda")
print(hf_peft_model)
cpeft_hf = thunderfx(hf_peft_model)
cpeft_hf(a)
def fwd_hf():
y = cpeft_hf(a)
timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA thunder", measurement)
cpeft_hf = torch.compile(hf_peft_model, fullgraph=True)
cpeft_hf(a)
def fwd_hf():
y = cpeft_hf(a)
timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA inductor", measurement)
def fwd_hf():
y = hf_peft_model(a)
timer = torch.utils.benchmark.Timer("fwd_hf()", globals={"fwd_hf": fwd_hf})
measurement = timer.timeit(number=10)
print("HF LoRA eager", measurement)
Environment
- H200
Additional context
Here the measurements from a run on H200(same as in #1720):
---------------------------------------------------------------------------- benchmark 'compute_type=ComputeType.TRAINING_BACKWARD': 6 tests -----------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nemo_lora_linear[backward-inductor] 279.9155 (1.0) 1,028.0015 (1.97) 307.5476 (1.0) 35.6273 (1.12) 300.4275 (1.0) 9.8018 (1.0) 88;108 3.2515 (1.0) 1705 2
test_hf_lora_linear[backward-inductor] 282.4720 (1.01) 521.4545 (1.0) 308.3400 (1.00) 31.7722 (1.0) 301.8200 (1.00) 10.4370 (1.06) 91;110 3.2432 (1.00) 1670 2
test_nemo_lora_linear[backward-eager] 308.9520 (1.10) 1,209.2750 (2.32) 329.3746 (1.07) 36.2003 (1.14) 322.4815 (1.07) 12.6656 (1.29) 82;86 3.0361 (0.93) 1569 2
test_hf_lora_linear[backward-eager] 328.8140 (1.17) 1,041.9055 (2.00) 358.5735 (1.17) 34.2569 (1.08) 353.2260 (1.18) 9.9289 (1.01) 82;120 2.7888 (0.86) 1478 2
test_hf_lora_linear[backward-thunderfx] 601.0530 (2.15) 2,034.0050 (3.90) 667.3585 (2.17) 94.1061 (2.96) 641.8175 (2.14) 29.9585 (3.06) 103;164 1.4984 (0.46) 1520 1
test_nemo_lora_linear[backward-thunderfx] 615.3560 (2.20) 1,146.2360 (2.20) 669.0339 (2.18) 85.2972 (2.68) 647.3741 (2.15) 24.5128 (2.50) 76;138 1.4947 (0.46) 1523 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------- benchmark 'compute_type=ComputeType.TRAINING_FORWARD': 6 tests -----------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_nemo_lora_linear[forward-eager] 204.8484 (1.0) 359.6930 (1.0) 217.1435 (1.0) 26.8580 (1.0) 210.0090 (1.0) 4.4708 (1.0) 100;143 4.6052 (1.0) 1581 3
test_nemo_lora_linear[forward-inductor] 214.6163 (1.05) 391.2380 (1.09) 230.6942 (1.06) 29.3198 (1.09) 222.4005 (1.06) 6.2602 (1.40) 85;144 4.3347 (0.94) 1500 3
test_hf_lora_linear[forward-inductor] 229.4977 (1.12) 533.6100 (1.48) 248.2199 (1.14) 31.6476 (1.18) 239.2725 (1.14) 8.1190 (1.82) 85;133 4.0287 (0.87) 1410 3
test_hf_lora_linear[forward-eager] 266.8125 (1.30) 2,128.4015 (5.92) 287.4951 (1.32) 61.3338 (2.28) 274.2370 (1.31) 6.5187 (1.46) 114;172 3.4783 (0.76) 1807 2
test_hf_lora_linear[forward-thunderfx] 775.8250 (3.79) 2,610.0180 (7.26) 839.2712 (3.87) 113.6345 (4.23) 813.5395 (3.87) 25.9290 (5.80) 82;110 1.1915 (0.26) 1250 1
test_nemo_lora_linear[forward-thunderfx] 782.9630 (3.82) 1,231.3460 (3.42) 845.2695 (3.89) 91.7294 (3.42) 822.9025 (3.92) 27.6695 (6.19) 80;108 1.1831 (0.26) 1196 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Legend:
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
OPS: Operations Per Second, computed as 1 / Mean
And results from the script:
HF LoRA thunder <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c6f6623f0>
fwd_hf()
709.00 us
1 measurement, 10 runs , 1 thread
HF LoRA inductor <torch.utils.benchmark.utils.common.Measurement object at 0x7f9b4b9162d0>
fwd_hf()
215.16 us
1 measurement, 10 runs , 1 thread
HF LoRA eager <torch.utils.benchmark.utils.common.Measurement object at 0x7f9b29814d40>
fwd_hf()
250.23 us
1 measurement, 10 runs , 1 thread