From 558ccfa31ceac2837a499f600d9a780131efad52 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Mon, 25 Mar 2024 00:10:36 -0300 Subject: [PATCH 1/4] Improvements to MPSDataType --- lib/mps/matrix.jl | 67 ++++++++++++++++++++++++++++++----------------- lib/mps/vector.jl | 4 +-- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 3156da953..68f6a87f1 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -1,17 +1,56 @@ # # matrix enums # - -@cenum MPSDataType::UInt32 begin +@cenum MPSDataTypeBits::UInt32 begin MPSDataTypeComplexBit = UInt32(0x01000000) MPSDataTypeFloatBit = UInt32(0x10000000) MPSDataTypeSignedBit = UInt32(0x20000000) MPSDataTypeNormalizedBit = UInt32(0x40000000) MPSDataTypeAlternateEncodingBit = UInt32(0x80000000) end + +@enum MPSDataType::UInt32 begin + MPSDataTypeInvalid = UInt32(0) + + MPSDataTypeUInt8 = UInt32(8) + MPSDataTypeUInt16 = UInt32(16) + MPSDataTypeUInt32 = UInt32(32) + MPSDataTypeUInt64 = UInt32(64) + + MPSDataTypeInt8 = MPSDataTypeSignedBit | UInt32(8) + MPSDataTypeInt16 = MPSDataTypeSignedBit | UInt32(16) + MPSDataTypeInt32 = MPSDataTypeSignedBit | UInt32(32) + MPSDataTypeInt64 = MPSDataTypeSignedBit | UInt32(64) + + MPSDataTypeFloat16 = MPSDataTypeFloatBit | UInt32(16) + MPSDataTypeFloat32 = MPSDataTypeFloatBit | UInt32(32) + + MPSDataTypeComplexF16 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16) + MPSDataTypeComplexF32 = MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) + + MPSDataTypeUnorm1 = MPSDataTypeNormalizedBit | UInt32(1) + MPSDataTypeUnorm8 = MPSDataTypeNormalizedBit | UInt32(8) + + MPSDataTypeBool = MPSDataTypeAlternateEncodingBit | UInt32(8) + MPSDataTypeBFloat16 = MPSDataTypeAlternateEncodingBit | MPSDataTypeFloatBit | UInt32(16) +end ## bitwise operations lose type information, so allow conversions Base.convert(::Type{MPSDataType}, x::Integer) = MPSDataType(x) +# Conversions for MPSDataTypes with Julia equivalents +const jl_mps_to_typ = Dict{MPSDataType, DataType}() +for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,ComplexF16,ComplexF32,Bool] + @eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type)) + @eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type +end +# BFloat is only supported in MPS starting in MacOS 14 +if macos_version() >= v"14" && isdefined(Core, :BFloat16) + Base.convert(::Type{MPSDataType}, ::Type{Core.BFloat16}) = MPSDataTypeBFloat16 + jl_mps_to_typ[MPSDataTypeBFloat16] = Core.BFloat16 +end +Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] + + # # matrix descriptor # @@ -29,31 +68,11 @@ export MPSMatrixDescriptor @autoproperty matrixBytes::NSUInteger end - -# Mapping from Julia types to the Performance Shader bitfields -const jl_typ_to_mps = Dict{DataType,MPSDataType}( - UInt8 => UInt32(8), - UInt16 => UInt32(16), - UInt32 => UInt32(32), - UInt64 => UInt32(64), - - Int8 => MPSDataTypeSignedBit | UInt32(8), - Int16 => MPSDataTypeSignedBit | UInt32(16), - Int32 => MPSDataTypeSignedBit | UInt32(32), - Int64 => MPSDataTypeSignedBit | UInt32(64), - - Float16 => MPSDataTypeFloatBit | UInt32(16), - Float32 => MPSDataTypeFloatBit | UInt32(32), - - ComplexF16 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(16), - ComplexF32 => MPSDataTypeFloatBit | MPSDataTypeComplexBit | UInt32(32) -) - function MPSMatrixDescriptor(rows, columns, rowBytes, dataType) desc = @objc [MPSMatrixDescriptor matrixDescriptorWithRows:rows::NSUInteger columns:columns::NSUInteger rowBytes:rowBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor} + dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor} obj = MPSMatrixDescriptor(desc) # XXX: who releases this object? return obj @@ -65,7 +84,7 @@ function MPSMatrixDescriptor(rows, columns, matrices, rowBytes, matrixBytes, dat matrices:matrices::NSUInteger rowBytes:rowBytes::NSUInteger matrixBytes:matrixBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSMatrixDescriptor} + dataType:dataType::MPSDataType]::id{MPSMatrixDescriptor} obj = MPSMatrixDescriptor(desc) # XXX: who releases this object? return obj diff --git a/lib/mps/vector.jl b/lib/mps/vector.jl index e2ad87efb..2d4bf9bf3 100644 --- a/lib/mps/vector.jl +++ b/lib/mps/vector.jl @@ -12,7 +12,7 @@ end function MPSVectorDescriptor(length, dataType) desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor} + dataType:dataType::MPSDataType]::id{MPSVectorDescriptor} obj = MPSVectorDescriptor(desc) # XXX: who releases this object? return obj @@ -22,7 +22,7 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType) desc = @objc [MPSVectorDescriptor vectorDescriptorWithLength:length::NSUInteger vectors:vectors::NSUInteger vectorBytes:vectorBytes::NSUInteger - dataType:jl_typ_to_mps[dataType]::MPSDataType]::id{MPSVectorDescriptor} + dataType:dataType::MPSDataType]::id{MPSVectorDescriptor} obj = MPSVectorDescriptor(desc) # XXX: who releases this object? return obj From 2ab7554c795d4ba5081a7ff9adab947e6b9f8b6a Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:03:22 -0300 Subject: [PATCH 2/4] BFloat16 extension --- Project.toml | 4 ++++ ext/BFloat16sExt.jl | 12 ++++++++++++ lib/mps/matrix.jl | 6 +----- src/Metal.jl | 4 ++++ test/Project.toml | 1 + 5 files changed, 22 insertions(+), 5 deletions(-) create mode 100644 ext/BFloat16sExt.jl diff --git a/Project.toml b/Project.toml index f67e8da57..b390a6cb6 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -28,14 +29,17 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [weakdeps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] SpecialFunctionsExt = "SpecialFunctions" +BFloat16sExt = "BFloat16s" [compat] Adapt = "4" Artifacts = "1" +BFloat16s = "0.5" CEnum = "0.4, 0.5" CodecBzip2 = "0.8" ExprTools = "0.1" diff --git a/ext/BFloat16sExt.jl b/ext/BFloat16sExt.jl new file mode 100644 index 000000000..d281d306d --- /dev/null +++ b/ext/BFloat16sExt.jl @@ -0,0 +1,12 @@ +module BFloat16sExt + +using Metal: MPS.MPSDataType, MPS.MPSDataTypeBFloat16, MPS.jl_mps_to_typ, macos_version +using BFloat16s + +# BFloat is only supported in MPS starting in MacOS 14 +if macos_version() >= v"14" + Base.convert(::Type{MPSDataType}, ::Type{BFloat16}) = MPSDataTypeBFloat16 + jl_mps_to_typ[MPSDataTypeBFloat16] = BFloat16 +end + +end # module diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index 68f6a87f1..e6813b070 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -43,11 +43,7 @@ for type in [UInt8,UInt16,UInt32,UInt64,Int8,Int16,Int32,Int64,Float16,Float32,C @eval Base.convert(::Type{MPSDataType}, ::Type{$type}) = $(Symbol(:MPSDataType, type)) @eval jl_mps_to_typ[$(Symbol(:MPSDataType, type))] = $type end -# BFloat is only supported in MPS starting in MacOS 14 -if macos_version() >= v"14" && isdefined(Core, :BFloat16) - Base.convert(::Type{MPSDataType}, ::Type{Core.BFloat16}) = MPSDataTypeBFloat16 - jl_mps_to_typ[MPSDataTypeBFloat16] = Core.BFloat16 -end + Base.convert(::Type{DataType}, mpstyp::MPSDataType) = jl_mps_to_typ[mpstyp] diff --git a/src/Metal.jl b/src/Metal.jl index b6a70977a..7c5a4e29b 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -67,4 +67,8 @@ include("MetalKernels.jl") import .MetalKernels: MetalBackend export MetalBackend +@static if !isdefined(Base, :get_extension) + include("../ext/BFloat16sExt.jl") +end + end # module diff --git a/test/Project.toml b/test/Project.toml index a3000c676..ddc517610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" From 98a4fc703606f276d7dccacb4e386976e26b1675 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Sat, 30 Mar 2024 02:11:00 -0300 Subject: [PATCH 3/4] Increase test timeout limit (again) --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index e184ae32a..c5b489659 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -245,7 +245,7 @@ try # catch timeouts pid = remotecall_fetch(getpid, wrkr) - timer = Timer(360) do _ + timer = Timer(480) do _ @warn "Test timed out: $test" t1 = rmprocs(wrkr, waitfor=0) From 0fd9eaad5237ec35bdb3fb29b84decb7c9eec5be Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 4 Apr 2024 08:27:07 +0200 Subject: [PATCH 4/4] Remove unneeded resolve. --- .buildkite/pipeline.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 760745b7a..1391028ff 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -19,11 +19,6 @@ steps: queue: "juliaecosystem" os: "macos" arch: "aarch64" - commands: | - julia --project -e ' - # make sure the 1.8-era Manifest works on this Julia version - using Pkg - Pkg.resolve()' if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 matrix: