-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[tests] tighten compilation tests for quantization #12002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -847,6 +847,10 @@ def quantization_config(self): | |||
components_to_quantize=["transformer", "text_encoder_2"], | |||
) | |||
|
|||
@pytest.mark.xfail( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdouglas I get:
- 0/0: expected type of 'module._modules['norm_out']._modules['linear']._parameters['weight'].CB' to be a tensor type, ' but found <class 'NoneType'>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the time being I'm not sure that we can do a whole lot to avoid this for bnb int8. At the very least it is not a high priority for us. Not 100% sure but it's possible you could get around this by making a forward pass through the model prior to compiling it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will note this down then in the xfail
reason.
@anijain2305 I am planning to switch to
This is to get rid of:
However, doing so with E torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
E Explanation: Dynamo developers have intentionally marked that the function `current_accelerator` in file `/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/accelerator/__init__.py` should not be traced.
E Hint: Avoid calling the function `current_accelerator`.
E Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `current_accelerator` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
E Hint: Please file an issue to PyTorch.
E
E Developer debug context: module: torch.accelerator, qualname: current_accelerator, skip reason: <missing reason>
E
E
E from user code:
E File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 448, in forward
E norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
E File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 168, in forward
E emb = self.linear(self.silu(emb))
E File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
E return forward_call(*args, **kwargs)
E File "/fsx/sayak/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
E args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
E File "/fsx/sayak/diffusers/src/diffusers/hooks/group_offloading.py", line 339, in pre_forward
E self.group.onload_()
E File "/fsx/sayak/diffusers/src/diffusers/hooks/group_offloading.py", line 213, in onload_
E getattr(torch, torch.accelerator.current_accelerator().type)
E
E Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" I also had to comment out:
Do you have any recommendations? |
What does this PR do?
torch.compile
withfullgraph=True
on quantized models, we don't want any recompilations to get in the way of performance. This PR ensures that.fullgraph=True
is better in terms of cold-start and also overall execution time. When_repeated_blocks
is available for a model class, we make use ofcompile_repeated_blocks()
instead ofcompile()
.