Skip to content

[tests] tighten compilation tests for quantization #12002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

What does this PR do?

  1. When not using any kind of offloading but just torch.compile with fullgraph=True on quantized models, we don't want any recompilations to get in the way of performance. This PR ensures that.
  2. When using offloading, regional compilation with fullgraph=True is better in terms of cold-start and also overall execution time. When _repeated_blocks is available for a model class, we make use of compile_repeated_blocks() instead of compile().

@@ -847,6 +847,10 @@ def quantization_config(self):
components_to_quantize=["transformer", "text_encoder_2"],
)

@pytest.mark.xfail(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas I get:

- 0/0: expected type of 'module._modules['norm_out']._modules['linear']._parameters['weight'].CB' to be a tensor type, ' but found <class 'NoneType'>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the time being I'm not sure that we can do a whole lot to avoid this for bnb int8. At the very least it is not a high priority for us. Not 100% sure but it's possible you could get around this by making a forward pass through the model prior to compiling it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will note this down then in the xfail reason.

@sayakpaul
Copy link
Member Author

@anijain2305 I am planning to switch to compile_repeated_blocks(fullgraph=True) for:

def test_torch_compile_with_group_offload_leaf(self, use_stream=False):

This is to get rid of:

torch._dynamo.config.cache_size_limit = 1000

However, doing so with pytest tests/quantization/gguf/test_gguf.py::GGUFCompileTests::test_torch_compile_with_group_offload_leaf results into:

E               torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
E                 Explanation: Dynamo developers have intentionally marked that the function `current_accelerator` in file `/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/accelerator/__init__.py` should not be traced.
E                 Hint: Avoid calling the function `current_accelerator`.
E                 Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `current_accelerator` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.
E                 Hint: Please file an issue to PyTorch.
E               
E                 Developer debug context: module: torch.accelerator, qualname: current_accelerator, skip reason: <missing reason>
E               
E               
E               from user code:
E                  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 448, in forward
E                   norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
E                 File "/fsx/sayak/diffusers/src/diffusers/models/normalization.py", line 168, in forward
E                   emb = self.linear(self.silu(emb))
E                 File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
E                   return forward_call(*args, **kwargs)
E                 File "/fsx/sayak/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
E                   args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
E                 File "/fsx/sayak/diffusers/src/diffusers/hooks/group_offloading.py", line 339, in pre_forward
E                   self.group.onload_()
E                 File "/fsx/sayak/diffusers/src/diffusers/hooks/group_offloading.py", line 213, in onload_
E                   getattr(torch, torch.accelerator.current_accelerator().type)
E               
E               Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

I also had to comment out:

@torch.compiler.disable()

Do you have any recommendations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants