diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index b234af6a804..9afbf8095f0 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -459,7 +459,7 @@ def _cache_values(fun): def new_fun(self, netname=None): __dict__ = self.__dict__ - _cache = __dict__["_cache"] + _cache = __dict__.setdefault("_cache", {}) attr_name = name if netname is not None: attr_name += "_" + netname