diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 760745b7a..1391028ff 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -19,11 +19,6 @@ steps: queue: "juliaecosystem" os: "macos" arch: "aarch64" - commands: | - julia --project -e ' - # make sure the 1.8-era Manifest works on this Julia version - using Pkg - Pkg.resolve()' if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: diff --git a/Project.toml b/Project.toml index f67e8da57..b390a6cb6 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] SpecialFunctionsExt = "SpecialFunctions" +BFloat16sExt = "BFloat16s" [compat] Adapt = "4" Artifacts = "1" +BFloat16s = "0.5" CEnum = "0.4, 0.5" CodecBzip2 = "0.8" ExprTools = "0.1" diff --git a/ext/BFloat16sExt.jl b/ext/BFloat16sExt.jl new file mode 100644 index 000000000..d281d306d --- /dev/null +++ b/ext/BFloat16sExt.jl @@ -0,0 +1,12 @@ +module BFloat16sExt + +using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version +using BFloat16s + +# BFloat is only supported in MPS starting in MacOS 14 +if macos_version() >= v"14" + Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16 + jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16 +end + +end # module diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 3156da953..e6813b070 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -1,17 +1,52 @@ # # matrix enums # - -@cenum MPSDataType::UInt32 begin +@cenum MPSDataTypeBits::UInt32 begin MPSDataTypeComplexBit = UInt32(0x01000000) MPSDataTypeFloatBit = UInt32(0x10000000) MPSDataTypeSignedBit = UInt32(0x20000000) MPSDataTypeNormalizedBit = UInt32(0x40000000) MPSDataTypeAlternateEncodingBit = UInt32(0x80000000) end + +@enum MPSDataType::UInt32 begin + MPSDataTypeInvalid = UInt32(0) + + MPSDataTypeUInt8 = UInt32(8) + MPSDataTypeUInt16 = UInt32(16) + MPSDataTypeUInt32 = UInt32(32) + MPSDataTypeUInt64 = UInt32(64) + + MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8) + MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16) + MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32) + MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64) + + MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16) + MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32) + + MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16) + MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) + + MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1) + MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8) + + MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8) + MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16) +end ## bitwise operations lose type information, so allow conversions Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) +# Conversions for MPSDataTypes with Julia equivalents +const jl_mps_to_typ = Dict{MPSDataType, DataType}() +for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool] + @eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type)) + @eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type +end + +Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] + + # # matrix descriptor # @@ -29,31 +64,11 @@ export MPSMatrixDescriptor @autoproperty matrixBytes::NSUInteger end - -# Mapping from Julia types to the Performance Shader bitfields -const jl_typ_to_mps = Dict{DataType,MPSDataType}( - UInt8 => UInt32(8), - UInt16 => UInt32(16), - UInt32 => UInt32(32), - UInt64 => UInt32(64), - - Int8 => MPSDataTypeSignedBit | UInt32(8), - Int16 => MPSDataTypeSignedBit | UInt32(16), - Int32 => MPSDataTypeSignedBit | UInt32(32), - Int64 => MPSDataTypeSignedBit | UInt32(64), - - Float16 => MPSDataTypeFloatBit | UInt32(16), - Float32 => MPSDataTypeFloatBit | UInt32(32), - - ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16), - ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) -) - function MPSMatrixDescriptor(rows, columns, rowBytes, dataType) desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger columns:columns::NSUInteger rowBytes:rowBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor} + dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor} obj = MPSMatrixDescriptor(desc) # XXX: who releases this object? return obj @@ -65,7 +80,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat matrices:matrices::NSUInteger rowBytes:rowBytes::NSUInteger matrixBytes:matrixBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor} + dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor} obj = MPSMatrixDescriptor(desc) # XXX: who releases this object? return obj diff --git a/lib/mps/vector.jl b/lib/mps/vector.jl index e2ad87efb..2d4bf9bf3 100644 --- a/lib/mps/vector.jl +++ b/lib/mps/vector.jl @@ -12,7 +12,7 @@ end function MPSVectorDescriptor(length, dataType) desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor} + dataType:dataType::MPSDataType]::id{MPSVectorDescriptor} obj = MPSVectorDescriptor(desc) # XXX: who releases this object? return obj @@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType) desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger vectors:vectors::NSUInteger vectorBytes:vectorBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor} + dataType:dataType::MPSDataType]::id{MPSVectorDescriptor} obj = MPSVectorDescriptor(desc) # XXX: who releases this object? return obj diff --git a/src/Metal.jl b/src/Metal.jl index b6a70977a..7c5a4e29b 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -67,4 +67,8 @@ include("MetalKernels.jl") import .MetalKernels: MetalBackend export MetalBackend +@static if !isdefined(Base, :get_extension) + include("../ext/BFloat16sExt.jl") +end + end # module diff --git a/test/Project.toml b/test/Project.toml index a3000c676..ddc517610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/runtests.jl b/test/runtests.jl index e184ae32a..c5b489659 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -245,7 +245,7 @@ try # catch timeouts pid = remotecall_fetch(getpid, wrkr) - timer = Timer(360) do _ + timer = Timer(480) do _ @warn "Test timed out: $test" t1 = rmprocs(wrkr, waitfor=0)