Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port openlibm log1pf as log1p #239

Merged
merged 4 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Math function mappings to Metal intrinsics

using Base: FastMath
using Base.Math: throw_complex_domainerror

# TODO:
# - wrap all intrinsics from include/metal/metal_math
Expand Down Expand Up @@ -101,6 +102,110 @@ using Base: FastMath
@device_override Base.log10(x::Float32) = ccall("extern air.log10.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.log10(x::Float16) = ccall("extern air.log10.f16", llvmcall, Float16, (Float16,), x)

# Implementation of `log1p(::Float32)` from openlibm's `log1pf`
# https://github.com/JuliaMath/openlibm
const ln2_hi = 0.6931381f0
const ln2_lo = 9.058001f-6
const Lp1 = 0.6666667f0
const Lp2 = 0.4f0
const Lp3 = 0.2857143f0
const Lp4 = 0.22222199f0
const Lp5 = 0.18183573f0
const Lp6 = 0.15313838f0
const Lp7 = 0.14798199f0

@device_override function Base.Math.log1p(x::Float32)
hx = reinterpret(Int32, x)
ax = hx & 0x7fffffff # |x|

k = 1
if hx < 0x3ed413d0 # x < sqrt(2) - 1
if ax >= 0x3f800000 # |x| ≥ 1
if x == -1
return -Inf32
elseif isnan(x)
return NaN32
else # x < -1
# TODO: switch to throw_complex_domainerror_neg1 for next Julia release
throw_complex_domainerror(:log1p, x)
end
end

if ax < 0x38000000 # |x| < 2^-15
if ax < 0x33800000 # |x| < 2^-24
return x # Inexact
else
return x - x*x*0.5f0
end
end

if hx>0||hx<=reinterpret(Int32, 0xbe95f619) # (sqrt(2)/2)-1 <= x
k = 0
f = x
hu = 1f0
end
end # hx < 0x3ed413d0

if hx >= 0x7f800000
return x+x
end

if k ≠ 0
if hx < 0x5a000000
u = 1f0 + x
hu = reinterpret(Int32, u)
k = (hu>>23) - 127
c = k>0 ? 1f0-(u-x) : x-(u-1f0)
c /= u
else
u = x
hu = reinterpret(Int32, u)
k = (hu>>23) - 127
c = 0f0
end

hu &= 0x007fffff

if hu < 0x3504f4 # u < sqrt(2)
u = reinterpret(Float32, hu|0x3f800000)
else
k += 1
u = reinterpret(Float32, hu|0x3f000000)
hu = (0x00800000-hu)>>2
end
f = u-1f0
end

hfsq = 0.5f0*f*f

if hu == 0 # |f| < 2^-20
if f == 0
if k == 0
return 0f0
else
c += k*ln2_lo
return k*ln2_hi+c
end
end
R = hfsq*(1f0-Lp1*f)
if k == 0
return f-R
else
return k*ln2_hi - ((R-(k*ln2_lo+c))-f)
end
end

s = f/(2f0+f)
z = s*s
R = z*(Lp1+z*(Lp2+z*(Lp3+z*(Lp4+z*(Lp5+z*(Lp6+z*Lp7))))))
if k == 0
return f-(hfsq-s*(hfsq+R))
else
return k*ln2_hi-((hfsq-(s*(hfsq+R)+(k*ln2_lo+c)))-f)
end
end


@device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern air.fast_pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y)
Expand Down
5 changes: 5 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ end
synchronize()
@test vecA ≈ sin.(a)
@test vecB ≈ cos.(a)

b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
bufferC = MtlArray(b)
vecC = Array(log1p.(bufferC))
@test vecC ≈ log1p.(b)
end

############################################################################################
Expand Down