From 2d36061d0ae88b409403ffc120b5d1c7ef4e583c Mon Sep 17 00:00:00 2001 From: Sotiris Lamprinidis Date: Tue, 15 Aug 2023 11:43:09 +0200 Subject: [PATCH 1/4] Port openlibm log1pf as log1p --- src/device/intrinsics/math.jl | 105 ++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 1179bff1a..ad0ab20a2 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -1,6 +1,7 @@ # Math function mappings to Metal intrinsics using Base: FastMath +using Base.Math: throw_complex_domainerror # TODO: # - wrap all intrinsics from include/metal/metal_math @@ -101,6 +102,110 @@ using Base: FastMath @device_override Base.log10(x::Float32) = ccall("extern air.log10.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log10(x::Float16) = ccall("extern air.log10.f16", llvmcall, Float16, (Float16,), x) +# Implementation of `log1p(::Float32)` from openlibm's `log1pf` +# https://github.com/JuliaMath/openlibm +const ln2_hi = 0.6931381f0 +const ln2_lo = 9.058001f-6 +const Lp1 = 0.6666667f0 +const Lp2 = 0.4f0 +const Lp3 = 0.2857143f0 +const Lp4 = 0.22222199f0 +const Lp5 = 0.18183573f0 +const Lp6 = 0.15313838f0 +const Lp7 = 0.14798199f0 + +@device_override function Base.Math.log1p(x::Float32) + hx = reinterpret(Int32, x) + ax = hx & 0x7fffffff # |x| + + k = 1 + if hx < 0x3ed413d0 # x < sqrt(2) - 1 + if ax >= 0x3f800000 # |x| ≥ 1 + if x == -1 + return -Inf32 + elseif isnan(x) + return NaN32 + else # x < -1 + # TODO: switch to throw_complex_domainerror_neg1 for next Julia release + throw_complex_domainerror(:log1p, x) + end + end + + if ax < 0x38000000 # |x| < 2^-15 + if ax < 0x33800000 # |x| < 2^-24 + return x # Inexact + else + return x - x*x*0.5f0 + end + end + + if hx>0||hx<=reinterpret(Int32, 0xbe95f619) # (sqrt(2)/2)-1 <= 1+x + k = 0 + f = x + hu = 1f0 + end + end # hx < 0x3ed413d0 + + if hx >= 0x7f800000 + return x+x + end + + if k ≠ 0 + if hx < 0x5a000000 + u = 1f0 + x + hu = reinterpret(Int32, u) + k = (hu>>23) - 127 + c = k>0 ? 1f0-(u-x) : x-(u-1f0) + c = c/u + else + u = x + hu = reinterpret(Int32, u) + k = (hu>>23) - 127 + c = 0f0 + end + + hu = hu & 0x007fffff + + if hu < 0x3504f4 # u < sqrt(2) + u = reinterpret(Float32, hu|0x3f800000) + else + k += 1 + u = reinterpret(Float32, hu|0x3f000000) + hu = (0x00800000-hu)>>2 + end + f = u-1f0 + end + + hfsq = 0.5f0*f*f + + if hu == 0 # |f| < 2^-20 + if f == 0 + if k == 0 + return 0f0 + else + c += k*ln2_lo + return k*ln2_hi+c + end + end + R = hfsq*(1f0-Lp1*f) + if k == 0 + return f-R + else + return k*ln2_hi - ((R-(k*ln2_lo+c))-f) + end + end + + s = f/(2f0+f) + z = s*s + R = z*(Lp1+z*(Lp2+z*(Lp3+z*(Lp4+z*(Lp5+z*(Lp6+z*Lp7)))))) + if k == 0 + return f-(hfsq-s*(hfsq+R)) + else + return k*ln2_hi-((hfsq-(s*(hfsq+R)+(k*ln2_lo+c)))-f) + end +end + + @device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern air.fast_pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) @device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) @device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y) From b1df62fd91be03bd1c0c6bd0fcc2582eb9fdc2f3 Mon Sep 17 00:00:00 2001 From: Sotiris Lamprinidis Date: Tue, 15 Aug 2023 12:22:37 +0200 Subject: [PATCH 2/4] Basic test for custom log1p --- test/device/intrinsics.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 3d3a166d5..d386b7278 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -138,6 +138,11 @@ end synchronize() @test vecA ≈ sin.(a) @test vecB ≈ cos.(a) + + b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) + bufferC = MtlArray{eltype(b),length(size(b)),Shared}(b) + vecC = Array(log1p.(bufferC)) + @test vecC ≈ log1p.(b) end ############################################################################################ From b608ed2ed05d08bcd06914dae17bbd5991411136 Mon Sep 17 00:00:00 2001 From: Sotiris Lamprinidis Date: Tue, 15 Aug 2023 12:23:49 +0200 Subject: [PATCH 3/4] Modify log1p test --- test/device/intrinsics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index d386b7278..23a93b88f 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -140,7 +140,7 @@ end @test vecB ≈ cos.(a) b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) - bufferC = MtlArray{eltype(b),length(size(b)),Shared}(b) + bufferC = MtlArray(b) vecC = Array(log1p.(bufferC)) @test vecC ≈ log1p.(b) end From 57bfbe129b54cf20a1e4f3e4901cb2f3f5ab259a Mon Sep 17 00:00:00 2001 From: Sotiris Lamprinidis Date: Tue, 15 Aug 2023 12:34:25 +0200 Subject: [PATCH 4/4] Fix comment and use compound assignments --- src/device/intrinsics/math.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index ad0ab20a2..b2b902d24 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -139,7 +139,7 @@ const Lp7 = 0.14798199f0 end end - if hx>0||hx<=reinterpret(Int32, 0xbe95f619) # (sqrt(2)/2)-1 <= 1+x + if hx>0||hx<=reinterpret(Int32, 0xbe95f619) # (sqrt(2)/2)-1 <= x k = 0 f = x hu = 1f0 @@ -156,7 +156,7 @@ const Lp7 = 0.14798199f0 hu = reinterpret(Int32, u) k = (hu>>23) - 127 c = k>0 ? 1f0-(u-x) : x-(u-1f0) - c = c/u + c /= u else u = x hu = reinterpret(Int32, u) @@ -164,7 +164,7 @@ const Lp7 = 0.14798199f0 c = 0f0 end - hu = hu & 0x007fffff + hu &= 0x007fffff if hu < 0x3504f4 # u < sqrt(2) u = reinterpret(Float32, hu|0x3f800000)