From 9e28b002070276a852de6b5508224d35d2547d51 Mon Sep 17 00:00:00 2001 From: Brian Wieder Date: Fri, 4 Oct 2024 05:54:44 -0700 Subject: [PATCH] Clear caches on jax exit. PiperOrigin-RevId: 682288160 --- xla/python/pjit.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index a64ab9808459b..f4c8d242a469e 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -135,7 +135,10 @@ class PjitFunctionCache { int Size() const { return lru_list_.Size(); } int Capacity() const { return lru_list_.Capacity(); } - void Clear() { lru_list_.Clear(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } private: struct Key { @@ -347,6 +350,7 @@ class PjitFunctionStore { for (auto* function : compiled_functions_) { function->ClearCache(); } + compiled_functions_.clear(); } private: