Skip to content

Commit

Permalink
Fix MPS.synchronize_state
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Oct 1, 2024
1 parent f605bcb commit 69c8d2a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ for R in [:MPSMatrixRandomMTGP32, :MPSMatrixRandomPhilox]
end

synchronize_state(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandBuffer) =
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing
@objc [kern::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing


@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2};
Expand Down
7 changes: 7 additions & 0 deletions test/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,10 @@ using .MPS: MPSMatrixFindTopK
@test topk.sourceColumns == cols
@test topk.sourceRows == rows
end

# Ensure that the function does not error
@testset "MPSMatrixRandom sync state" begin
cmdbuf = MTL.MTLCommandBuffer(global_queue(device()))
rng = MPS.MPSMatrixRandomMTGP32(device())
@test isnothing(MPS.synchronize_state(rng, cmdbuf))
end

0 comments on commit 69c8d2a

Please sign in to comment.