Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BFloat16s.jl extension and related improvements #326

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ steps:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
commands: |
julia --project -e '
# make sure the 1.8-era Manifest works on this Julia version
using Pkg
Pkg.resolve()'
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
matrix:
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.0.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand All @@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SpecialFunctionsExt = "SpecialFunctions"
BFloat16sExt = "BFloat16s"

[compat]
Adapt = "4"
Artifacts = "1"
BFloat16s = "0.5"
CEnum = "0.4, 0.5"
CodecBzip2 = "0.8"
ExprTools = "0.1"
Expand Down
12 changes: 12 additions & 0 deletions ext/BFloat16sExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BFloat16sExt

using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version
using BFloat16s

# BFloat is only supported in MPS starting in MacOS 14
if macos_version() >= v"14"
Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16
jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16
end

end # module
63 changes: 39 additions & 24 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
#
# matrix enums
#

@cenum MPSDataType::UInt32 begin
@cenum MPSDataTypeBits::UInt32 begin
MPSDataTypeComplexBit = UInt32(0x01000000)
MPSDataTypeFloatBit = UInt32(0x10000000)
MPSDataTypeSignedBit = UInt32(0x20000000)
MPSDataTypeNormalizedBit = UInt32(0x40000000)
MPSDataTypeAlternateEncodingBit = UInt32(0x80000000)
end

@enum MPSDataType::UInt32 begin
MPSDataTypeInvalid = UInt32(0)

MPSDataTypeUInt8 = UInt32(8)
MPSDataTypeUInt16 = UInt32(16)
MPSDataTypeUInt32 = UInt32(32)
MPSDataTypeUInt64 = UInt32(64)

MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)

MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)

MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16)
MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)

MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)

MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
end
## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

# Conversions for MPSDataTypes with Julia equivalents
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
end

Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]


#
# matrix descriptor
#
Expand All @@ -29,31 +64,11 @@ export MPSMatrixDescriptor
@autoproperty matrixBytes::NSUInteger
end


# Mapping from Julia types to the Performance Shader bitfields
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
UInt8 => UInt32(8),
UInt16 => UInt32(16),
UInt32 => UInt32(32),
UInt64 => UInt32(64),

Int8 => MPSDataTypeSignedBit | UInt32(8),
Int16 => MPSDataTypeSignedBit | UInt32(16),
Int32 => MPSDataTypeSignedBit | UInt32(32),
Int64 => MPSDataTypeSignedBit | UInt32(64),

Float16 => MPSDataTypeFloatBit | UInt32(16),
Float32 => MPSDataTypeFloatBit | UInt32(32),

ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16),
ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
)

function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
columns:columns::NSUInteger
rowBytes:rowBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
Expand All @@ -65,7 +80,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat
matrices:matrices::NSUInteger
rowBytes:rowBytes::NSUInteger
matrixBytes:matrixBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
Expand Down
4 changes: 2 additions & 2 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

function MPSVectorDescriptor(length, dataType)
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
obj = MPSVectorDescriptor(desc)
# XXX: who releases this object?
return obj
Expand All @@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType)
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
vectors:vectors::NSUInteger
vectorBytes:vectorBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
obj = MPSVectorDescriptor(desc)
# XXX: who releases this object?
return obj
Expand Down
4 changes: 4 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,8 @@ include("MetalKernels.jl")
import .MetalKernels: MetalBackend
export MetalBackend

@static if !isdefined(Base, :get_extension)
include("../ext/BFloat16sExt.jl")
end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ try

# catch timeouts
pid = remotecall_fetch(getpid, wrkr)
timer = Timer(360) do _
timer = Timer(480) do _
@warn "Test timed out: $test"
t1 = rmprocs(wrkr, waitfor=0)

Expand Down