diff --git a/smartsim/_core/_cli/validate.py b/smartsim/_core/_cli/validate.py index bda254859..5794e768f 100644 --- a/smartsim/_core/_cli/validate.py +++ b/smartsim/_core/_cli/validate.py @@ -25,13 +25,14 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse +import contextlib import io import multiprocessing as mp import os +import os.path import socket import tempfile import typing as t -from contextlib import contextmanager from types import TracebackType import numpy as np @@ -52,8 +53,6 @@ if t.TYPE_CHECKING: - # Pylint disables needed for old version of pylint w/ TF 2.6.2 - # pylint: disable-next=unused-import from multiprocessing.connection import Connection # pylint: disable-next=unsubscriptable-object @@ -89,12 +88,23 @@ def execute( simple experiment """ backends = installed_redisai_backends() + device: _TCapitalDeviceStr = args.device.upper() try: - with _VerificationTempDir(dir=os.getcwd()) as temp_dir: + with contextlib.ExitStack() as ctx: + temp_dir = ctx.enter_context(_VerificationTempDir(dir=os.getcwd())) + validate_env = { + "SR_LOG_LEVEL": os.environ.get("SR_LOG_LEVEL", "INFO"), + "SR_LOG_FILE": os.environ.get( + "SR_LOG_FILE", os.path.join(temp_dir, "smartredis.log") + ), + } + if device == "GPU": + validate_env["CUDA_VISIBLE_DEVICES"] = "0" + ctx.enter_context(_env_vars_set_to(validate_env)) test_install( location=temp_dir, port=args.port, - device=args.device.upper(), + device=device, with_tf="tensorflow" in backends, with_pt="torch" in backends, with_onnx="onnxruntime" in backends, @@ -147,18 +157,40 @@ def test_install( logger.info("Verifying Tensor Transfer") client.put_tensor("plain-tensor", np.ones((1, 1, 3, 3))) client.get_tensor("plain-tensor") - if with_tf: - logger.info("Verifying TensorFlow Backend") - _test_tf_install(client, location, device) if with_pt: logger.info("Verifying Torch Backend") _test_torch_install(client, device) if with_onnx: logger.info("Verifying ONNX Backend") _test_onnx_install(client, device) + if with_tf: # Run last in case TF locks an entire GPU + logger.info("Verifying TensorFlow Backend") + _test_tf_install(client, location, device) + logger.info("Success!") -@contextmanager +@contextlib.contextmanager +def _env_vars_set_to( + evars: t.Mapping[str, t.Optional[str]] +) -> t.Generator[None, None, None]: + envvars = tuple((var, os.environ.pop(var, None), val) for var, val in evars.items()) + for var, _, tmpval in envvars: + _set_or_del_env_var(var, tmpval) + try: + yield + finally: + for var, origval, _ in reversed(envvars): + _set_or_del_env_var(var, origval) + + +def _set_or_del_env_var(var: str, val: t.Optional[str]) -> None: + if val is not None: + os.environ[var] = val + else: + os.environ.pop(var, None) + + +@contextlib.contextmanager def _make_managed_local_orc( exp: Experiment, port: int ) -> t.Generator[Client, None, None]: @@ -243,9 +275,18 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) + if device == "GPU": + device_ = torch.device("cuda") + else: + device_ = torch.device("cpu") + net = Net() - forward_input = torch.rand(1, 1, 3, 3) + net.to(device_) + net.eval() + + forward_input = torch.rand(1, 1, 3, 3).to(device_) traced = torch.jit.trace(net, forward_input) # type: ignore[no-untyped-call] + buffer = io.BytesIO() torch.jit.save(traced, buffer) # type: ignore[no-untyped-call] model = buffer.getvalue() @@ -261,7 +302,7 @@ def _test_onnx_install(client: Client, device: _TCapitalDeviceStr) -> None: from sklearn.cluster import KMeans data = np.arange(20, dtype=np.float32).reshape(10, 2) - model = KMeans(n_clusters=2) + model = KMeans(n_clusters=2, n_init=10) model.fit(data) kmeans = to_onnx(model, data, target_opset=11) diff --git a/tests/test_cli.py b/tests/test_cli.py index 899caa1e0..b0c0a15b3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -828,3 +828,33 @@ def test_cli_validation_test_execute( assert expected_stdout in caplog.text assert actual_retval == expected_retval + + +def test_validate_correctly_sets_and_restores_env(monkeypatch): + monkeypatch.setenv("FOO", "BAR") + monkeypatch.setenv("SPAM", "EGGS") + monkeypatch.delenv("TICK", raising=False) + monkeypatch.delenv("DNE", raising=False) + + assert os.environ["FOO"] == "BAR" + assert os.environ["SPAM"] == "EGGS" + assert "TICK" not in os.environ + assert "DNE" not in os.environ + + with smartsim._core._cli.validate._env_vars_set_to( + { + "FOO": "BAZ", # Redefine + "SPAM": None, # Delete + "TICK": "TOCK", # Add + "DNE": None, # Delete already missing + } + ): + assert os.environ["FOO"] == "BAZ" + assert "SPAM" not in os.environ + assert os.environ["TICK"] == "TOCK" + assert "DNE" not in os.environ + + assert os.environ["FOO"] == "BAR" + assert os.environ["SPAM"] == "EGGS" + assert "TICK" not in os.environ + assert "DNE" not in os.environ