25
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
import argparse
28
+ import contextlib
28
29
import io
29
30
import multiprocessing as mp
30
31
import os
32
+ import os .path
31
33
import socket
32
34
import tempfile
33
35
import typing as t
34
- from contextlib import contextmanager
35
36
from types import TracebackType
36
37
37
38
import numpy as np
52
53
53
54
54
55
if t .TYPE_CHECKING :
55
- # Pylint disables needed for old version of pylint w/ TF 2.6.2
56
- # pylint: disable-next=unused-import
57
56
from multiprocessing .connection import Connection
58
57
59
58
# pylint: disable-next=unsubscriptable-object
@@ -89,12 +88,23 @@ def execute(
89
88
simple experiment
90
89
"""
91
90
backends = installed_redisai_backends ()
91
+ device : _TCapitalDeviceStr = args .device .upper ()
92
92
try :
93
- with _VerificationTempDir (dir = os .getcwd ()) as temp_dir :
93
+ with contextlib .ExitStack () as ctx :
94
+ temp_dir = ctx .enter_context (_VerificationTempDir (dir = os .getcwd ()))
95
+ validate_env = {
96
+ "SR_LOG_LEVEL" : os .environ .get ("SR_LOG_LEVEL" , "INFO" ),
97
+ "SR_LOG_FILE" : os .environ .get (
98
+ "SR_LOG_FILE" , os .path .join (temp_dir , "smartredis.log" )
99
+ ),
100
+ }
101
+ if device == "GPU" :
102
+ validate_env ["CUDA_VISIBLE_DEVICES" ] = "0"
103
+ ctx .enter_context (_env_vars_set_to (validate_env ))
94
104
test_install (
95
105
location = temp_dir ,
96
106
port = args .port ,
97
- device = args . device . upper () ,
107
+ device = device ,
98
108
with_tf = "tensorflow" in backends ,
99
109
with_pt = "torch" in backends ,
100
110
with_onnx = "onnxruntime" in backends ,
@@ -147,18 +157,40 @@ def test_install(
147
157
logger .info ("Verifying Tensor Transfer" )
148
158
client .put_tensor ("plain-tensor" , np .ones ((1 , 1 , 3 , 3 )))
149
159
client .get_tensor ("plain-tensor" )
150
- if with_tf :
151
- logger .info ("Verifying TensorFlow Backend" )
152
- _test_tf_install (client , location , device )
153
160
if with_pt :
154
161
logger .info ("Verifying Torch Backend" )
155
162
_test_torch_install (client , device )
156
163
if with_onnx :
157
164
logger .info ("Verifying ONNX Backend" )
158
165
_test_onnx_install (client , device )
166
+ if with_tf : # Run last in case TF locks an entire GPU
167
+ logger .info ("Verifying TensorFlow Backend" )
168
+ _test_tf_install (client , location , device )
169
+ logger .info ("Success!" )
159
170
160
171
161
- @contextmanager
172
+ @contextlib .contextmanager
173
+ def _env_vars_set_to (
174
+ evars : t .Mapping [str , t .Optional [str ]]
175
+ ) -> t .Generator [None , None , None ]:
176
+ envvars = tuple ((var , os .environ .pop (var , None ), val ) for var , val in evars .items ())
177
+ for var , _ , tmpval in envvars :
178
+ _set_or_del_env_var (var , tmpval )
179
+ try :
180
+ yield
181
+ finally :
182
+ for var , origval , _ in reversed (envvars ):
183
+ _set_or_del_env_var (var , origval )
184
+
185
+
186
+ def _set_or_del_env_var (var : str , val : t .Optional [str ]) -> None :
187
+ if val is not None :
188
+ os .environ [var ] = val
189
+ else :
190
+ os .environ .pop (var , None )
191
+
192
+
193
+ @contextlib .contextmanager
162
194
def _make_managed_local_orc (
163
195
exp : Experiment , port : int
164
196
) -> t .Generator [Client , None , None ]:
@@ -243,9 +275,18 @@ def __init__(self) -> None:
243
275
def forward (self , x : torch .Tensor ) -> torch .Tensor :
244
276
return self .conv (x )
245
277
278
+ if device == "GPU" :
279
+ device_ = torch .device ("cuda" )
280
+ else :
281
+ device_ = torch .device ("cpu" )
282
+
246
283
net = Net ()
247
- forward_input = torch .rand (1 , 1 , 3 , 3 )
284
+ net .to (device_ )
285
+ net .eval ()
286
+
287
+ forward_input = torch .rand (1 , 1 , 3 , 3 ).to (device_ )
248
288
traced = torch .jit .trace (net , forward_input ) # type: ignore[no-untyped-call]
289
+
249
290
buffer = io .BytesIO ()
250
291
torch .jit .save (traced , buffer ) # type: ignore[no-untyped-call]
251
292
model = buffer .getvalue ()
@@ -261,7 +302,7 @@ def _test_onnx_install(client: Client, device: _TCapitalDeviceStr) -> None:
261
302
from sklearn .cluster import KMeans
262
303
263
304
data = np .arange (20 , dtype = np .float32 ).reshape (10 , 2 )
264
- model = KMeans (n_clusters = 2 )
305
+ model = KMeans (n_clusters = 2 , n_init = 10 )
265
306
model .fit (data )
266
307
267
308
kmeans = to_onnx (model , data , target_opset = 11 )
0 commit comments