Skip to content

Commit

Permalink
Reflection: Figure out kernel names by looking at metallib section. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Jul 16, 2024
1 parent 37bcfda commit 3b16be4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 19 deletions.
44 changes: 37 additions & 7 deletions src/compiler/library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ tag_value_io["RLST"] = (
@NamedTuple{offset::UInt64, size::UInt64},
(io, _) -> (; offset=read(io, UInt64), size=read(io, UInt64)),
(io, val) -> write(io, UInt64[val.offset, val.size]))
tag_value_io["SLST"] = (
# Offset and size of the script list section; 2 x 64-bit unsigned integers
@NamedTuple{offset::UInt64, size::UInt64},
(io, _) -> (; offset=read(io, UInt64), size=read(io, UInt64)),
(io, val) -> write(io, UInt64[val.offset, val.size]))
tag_value_io["UUID"] = (
# UUID of the module; 16 bytes
UUID,
Expand Down Expand Up @@ -393,7 +398,16 @@ tag_value_io["SARC"] = (
tag_value_io["RBUF"] = (
# Reflection buffer
Vector{UInt8},
(io, nb) -> begin
(io, nb) -> begin
read(io, nb)
end,
(io, val) -> write(io, val))
## script lists
tag_value_io["SBUF"] = (
# Script buffer
Vector{UInt8},
(io, nb) -> begin
# XXX: these are probably flatbuffers; decode here?
read(io, nb)
end,
(io, val) -> write(io, val))
Expand All @@ -420,9 +434,9 @@ function Base.read!(io::IO, tg::TagGroup)
value_size = read(io, tg.size_type)
tg.offsets[tag_name] = position(io)

# XXX: there's a 2 byte mismatch between the reflection list size, and the
# next token... bug in air-lld?
if tag_name == "RBUF"
# XXX: there's a 2 byte mismatch between the buffers in list, and
# the next token. bug in air-lld, or are we missing something?
if tag_name in ["RBUF", "SBUF"]
value_size += 2
end

Expand Down Expand Up @@ -469,9 +483,9 @@ function Base.write(io::IO, tg::TagGroup)
end
value_size = sizeof(value_bytes)

# XXX: there's a 2 byte mismatch between the reflection list size, and the
# next token... bug in air-lld?
if tag == "RBUF"
# XXX: there's a 2 byte mismatch between the buffers in list, and
# the next token. bug in air-lld, or are we missing something?
if tag in ["RBUF", "SBUF"]
value_size -= 2
end

Expand Down Expand Up @@ -648,6 +662,22 @@ function Base.read(io::IO, ::Type{MetalLib})
@assert position(io) == header_ex["RLST"].offset + header_ex["RLST"].size
end

# script list
#
# this section contains pipeline descriptor scripts
if header_ex !== nothing && haskey(header_ex, "SLST")
seek(io, header_ex["SLST"].offset)
script_count = read(io, UInt32)
for i in 1:script_count
script_buf = read!(io, TagGroup())

# script_buf["SBUF"] contains flatbuffer data; compare with:
# $ metal-source -flatbuffers=binary
# $ metal-source -flatbuffers=json
end
@assert position(io) == header_ex["SLST"].offset + header_ex["SLST"].size
end

# embedded source
#
# there can be fewer sources than functions, so preserve the function -> source mapping
Expand Down
37 changes: 25 additions & 12 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,36 @@ function extract_gpu_code(f, binary)
end
arch == nothing && error("Could not find GPU architecture in universal binary")

# the GPU binary contains several sections (metallib, descriptor, reflection, compute?,
# fragment?, vertex?); extract the compute section, which is another Mach-O binary
# the GPU binary contains several sections...
## ... extract the compute section, which is another Mach-O binary
compute_section = findfirst(Sections(fat_handle[arch]), "__TEXT,__compute")
compute_section === nothing && error("Could not find __compute section in GPU binary")
compute_binary = read(compute_section)
native_handle = only(readmeta(IOBuffer(compute_binary)))

# the start of the section should also alias with a symbol in the universal binary,
# which we can use to identify the name of the kernel
compute_symbol = nothing
for symbol in Symbols(fat_handle[arch])
symbol_value(symbol) == section_offset(compute_section) || continue
endswith(symbol_name(symbol), "_begin") || continue
compute_symbol = symbol
## ... extract the metallib section, which is a Metal library
metallib_section = findfirst(Sections(fat_handle[arch]), "__TEXT,__metallib")
metallib_section === nothing && error("Could not find __metallib section in GPU binary")
metallib_binary = read(metallib_section)
metallib = read(IOBuffer(metallib_binary), Metal.MetalLib)
# TODO: use this to implement a do-block device_code_air like CUDA.jl?

# identify the kernel name
kernel_name = "unknown_kernel"
# XXX: does it happen that these metallibs contain multiple functions?
if length(metallib.functions) == 1
kernel_name = metallib.functions[1].name
end
compute_symbol === nothing && error("Could not find symbol for __compute section")
kernel_name = symbol_name(compute_symbol)[1:end-6]
# XXX: we used to be able to identify the kernel by looking at symbols in
# the fat binary, one of which aliased with the start of the compute
# section. these symbols have disappeared on macOS 15.
#compute_symbol = nothing
#for symbol in Symbols(fat_handle[arch])
# symbol_value(symbol) == section_offset(compute_section) || continue
# endswith(symbol_name(symbol), "_begin") || continue
# compute_symbol = symbol
#end
#compute_symbol === nothing && error("Could not find symbol for __compute section")
#kernel_name = symbol_name(compute_symbol)[1:end-6]

# within the native GPU binary, isolate the section containing code
section = findfirst(Sections(native_handle), "__TEXT,__text")
Expand Down

0 comments on commit 3b16be4

Please sign in to comment.