Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into torchserve_tutorial_aws
Browse files Browse the repository at this point in the history
  • Loading branch information
Viditagarwal7479 authored Nov 10, 2023
2 parents 0b88b33 + 16e4f2a commit eed4a97
Show file tree
Hide file tree
Showing 12 changed files with 1,387 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .pyspelling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ matrix:
- open: '\.\.\s+(figure|literalinclude|math|image|grid)::'
close: '\n'
# Exclude roles:
- open: ':(?:(class|py:mod|mod|func)):`'
- open: ':(?:(class|py:mod|mod|func|meth|obj)):`'
content: '[^`]*'
close: '`'
# Exclude reStructuredText hyperlinks
Expand Down Expand Up @@ -70,7 +70,7 @@ matrix:
- open: ':figure:.*'
close: '\n'
# Ignore reStructuredText roles
- open: ':(?:(class|file|func|math|ref|octicon)):`'
- open: ':(?:(class|file|func|math|ref|octicon|meth|obj)):`'
content: '[^`]*'
close: '`'
- open: ':width:'
Expand Down
Binary file added _static/img/pendulum.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added _static/img/rollout_recurrent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
912 changes: 912 additions & 0 deletions advanced_source/pendulum.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions advanced_source/static_quantization_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,15 @@ Note: this code is taken from
# Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
# This operation does not change the numerics
def fuse_model(self):
def fuse_model(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
for m in self.modules():
if type(m) == ConvBNReLU:
torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == InvertedResidual:
for idx in range(len(m.conv)):
if type(m.conv[idx]) == nn.Conv2d:
torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
2. Helper functions
-------------------
Expand Down Expand Up @@ -533,7 +534,7 @@ We fuse modules as before
.. code:: python
qat_model = load_model(saved_model_dir + float_model_file)
qat_model.fuse_model()
qat_model.fuse_model(is_qat=True)
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
# The old 'fbgemm' is still available but 'x86' is the recommended default.
Expand Down
5 changes: 3 additions & 2 deletions beginner_source/introyt/tensorboardyt_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ def forward(self, x):
# Check against the validation set
running_vloss = 0.0

net.train(False) # Don't need to track gradents for validation
# In evaluation mode some model specific operations can be omitted eg. dropout layer
net.train(False) # Switching to evaluation mode, eg. turning off regularisation
for j, vdata in enumerate(validation_loader, 0):
vinputs, vlabels = vdata
voutputs = net(vinputs)
vloss = criterion(voutputs, vlabels)
running_vloss += vloss.item()
net.train(True) # Turn gradients back on for training
net.train(True) # Switching back to training mode, eg. turning on regularisation

avg_loss = running_loss / 1000
avg_vloss = running_vloss / len(validation_loader)
Expand Down
1 change: 1 addition & 0 deletions en-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Colab
Conv
ConvNet
ConvNets
customizable
DCGAN
DCGANs
DDP
Expand Down
15 changes: 14 additions & 1 deletion index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,26 @@ What's new in PyTorch tutorials?
:link: intermediate/mario_rl_tutorial.html
:tags: Reinforcement-Learning

.. customcarditem::
:header: Recurrent DQN
:card_description: Use TorchRL to train recurrent policies
:image: _static/img/rollout_recurrent.png
:link: intermediate/dqn_with_rnn_tutorial.html
:tags: Reinforcement-Learning

.. customcarditem::
:header: Code a DDPG Loss
:card_description: Use TorchRL to code a DDPG Loss
:image: _static/img/half_cheetah.gif
:link: advanced/coding_ddpg.html
:tags: Reinforcement-Learning


.. customcarditem::
:header: Writing your environment and transforms
:card_description: Use TorchRL to code a Pendulum
:image: _static/img/pendulum.gif
:link: advanced/pendulum.html
:tags: Reinforcement-Learning

.. Deploying PyTorch Models in Production
Expand Down Expand Up @@ -951,6 +963,7 @@ Additional Resources
intermediate/reinforcement_q_learning
intermediate/reinforcement_ppo
intermediate/mario_rl_tutorial
advanced/pendulum

.. toctree::
:maxdepth: 2
Expand Down
Loading

0 comments on commit eed4a97

Please sign in to comment.