Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] [weight-stripped engine] doesn't work for torch.compile #3216

Closed
zewenli98 opened this issue Oct 6, 2024 · 5 comments
Closed
Labels
bug Something isn't working

Comments

@zewenli98
Copy link
Collaborator

Bug Description

There are three interfaces to compile a model: 1) torch export + torch_trt.dynamo.compile 2) torch_trt.compile(ir="dynamo") 3) torch.compile(backend="tensorrt").
The PR #3167 is supporting weight-stripped engines, which works for 1) and 2) but not 3). Not sure the reason.

I observed the issue in the test:

# 1. Compile with torch_trt.dynamo.compile
gm1 = torch_trt.dynamo.compile(
exp_program,
example_inputs,
**settings,
)
gm1_output = gm1(*example_inputs)
# 2. Compile with torch_trt.compile using dynamo backend
gm2 = torch_trt.compile(
pyt_model, ir="dynamo", inputs=example_inputs, **settings
)
gm2_output = gm2(*example_inputs)
# 3. Compile with torch.compile using tensorrt backend
gm3 = torch.compile(
pyt_model,
backend="tensorrt",
options=settings,
)
gm3_output = gm3(*example_inputs)
assertions.assertEqual(
gm1_output.sum(), 0, msg="gm1_output should be all zeros"
)
assertions.assertEqual(
gm2_output.sum(), 0, msg="gm2_output should be all zeros"
)
assertions.assertEqual(
gm3_output.sum(), 0, msg="gm3_output should be all zeros"
)

The CI test reports the error:

FAILED models/test_weight_stripped_engine.py::TestWeightStrippedEngine::test_three_ways_to_compile_weight_stripped_engine - AssertionError: tensor(0.5406, device='cuda:0') != 0 : gm3_output should be all zeros

Only torch.compile failed so I output the result of gm3_output, which is the same as the output of the weight-included engine. So I suspect the engine was not getting its weights stripped. I also tried torch_trt.compile(ir="torch_compile") which is just another way to call torch.compile so it failed as well. Not sure the reason.

@zewenli98 zewenli98 added the bug Something isn't working label Oct 6, 2024
@narendasan
Copy link
Collaborator

Have you pulled the engines out to see what they look like?

@zewenli98
Copy link
Collaborator Author

@narendasan It seems like the main difference between workflows of torch.compile() and torch_trt.compile() is that torch.compile() calls py/torch_tensorrt/dynamo/backend/backends.py::pretraced_backend() while torch_trt.compile() calls py/torch_tensorrt/_compile.py::compile() before both calling py/torch_tensorrt/dynamo/_compiler.py::compile_module().

I pulled the engine out and outputs some attributes.

  • When building the engine with strip_engine_weights=True:

    • torch.compile():
      • engine.refittable: True
        engine.num_layers: 186
        engine size: 2038852
    • torch_trt.compile():
      engine.refittable: True
      engine.num_layers: 276
      engine size: 2802668
  • When building the engine with strip_engine_weights=False:

    • torch.compile():
      • engine.refittable: True
        engine.num_layers: 186
        engine size: 2038332
    • torch_trt.compile():
      • engine.refittable: True
        engine.num_layers: 276
        engine size: 77122580

We can see that torch.compile() has less num of layers and its engine sizes are almost same, but we expect the engine size to be smaller when strip_engine_weights=True. Is this because of the pretraced_backend()?

@narendasan
Copy link
Collaborator

torch.compile and torch.export graphs arent going to be identical that is expected. Is the argument being correct parsed from torch.compile(options=...)?

@zewenli98
Copy link
Collaborator Author

I'm calling three compiles like this:

settings = {
    "use_python_runtime": True,
    "enabled_precisions": {torch.float},
    "debug": False,
    "min_block_size": 1,
    "strip_engine_weights": True,
    "refit_identical_engine_weights": False,
}

# 1. Compile with torch_trt.dynamo.compile
gm1 = torch_trt.dynamo.compile(
    exp_program,
    example_inputs,
    **settings,
)
gm1_output = gm1(*example_inputs)

# 2. Compile with torch_trt.compile using dynamo backend
gm2 = torch_trt.compile(
    pyt_model, ir="dynamo", inputs=example_inputs, **settings
)
gm2_output = gm2(*example_inputs)

# 3. Compile with torch.compile using tensorrt backend
gm3 = torch.compile(
    pyt_model,
    backend="tensorrt",
    options=settings,
)
gm3_output = gm3(*example_inputs)

I also output strip_engine_weights after building the TRT engine, which is as expected.

@zewenli98
Copy link
Collaborator Author

As discussions with @narendasan and @peri044, torch.compile's behavior was changed to that all the weights are now registered as inputs, as below:

Inputs: List[Tensor: (64, 3, 7, 7)@float32, Tensor: (100, 3, 224, 224)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (128, 64, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 64, 1, 1)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (256, 128, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 128, 1, 1)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (512, 256, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 256, 1, 1)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (1000, 512)@float32, Tensor: (1000)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (100, 3, 224, 224)@float32, Tensor: (64, 3, 7, 7)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64, 64, 3, 3)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (64)@float32, Tensor: (128, 64, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 64, 1, 1)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128, 128, 3, 3)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (128)@float32, Tensor: (256, 128, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 128, 1, 1)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256, 256, 3, 3)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (256)@float32, Tensor: (512, 256, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 256, 1, 1)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512, 512, 3, 3)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (512)@float32, Tensor: (1000, 512)@float32, Tensor: (1000)@float32]

This might explain why weight stripping is not working properly for torch.compile. We're going to throw a warning that it's not supported for torch.compile and then return the compiled module "with weights"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants