From 106d70f1267c11d8951d478f04bb8623d9b72de8 Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Thu, 1 Feb 2024 20:49:28 -0800 Subject: [PATCH] Add eval() to remove Torch warnings during testing (#472) The Torch eval() function is invoked in the tests to resolve warnings related to model tracing. [ reviewed by @MattToast @ashao ] [ committed by @mellis13 ] --- doc/changelog.rst | 5 ++++- tests/backends/test_dbmodel.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index ace497a73..cc16fa095 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -19,6 +19,7 @@ To be released at some future point in time Description +- Updated tests to address Torch warning - Updated GitHub actions to latest versions in CI/CD - Dropped Cobalt support - Override the sphinx-tabs extension background color @@ -28,6 +29,8 @@ Description Detailed Notes +- Tests that were saving Torch models were emitting warnings. These warnings + were addressed by updating the model save test function. (SmartSim-PR472_) - Some actions in the current GitHub CI/CD workflows were outdated. They were replaced with their latest versions. (SmartSim-PR446_) - As the Cobalt workload manager is not used on any system we are aware of, @@ -42,12 +45,12 @@ Detailed Notes all of SmartSim's machine learning backends with Python 3.11. (SmartSim-PR451_) (SmartSim-PR461_) - .. _SmartSim-PR446: https://github.com/CrayLabs/SmartSim/pull/446 .. _SmartSim-PR448: https://github.com/CrayLabs/SmartSim/pull/448 .. _SmartSim-PR451: https://github.com/CrayLabs/SmartSim/pull/451 .. _SmartSim-PR453: https://github.com/CrayLabs/SmartSim/pull/453 .. _SmartSim-PR461: https://github.com/CrayLabs/SmartSim/pull/461 +.. _SmartSim-PR472: https://github.com/CrayLabs/SmartSim/pull/472 0.6.0 diff --git a/tests/backends/test_dbmodel.py b/tests/backends/test_dbmodel.py index 1cfc1efcb..84e708f76 100644 --- a/tests/backends/test_dbmodel.py +++ b/tests/backends/test_dbmodel.py @@ -138,6 +138,7 @@ def create_tf_cnn(): def save_torch_cnn(path, file_name): n = PyTorchNet() + n.eval() example_forward_input = torch.rand(1, 1, 28, 28) module = torch.jit.trace(n, example_forward_input) torch.jit.save(module, path + "/" + file_name)