Skip to content

Commit 35973b5

Browse files
authored
Quality of life smart validate Improvements (#458)
Quality of life `smart validate` improvements: - Set `CUDA_VISIBLE_DEVICES` environment variable within `smart validate` prior to importing any ML deps to prevent false negatives on multi-GPU systems - Move SmartRedis logs from standard out to dedicated log file in the validation temporary directory - Suppress `sklearn` deprecation warning by pinning `KMeans` constructor argument - Move TF test to last as TF may reserve the GPUs it uses [ committed by @MattToast ] [ reviewed by @al-rigazzi @ashao ]
1 parent cab2ef8 commit 35973b5

File tree

2 files changed

+82
-11
lines changed

2 files changed

+82
-11
lines changed

smartsim/_core/_cli/validate.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import argparse
28+
import contextlib
2829
import io
2930
import multiprocessing as mp
3031
import os
32+
import os.path
3133
import socket
3234
import tempfile
3335
import typing as t
34-
from contextlib import contextmanager
3536
from types import TracebackType
3637

3738
import numpy as np
@@ -52,8 +53,6 @@
5253

5354

5455
if t.TYPE_CHECKING:
55-
# Pylint disables needed for old version of pylint w/ TF 2.6.2
56-
# pylint: disable-next=unused-import
5756
from multiprocessing.connection import Connection
5857

5958
# pylint: disable-next=unsubscriptable-object
@@ -89,12 +88,23 @@ def execute(
8988
simple experiment
9089
"""
9190
backends = installed_redisai_backends()
91+
device: _TCapitalDeviceStr = args.device.upper()
9292
try:
93-
with _VerificationTempDir(dir=os.getcwd()) as temp_dir:
93+
with contextlib.ExitStack() as ctx:
94+
temp_dir = ctx.enter_context(_VerificationTempDir(dir=os.getcwd()))
95+
validate_env = {
96+
"SR_LOG_LEVEL": os.environ.get("SR_LOG_LEVEL", "INFO"),
97+
"SR_LOG_FILE": os.environ.get(
98+
"SR_LOG_FILE", os.path.join(temp_dir, "smartredis.log")
99+
),
100+
}
101+
if device == "GPU":
102+
validate_env["CUDA_VISIBLE_DEVICES"] = "0"
103+
ctx.enter_context(_env_vars_set_to(validate_env))
94104
test_install(
95105
location=temp_dir,
96106
port=args.port,
97-
device=args.device.upper(),
107+
device=device,
98108
with_tf="tensorflow" in backends,
99109
with_pt="torch" in backends,
100110
with_onnx="onnxruntime" in backends,
@@ -147,18 +157,40 @@ def test_install(
147157
logger.info("Verifying Tensor Transfer")
148158
client.put_tensor("plain-tensor", np.ones((1, 1, 3, 3)))
149159
client.get_tensor("plain-tensor")
150-
if with_tf:
151-
logger.info("Verifying TensorFlow Backend")
152-
_test_tf_install(client, location, device)
153160
if with_pt:
154161
logger.info("Verifying Torch Backend")
155162
_test_torch_install(client, device)
156163
if with_onnx:
157164
logger.info("Verifying ONNX Backend")
158165
_test_onnx_install(client, device)
166+
if with_tf: # Run last in case TF locks an entire GPU
167+
logger.info("Verifying TensorFlow Backend")
168+
_test_tf_install(client, location, device)
169+
logger.info("Success!")
159170

160171

161-
@contextmanager
172+
@contextlib.contextmanager
173+
def _env_vars_set_to(
174+
evars: t.Mapping[str, t.Optional[str]]
175+
) -> t.Generator[None, None, None]:
176+
envvars = tuple((var, os.environ.pop(var, None), val) for var, val in evars.items())
177+
for var, _, tmpval in envvars:
178+
_set_or_del_env_var(var, tmpval)
179+
try:
180+
yield
181+
finally:
182+
for var, origval, _ in reversed(envvars):
183+
_set_or_del_env_var(var, origval)
184+
185+
186+
def _set_or_del_env_var(var: str, val: t.Optional[str]) -> None:
187+
if val is not None:
188+
os.environ[var] = val
189+
else:
190+
os.environ.pop(var, None)
191+
192+
193+
@contextlib.contextmanager
162194
def _make_managed_local_orc(
163195
exp: Experiment, port: int
164196
) -> t.Generator[Client, None, None]:
@@ -243,9 +275,18 @@ def __init__(self) -> None:
243275
def forward(self, x: torch.Tensor) -> torch.Tensor:
244276
return self.conv(x)
245277

278+
if device == "GPU":
279+
device_ = torch.device("cuda")
280+
else:
281+
device_ = torch.device("cpu")
282+
246283
net = Net()
247-
forward_input = torch.rand(1, 1, 3, 3)
284+
net.to(device_)
285+
net.eval()
286+
287+
forward_input = torch.rand(1, 1, 3, 3).to(device_)
248288
traced = torch.jit.trace(net, forward_input) # type: ignore[no-untyped-call]
289+
249290
buffer = io.BytesIO()
250291
torch.jit.save(traced, buffer) # type: ignore[no-untyped-call]
251292
model = buffer.getvalue()
@@ -261,7 +302,7 @@ def _test_onnx_install(client: Client, device: _TCapitalDeviceStr) -> None:
261302
from sklearn.cluster import KMeans
262303

263304
data = np.arange(20, dtype=np.float32).reshape(10, 2)
264-
model = KMeans(n_clusters=2)
305+
model = KMeans(n_clusters=2, n_init=10)
265306
model.fit(data)
266307

267308
kmeans = to_onnx(model, data, target_opset=11)

tests/test_cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,33 @@ def test_cli_validation_test_execute(
828828

829829
assert expected_stdout in caplog.text
830830
assert actual_retval == expected_retval
831+
832+
833+
def test_validate_correctly_sets_and_restores_env(monkeypatch):
834+
monkeypatch.setenv("FOO", "BAR")
835+
monkeypatch.setenv("SPAM", "EGGS")
836+
monkeypatch.delenv("TICK", raising=False)
837+
monkeypatch.delenv("DNE", raising=False)
838+
839+
assert os.environ["FOO"] == "BAR"
840+
assert os.environ["SPAM"] == "EGGS"
841+
assert "TICK" not in os.environ
842+
assert "DNE" not in os.environ
843+
844+
with smartsim._core._cli.validate._env_vars_set_to(
845+
{
846+
"FOO": "BAZ", # Redefine
847+
"SPAM": None, # Delete
848+
"TICK": "TOCK", # Add
849+
"DNE": None, # Delete already missing
850+
}
851+
):
852+
assert os.environ["FOO"] == "BAZ"
853+
assert "SPAM" not in os.environ
854+
assert os.environ["TICK"] == "TOCK"
855+
assert "DNE" not in os.environ
856+
857+
assert os.environ["FOO"] == "BAR"
858+
assert os.environ["SPAM"] == "EGGS"
859+
assert "TICK" not in os.environ
860+
assert "DNE" not in os.environ

0 commit comments

Comments
 (0)