Skip to content

Commit

Permalink
Add eval() to remove Torch warnings during testing (#472)
Browse files Browse the repository at this point in the history
The Torch eval() function is invoked in the tests to resolve
warnings related to model tracing.

[ reviewed by @MattToast @ashao ]
[ committed by @mellis13 ]
  • Loading branch information
mellis13 authored Feb 2, 2024
1 parent 948d97c commit 106d70f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ To be released at some future point in time

Description

- Updated tests to address Torch warning
- Updated GitHub actions to latest versions in CI/CD
- Dropped Cobalt support
- Override the sphinx-tabs extension background color
Expand All @@ -28,6 +29,8 @@ Description

Detailed Notes

- Tests that were saving Torch models were emitting warnings. These warnings
were addressed by updating the model save test function. (SmartSim-PR472_)
- Some actions in the current GitHub CI/CD workflows were outdated. They were
replaced with their latest versions. (SmartSim-PR446_)
- As the Cobalt workload manager is not used on any system we are aware of,
Expand All @@ -42,12 +45,12 @@ Detailed Notes
all of SmartSim's machine learning backends with Python 3.11.
(SmartSim-PR451_) (SmartSim-PR461_)


.. _SmartSim-PR446: https://github.com/CrayLabs/SmartSim/pull/446
.. _SmartSim-PR448: https://github.com/CrayLabs/SmartSim/pull/448
.. _SmartSim-PR451: https://github.com/CrayLabs/SmartSim/pull/451
.. _SmartSim-PR453: https://github.com/CrayLabs/SmartSim/pull/453
.. _SmartSim-PR461: https://github.com/CrayLabs/SmartSim/pull/461
.. _SmartSim-PR472: https://github.com/CrayLabs/SmartSim/pull/472


0.6.0
Expand Down
1 change: 1 addition & 0 deletions tests/backends/test_dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def create_tf_cnn():

def save_torch_cnn(path, file_name):
n = PyTorchNet()
n.eval()
example_forward_input = torch.rand(1, 1, 28, 28)
module = torch.jit.trace(n, example_forward_input)
torch.jit.save(module, path + "/" + file_name)
Expand Down

0 comments on commit 106d70f

Please sign in to comment.