diff --git a/include/hip/amd_detail/amd_hip_atomic.h b/include/hip/amd_detail/amd_hip_atomic.h index 80e0b3f2..193ba4db 100644 --- a/include/hip/amd_detail/amd_hip_atomic.h +++ b/include/hip/amd_detail/amd_hip_atomic.h @@ -434,13 +434,15 @@ float atomicMin(float* address, float val) { #else unsigned int tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - float value = __uint_as_float(tmp); + float assumed = __uint_as_float(tmp); + float old = assumed; - while (val < value) { - value = atomicCAS(address, value, val); + while (val < assumed && __float_as_uint(assumed) != __float_as_uint(old)) { + assumed = old; + old = atomicCAS(address, assumed, val); } - return value; + return old; } __device__ @@ -452,13 +454,15 @@ float atomicMin_system(float* address, float val) { #else unsigned int tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - float value = __uint_as_float(tmp); + float assumed = __uint_as_float(tmp); + float old = assumed; - while (val < value) { - value = atomicCAS_system(address, value, val); + while (val < assumed && __float_as_uint(assumed) != __float_as_uint(old)) { + assumed = old; + old = atomicCAS_system(address, assumed, val); } - return value; + return old; } __device__ @@ -470,13 +474,15 @@ double atomicMin(double* address, double val) { #else unsigned long long tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - double value = __longlong_as_double(tmp); + double assumed = __longlong_as_double(tmp); + double old = assumed; - while (val < value) { - value = atomicCAS(address, value, val); + while (val < assumed && __double_as_longlong(assumed) != __double_as_longlong(old)) { + assumed = old; + old = atomicCAS(address, assumed, val); } - return value; + return old; } __device__ @@ -488,13 +494,15 @@ double atomicMin_system(double* address, double val) { #else unsigned long long tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - double value = __longlong_as_double(tmp); + double assumed = __longlong_as_double(tmp); + double old = assumed; - while (val < value) { - value = atomicCAS_system(address, value, val); + while (val < assumed && __double_as_longlong(assumed) != __double_as_longlong(old)) { + assumed = old; + old = atomicCAS_system(address, assumed, val); } - return value; + return old; } __device__ @@ -554,13 +562,15 @@ float atomicMax(float* address, float val) { #else unsigned int tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - float value = __uint_as_float(tmp); + float assumed = __uint_as_float(tmp); + float old = assumed; - while (value < val) { - value = atomicCAS(address, value, val); - } + while (val > assumed && __float_as_uint(assumed) != __float_as_uint(old)) { + assumed = old; + old = atomicCAS(address, assumed, val); + } - return value; + return old; } __device__ @@ -572,13 +582,15 @@ float atomicMax_system(float* address, float val) { #else unsigned int tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - float value = __uint_as_float(tmp); + float assumed = __uint_as_float(tmp); + float old = assumed; - while (value < val) { - value = atomicCAS_system(address, value, val); - } + while (val > assumed && __float_as_uint(assumed) != __float_as_uint(old)) { + assumed = old; + old = atomicCAS_system(address, assumed, val); + } - return value; + return old; } __device__ @@ -590,13 +602,15 @@ double atomicMax(double* address, double val) { #else unsigned long long tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - double value = __longlong_as_double(tmp); + double assumed = __longlong_as_double(tmp); + double old = assumed; - while (value < val) { - value = atomicCAS(address, value, val); + while (val > assumed && __double_as_longlong(assumed) != __double_as_longlong(old)) { + assumed = old; + old = atomicCAS(address, assumed, val); } - return value; + return old; } __device__ @@ -608,13 +622,15 @@ double atomicMax_system(double* address, double val) { #else unsigned long long tmp {__atomic_load_n(uaddr, __ATOMIC_RELAXED)}; #endif - double value = __longlong_as_double(tmp); + double assumed = __longlong_as_double(tmp); + double old = assumed; - while (value < val) { - value = atomicCAS_system(address, value, val); + while (val > assumed && __double_as_longlong(assumed) != __double_as_longlong(old)) { + assumed = old; + old = atomicCAS_system(address, assumed, val); } - return value; + return old; } __device__