From 4c005e28f0d0954b5df5ea3bd6e7b940facb15c0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 19:45:32 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/td3/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 071bad6c68b..cdf52927158 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -142,11 +142,13 @@ def make_replay_buffer( ): if compile: prefetch = 0 - with ( - tempfile.TemporaryDirectory() - if scratch_dir in ("", None) - else nullcontext(scratch_dir) - ) as scratch_dir: + if scratch_dir in ("", None): + ctx = nullcontext(None) + elif scratch_dir == "temp": + ctx = tempfile.TemporaryDirectory() + else: + ctx = nullcontext(scratch_dir) + with ctx as scratch_dir: storage_cls = ( functools.partial(LazyTensorStorage, device=device, compilable=compile) if not scratch_dir