-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: allow dict
type module_outputs
in infer_module_output_dtypes
#3133
fix: allow dict
type module_outputs
in infer_module_output_dtypes
#3133
Conversation
@peri044 isnt AOTExport supposed to handle this for us / we are trying to get away from dummy runs for output calculations types? |
@jiwoong-choi Can you provide a model where you have seen this issue ? We have tested some huggingface models and haven't encountered this issue. From my understanding, graphmodule should have tensors/scalars in and tensors/scalars out. The nested ones should be handled by the pytree flatten/unflatten that's encoded in the in_spec and out_spec of the graph module. |
The graph module can indeed output pretty much every nested tensors format (internally using import torch
from transformers import BertModel
def main():
model = BertModel.from_pretrained("bert-base-uncased")
example_kwargs = {
"input_ids": torch.randint(100, 1000, (1, 128), dtype=torch.int64),
"token_type_ids": torch.zeros(1, 128, dtype=torch.int64),
"attention_mask": torch.ones(1, 128, dtype=torch.int64),
"return_dict": True
}
exported_program = torch.export.export(
model, args=(), kwargs=example_kwargs,
)
graph_module = exported_program.module()
outputs = graph_module(**example_kwargs)
print(type(outputs))
if __name__ == "__main__":
main() The output of the code is |
FYI if you run
As you can see, the output tensors |
@jiwoong-choi Sorry for the delay. The ...
...
...
%layer_norm_24 : [num_users=2] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_25, [768], %encoder_layer_11_output_layer_norm_weight, %encoder_layer_11_output_layer_norm_bias, 1e-12), kwargs = {})
%slice_5 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%layer_norm_24, 0, 0, 9223372036854775807), kwargs = {})
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%slice_5, 1, 0), kwargs = {})
%linear_72 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%select, %pooler_dense_weight, %pooler_dense_bias), kwargs = {})
%tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%linear_72,), kwargs = {})
return (layer_norm_24, tanh) The outputs are passed to pytree.unflatten which constructs We are planning to refactor this function and compute output datatypes using output node metadata in the graph (node.meta["val"]). This way, we don't have to handle the nesting or custom objects. This refactor will be a part of #3212. So, we can close this PR. Let me know if you have further questions. |
@peri044 That sounds like the right way to go. I'm looking forward to the refactoring. |
Description
A graph module's output might have nested structures depending on the implementation. For example, many models from transformers returns output of type ModelOutput (e.g. CausalLMOutputsWithPast).
This PR doesn't aim to handle all possible nested pytree structures imposed by graph module outputs. However, this simple fix at least allows the model output to be a non-nested dictionary (or a subclass of
dict
) of tensors (or values that could be converted to tensor viatorch.tensor
).P.S. the comment in the subsequent for loop in the code
# We don't need to check if output is nested here because the input module will be flattened
is misleading. We need to actually handle the nested outputs here.Type of change
Please delete options that are not relevant and/or add your own.
Checklist: