From 6b2237c27e85d667b56a1d8b6de32258c45004ff Mon Sep 17 00:00:00 2001 From: Brian Wieder Date: Wed, 2 Oct 2024 11:31:38 -0700 Subject: [PATCH] Clear caches on jax exit. PiperOrigin-RevId: 681529928 --- xla/python/pjit.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index a64ab9808459b..f4c8d242a469e 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -135,7 +135,10 @@ class PjitFunctionCache { int Size() const { return lru_list_.Size(); } int Capacity() const { return lru_list_.Capacity(); } - void Clear() { lru_list_.Clear(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } private: struct Key { @@ -347,6 +350,7 @@ class PjitFunctionStore { for (auto* function : compiled_functions_) { function->ClearCache(); } + compiled_functions_.clear(); } private: