Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow dict type module_outputs in infer_module_output_dtypes #3133

Conversation

jiwoong-choi
Copy link
Contributor

Description

A graph module's output might have nested structures depending on the implementation. For example, many models from transformers returns output of type ModelOutput (e.g. CausalLMOutputsWithPast).
This PR doesn't aim to handle all possible nested pytree structures imposed by graph module outputs. However, this simple fix at least allows the model output to be a non-nested dictionary (or a subclass of dict) of tensors (or values that could be converted to tensor via torch.tensor).

P.S. the comment in the subsequent for loop in the code # We don't need to check if output is nested here because the input module will be flattened is misleading. We need to actually handle the nested outputs here.

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Aug 30, 2024
@narendasan
Copy link
Collaborator

narendasan commented Aug 30, 2024

@peri044 isnt AOTExport supposed to handle this for us / we are trying to get away from dummy runs for output calculations types?

@peri044
Copy link
Collaborator

peri044 commented Aug 30, 2024

@jiwoong-choi Can you provide a model where you have seen this issue ? We have tested some huggingface models and haven't encountered this issue. From my understanding, graphmodule should have tensors/scalars in and tensors/scalars out. The nested ones should be handled by the pytree flatten/unflatten that's encoded in the in_spec and out_spec of the graph module.

@jiwoong-choi
Copy link
Contributor Author

The graph module can indeed output pretty much every nested tensors format (internally using pytree.unflattenin my understanding).
Here's a simple example code that can demonstrate what I've just said.

import torch
from transformers import BertModel

def main():
    model = BertModel.from_pretrained("bert-base-uncased")
    example_kwargs = {
        "input_ids": torch.randint(100, 1000, (1, 128), dtype=torch.int64),
        "token_type_ids": torch.zeros(1, 128, dtype=torch.int64),
        "attention_mask": torch.ones(1, 128, dtype=torch.int64),
        "return_dict": True
    }
    exported_program = torch.export.export(
        model, args=(), kwargs=example_kwargs,
    )
    graph_module = exported_program.module()
    outputs = graph_module(**example_kwargs)
    print(type(outputs))

if __name__ == "__main__":
    main()

The output of the code is <class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>, which is a subclass of OrderedDict.

@jiwoong-choi
Copy link
Contributor Author

FYI if you run graph_module.print_readable() from the above example code, the last few lines of the graph module's code looks as follows:

... (omitted)

# File: /home/hdd/jiwoongchoi/micromamba/envs/torch-trt/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:746 in forward, code: pooled_output = self.dense(first_token_tensor)
linear_72: "f32[1, 768]" = torch.ops.aten.linear.default(select, pooler_dense_weight, pooler_dense_bias);  select = pooler_dense_weight = pooler_dense_bias = None
        
# File: /home/hdd/jiwoongchoi/micromamba/envs/torch-trt/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:747 in forward, code: pooled_output = self.activation(pooled_output)
tanh: "f32[1, 768]" = torch.ops.aten.tanh.default(linear_72);  linear_72 = None
return pytree.tree_unflatten((layer_norm_24, tanh), self._out_spec)

As you can see, the output tensors layer_norm_24 and tanh are passed to pytree.tree_unflatten to reconstruct the final output as the original output spec encapsulated in self._out_spec.

@peri044
Copy link
Collaborator

peri044 commented Oct 18, 2024

@jiwoong-choi Sorry for the delay. The graph_module.graph here is as follows

...
...
...
%layer_norm_24 : [num_users=2] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_25, [768], %encoder_layer_11_output_layer_norm_weight, %encoder_layer_11_output_layer_norm_bias, 1e-12), kwargs = {})
    %slice_5 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%layer_norm_24, 0, 0, 9223372036854775807), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_5, 1, 0), kwargs = {})
    %linear_72 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%select, %pooler_dense_weight, %pooler_dense_bias), kwargs = {})
    %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%linear_72,), kwargs = {})
    return (layer_norm_24, tanh)

The outputs are passed to pytree.unflatten which constructs BaseModelOutputWithPoolingAndCrossAttentions. From TensorRT compilation point of view, we are interested until pytree.unflatten which returns a list of tensors. The unflattening part is left to Pytorch. The infer_module_output_dtypes call is meant to deduce output types but the way it is currently doing is incorrect. Especially in the case you described, the kwarg inputs are not being passed correctly.

We are planning to refactor this function and compute output datatypes using output node metadata in the graph (node.meta["val"]). This way, we don't have to handle the nesting or custom objects. This refactor will be a part of #3212. So, we can close this PR. Let me know if you have further questions.

@jiwoong-choi
Copy link
Contributor Author

jiwoong-choi commented Nov 6, 2024

We are planning to refactor this function and compute output datatypes using output node metadata in the graph (node.meta["val"]). This way, we don't have to handle the nesting or custom objects. This refactor will be a part of #3212. So, we can close this PR. Let me know if you have further questions.

@peri044 That sounds like the right way to go. I'm looking forward to the refactoring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants