diff --git a/lib/mps/matrix.jl b/lib/mps/matrix.jl index eaf60ba4..db372f6c 100644 --- a/lib/mps/matrix.jl +++ b/lib/mps/matrix.jl @@ -155,14 +155,8 @@ function MPSMatrix(arr::MtlArray{T,3}) where T n_cols, n_rows, n_matrices = size(arr) row_bytes = sizeof(T)*n_cols desc = MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, row_bytes * n_rows, T) - mat = @objc [MPSMatrix alloc]::id{MPSMatrix} - obj = MPSMatrix(mat) offset = arr.offset * sizeof(T) - finalizer(release, obj) - @objc [obj::id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer} - offset:offset::NSUInteger - descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} - return obj + return MPSMatrix(arr, desc, offset) end #