From 40cc616909ba35dc68e8d11d4570091a4b723a5a Mon Sep 17 00:00:00 2001 From: cdzhan Date: Tue, 21 May 2024 03:20:13 +0000 Subject: [PATCH] =?UTF-8?q?Fix=20caching=20allocator=20of=20out-of-tree=20?= =?UTF-8?q?device=20is=20destructed=20before=20the=20=E2=80=A6=20(#126677)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …destruction of tensors cached by autocast ## Root Cause For out-of-tree device extension it is loaded after torch (different .so), so the global variable `cached_casts` may be constructed before caching allocator and then destructed in reversed order when exit. ## Fix Lazily initialize `cached_casts` to correct the order. ## How to Reproduce && Test Modify the testcase `TestAutocastGPU.test_cast_cache_is_global` in test/test_autocast.py to run on your out-of-tree device. You will see following failure in the end of test. ```bash ---------------------------------------------------------------------- Ran 1 test in 4.812s OK free: 0x30080ff44000400 terminate called after throwing an instance of 'c10::Error' what(): invalid device pointer: 0x30080ff44000400 Exception raised from free at /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/framework/core/caching_allocator.cpp:1609 (most recent call first): frame #0: + 0x118fe1 (0x7ffaef4d3fe1 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #1: + 0x11b1c4 (0x7ffaef4d61c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #2: + 0x117677 (0x7ffaef4d2677 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #3: + 0x11a2bf (0x7ffaef4d52bf in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #4: + 0x11a186 (0x7ffaef4d5186 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #5: + 0x119fde (0x7ffaef4d4fde in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #6: + 0x119d2e (0x7ffaef4d4d2e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #7: + 0x119be0 (0x7ffaef4d4be0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #8: + 0x119977 (0x7ffaef4d4977 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #9: + 0x119313 (0x7ffaef4d4313 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #10: + 0x118b4c (0x7ffaef4d3b4c in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #11: c10::Error::Error(c10::SourceLocation, std::string) + 0x34 (0x7ffaef4d27c4 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #12: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x7f (0x7ffaef4d04ed in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #13: torch_mlu::MLUCachingAllocator::Native::NativeCachingAllocator::free(void*) + 0xe6 (0x7ff9a8eeb112 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame #14: torch_mlu::MLUCachingAllocator::Native::local_raw_delete(void*) + 0x3b (0x7ff9a8ed9480 in /projs/framework/betterman/code/pytorch_new/catch/torch_mlu/csrc/lib/libtorch_mlu.so) frame #15: std::unique_ptr::~unique_ptr() + 0x50 (0x7ffb0a5ea322 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #16: + 0x1269890 (0x7ffb0a5e4890 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #17: + 0x1269928 (0x7ffb0a5e4928 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #18: + 0x127572c (0x7ffb0a5f072c in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #19: + 0x1275758 (0x7ffb0a5f0758 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_python.so) frame #20: + 0xb9bc7 (0x7ffaef474bc7 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #21: + 0xb97bc (0x7ffaef4747bc in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #22: + 0xdbc50 (0x7ffaef496c50 in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #23: c10::TensorImpl::~TensorImpl() + 0x82 (0x7ffaef49157e in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #24: c10::TensorImpl::~TensorImpl() + 0x1c (0x7ffaef4915aa in /projs/framework/betterman/code/pytorch_new/torch/lib/libc10.so) frame #25: + 0x2f596d9 (0x7ffaf24fc6d9 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #26: + 0x2f589c2 (0x7ffaf24fb9c2 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #27: + 0x2f57b92 (0x7ffaf24fab92 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #28: + 0x2f5c228 (0x7ffaf24ff228 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #29: + 0x30f3f70 (0x7ffaf2696f70 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #30: + 0x30f3f90 (0x7ffaf2696f90 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #31: + 0x30f5004 (0x7ffaf2698004 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #32: + 0x30f5024 (0x7ffaf2698024 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #33: + 0x31207f0 (0x7ffaf26c37f0 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #34: + 0x3120814 (0x7ffaf26c3814 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #35: + 0x30f51e8 (0x7ffaf26981e8 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #36: + 0x30f5148 (0x7ffaf2698148 in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #37: + 0x316ecea (0x7ffaf2711cea in /projs/framework/betterman/code/pytorch_new/torch/lib/libtorch_cpu.so) frame #38: + 0x468a7 (0x7ffb0c9ed8a7 in /lib/x86_64-linux-gnu/libc.so.6) frame #39: on_exit + 0 (0x7ffb0c9eda60 in /lib/x86_64-linux-gnu/libc.so.6) frame #47: __libc_start_main + 0xf3 (0x7ffb0c9cb083 in /lib/x86_64-linux-gnu/libc.so.6) Aborted (core dumped) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/126677 Approved by: https://github.com/ezyang --- aten/src/ATen/autocast_mode.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2d01bdeca500b..f0c73cde2dda3 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -35,7 +35,11 @@ namespace { // directly against incoming TensorImpl*s. using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; -ska::flat_hash_map cached_casts; + +static ska::flat_hash_map& get_cached_casts() { + static ska::flat_hash_map cached_casts; + return cached_casts; +} std::mutex cached_casts_mutex; @@ -82,7 +86,7 @@ thread_local bool cache_enabled = true; void clear_cache() { const std::lock_guard lock(cached_casts_mutex); - cached_casts.clear(); + get_cached_casts().clear(); } int increment_nesting() { @@ -124,12 +128,12 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ if (can_try_cache) { const std::lock_guard lock(cached_casts_mutex); - auto it = cached_casts.find(arg.unsafeGetTensorImpl()); - if (it != cached_casts.end()) { + auto it = get_cached_casts().find(arg.unsafeGetTensorImpl()); + if (it != get_cached_casts().end()) { return std::get<1>(it->second); } else { auto casted_arg = arg.to(to_type); - cached_casts.emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); + get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); return casted_arg; } } else {