Skip to content

Commit

Permalink
torch.export.export tutorial (#2620)
Browse files Browse the repository at this point in the history
Co-authored-by: Ankith Gunapal <[email protected]>
  • Loading branch information
msaroufim and agunapal authored Oct 6, 2023
1 parent 726dad2 commit cddbcfe
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python ts_scripts/install_dependencies.py --cuda=cu118
pip install torchserve torch-model-archiver
```

## Package your model
## torch.compile

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging

Expand All @@ -34,10 +34,10 @@ The exact same approach works with any other model, what's going on is the below
opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the original module (weights are shared)
torch.save(model, "model.pt")
# 3. Save the opt module state dict
torch.save(opt_model.state_dict(), "model.pt")

# 4. Load the non optimized model
# 4. Reload the model
mod = torch.load(model)

# 5. Compile the module and then run inferences with it
Expand All @@ -46,6 +46,47 @@ opt_mod = torch.compile(mod)

torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models.

## Next steps
## torch.export.export

Export your model from a training script, keep in mind that an exported model cannot have graph breaks.

```python
import io
import torch

class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10

ep = torch.export.export(MyModule(), (torch.randn(5),))

# Save to file
# torch.export.save(ep, 'exported_program.pt2')
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)
```

Serve your exported model from a custom handler

```python
# from initialize()
ep = torch.export.load('exported_program.pt2')

with open('exported_program.pt2', 'rb') as f:
buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)

# Make sure everything looks good
print(ep)
print(extra_files['foo.txt'])

# from inference()
print(ep(torch.randn(5)))
```


For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. You can learn more here https://github.com/pytorch/serve/blob/master/docs/performance_guide.md

0 comments on commit cddbcfe

Please sign in to comment.