-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thunder slower than eager for PEFT LoRA configs with small batch sizes #1738
Comments
Looking at the standalone script, it looks like we are checking this on CPU. Is that intentional? |
Also, it looks like the MLP is missing an activation function. |
Yes my bad, I've now updated the script and results Regarding the activation function I left that out on purpose to compare just the linear layers. It is possible to benchmark one single layer, but I think putting multiple ones helps to reduce noise in the measurements. Tho I am open to suggestions! |
It's a great start to have matching performance without activation functions. Can you please change the name and update the example snippets to avoid confusion? Adding "wo_act_fn" is enough. Reducing noise in the measurements should be handled by specifying different pytest-benchmark options. Even if the execution is too fast for a timer it's possible to set a resolution for pytest-benchmark to find the correct number of runs as explained in https://pytest-benchmark.readthedocs.io/en/latest/calibration.html What does the execution trace look like and is it different from the PyTorch eager code? If executing the computation function directly without prologue how do the timings look like? |
Interestingly enough, by adding the following instrumentation, we can use the otherwise unused diff --git a/thunder/__init__.py b/thunder/__init__.py
index 2ba53cdf..74585dd1 100644
--- a/thunder/__init__.py
+++ b/thunder/__init__.py
@@ -682,6 +682,7 @@ def jit(
computation_trc = transform_to_torch_types(computation_trc)
comp = computation_trc.python_callable()
+ comp = computation_execution_timer(comp)
# TODO RC1 Update the cache
cache_entry = CacheEntry(
@@ -710,6 +711,16 @@ def jit(
return wrapped
+ def computation_execution_timer(fn):
+ def wrapped(*args, **kwargs):
+ cs.last_computation_execution_start = time.perf_counter_ns()
+ try:
+ return fn(*args, **kwargs)
+ finally:
+ cs.last_computation_execution_stop = time.perf_counter_ns()
+
+ return wrapped
+
def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
I'll continue looking to find where all the rest of the runtime is spent |
I've updated the name of this issue because it looks like the observed slowdown is due to the latency of launching computation from Thunder. In the graph below, on the x axis is the size of the input tensor, and on the y axis the runtime. As it can be seen there are two main regimes, the one at the left of the graph where Thunder pays a big penalty compared to inductor and the other on the right, after the size of the input surpasses ~35 million: I've isolated the runs and tried without nvFuser but interestingly enough the behaviour is the same: In conclusion, I see two similar problems here, namely:
|
🐛 Bug
This example from #1720 shows performance difference between eager and Thunder compiled LoRA version of a simple MLP.
To Reproduce
You can either use the available benchmarks in the hf-nemo-benchmark branch by calling pytest with:
pytest thunder/benchmarks/targets.py -k "lora_linear and not inference" --benchmark-timer=torch.utils.benchmark.utils.timer.timer --benchmark-warmup=on --benchmark-group-by=param:compute_type --benchmark-warmup-iterations=10
or run the following script:
Standalone script
Environment
Additional context
Here the measurements from a run on H200(same as in #1720):
And results from the script:
The text was updated successfully, but these errors were encountered: