Skip to content

Commit

Permalink
Merge pull request #326 from christiangnrd/improvements
Browse files Browse the repository at this point in the history
BFloat16s.jl extension and related improvements
  • Loading branch information
maleadt authored Apr 4, 2024
2 parents bc2131e + 0fd9eaa commit e057d13
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 32 deletions.
5 changes: 0 additions & 5 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ steps:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
commands: |
julia --project -e '
# make sure the 1.8-era Manifest works on this Julia version
using Pkg
Pkg.resolve()'
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
matrix:
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.0.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand All @@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SpecialFunctionsExt = "SpecialFunctions"
BFloat16sExt = "BFloat16s"

[compat]
Adapt = "4"
Artifacts = "1"
BFloat16s = "0.5"
CEnum = "0.4, 0.5"
CodecBzip2 = "0.8"
ExprTools = "0.1"
Expand Down
12 changes: 12 additions & 0 deletions ext/BFloat16sExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BFloat16sExt

using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version
using BFloat16s

# BFloat is only supported in MPS starting in MacOS 14
if macos_version() >= v"14"
Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16
jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16
end

end # module
63 changes: 39 additions & 24 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
#
# matrix enums
#

@cenum MPSDataType::UInt32 begin
@cenum MPSDataTypeBits::UInt32 begin
MPSDataTypeComplexBit = UInt32(0x01000000)
MPSDataTypeFloatBit = UInt32(0x10000000)
MPSDataTypeSignedBit = UInt32(0x20000000)
MPSDataTypeNormalizedBit = UInt32(0x40000000)
MPSDataTypeAlternateEncodingBit = UInt32(0x80000000)
end

@enum MPSDataType::UInt32 begin
MPSDataTypeInvalid = UInt32(0)

MPSDataTypeUInt8 = UInt32(8)
MPSDataTypeUInt16 = UInt32(16)
MPSDataTypeUInt32 = UInt32(32)
MPSDataTypeUInt64 = UInt32(64)

MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8)
MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16)
MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32)
MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64)

MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16)
MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32)

MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16)
MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)

MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1)
MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8)

MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8)
MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16)
end
## bitwise operations lose type information, so allow conversions
Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x)

# Conversions for MPSDataTypes with Julia equivalents
const jl_mps_to_typ = Dict{MPSDataType, DataType}()
for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool]
@eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type))
@eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type
end

Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp]


#
# matrix descriptor
#
Expand All @@ -29,31 +64,11 @@ export MPSMatrixDescriptor
@autoproperty matrixBytes::NSUInteger
end


# Mapping from Julia types to the Performance Shader bitfields
const jl_typ_to_mps = Dict{DataType,MPSDataType}(
UInt8 => UInt32(8),
UInt16 => UInt32(16),
UInt32 => UInt32(32),
UInt64 => UInt32(64),

Int8 => MPSDataTypeSignedBit | UInt32(8),
Int16 => MPSDataTypeSignedBit | UInt32(16),
Int32 => MPSDataTypeSignedBit | UInt32(32),
Int64 => MPSDataTypeSignedBit | UInt32(64),

Float16 => MPSDataTypeFloatBit | UInt32(16),
Float32 => MPSDataTypeFloatBit | UInt32(32),

ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16),
ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32)
)

function MPSMatrixDescriptor(rows, columns, rowBytes, dataType)
desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger
columns:columns::NSUInteger
rowBytes:rowBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
Expand All @@ -65,7 +80,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat
matrices:matrices::NSUInteger
rowBytes:rowBytes::NSUInteger
matrixBytes:matrixBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor}
dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor}
obj = MPSMatrixDescriptor(desc)
# XXX: who releases this object?
return obj
Expand Down
4 changes: 2 additions & 2 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

function MPSVectorDescriptor(length, dataType)
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
obj = MPSVectorDescriptor(desc)
# XXX: who releases this object?
return obj
Expand All @@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType)
desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger
vectors:vectors::NSUInteger
vectorBytes:vectorBytes::NSUInteger
dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor}
dataType:dataType::MPSDataType]::id{MPSVectorDescriptor}
obj = MPSVectorDescriptor(desc)
# XXX: who releases this object?
return obj
Expand Down
4 changes: 4 additions & 0 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,8 @@ include("MetalKernels.jl")
import .MetalKernels: MetalBackend
export MetalBackend

@static if !isdefined(Base, :get_extension)
include("../ext/BFloat16sExt.jl")
end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ try

# catch timeouts
pid = remotecall_fetch(getpid, wrkr)
timer = Timer(360) do _
timer = Timer(480) do _
@warn "Test timed out: $test"
t1 = rmprocs(wrkr, waitfor=0)

Expand Down

0 comments on commit e057d13

Please sign in to comment.