From 390b585e5e45a069c556130425f94f6d34cc9bd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:25:44 +0100 Subject: [PATCH] add shortcut to `dot_general` op (#330) * implement `dot_general` shortcut op * fix 1-based indexing to 0-based indexing conversion * test `dot_general` * fix checks * last fixes * add deprecation notice in `einsum` * fix test --- src/Ops.jl | 184 +++++++++++++++++++++++++++++++++++++++++++++------- test/ops.jl | 31 +++++++++ 2 files changed, 190 insertions(+), 25 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index eee3edbd8..e9f3d3252 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -546,31 +546,162 @@ end # return TracedRArray{T,N}((), res, size(lhs)) # end -# function dot_general( -# lhs::TracedRArray{T,N}, -# rhs::TracedRArray{T,N}; -# dimension_numbers, -# lhs_contracting_dimensions, -# rhs_contracting_dimensions, -# result_permutation, -# location=mlir_stacktrace( -# "dot_general", @__FILE__, @__LINE__ -# ), -# ) where {T,N} -# res = MLIR.IR.result( -# stablehlo.dot_general( -# lhs.mlir_data, -# rhs.mlir_data; -# result=mlir_type(TracedRArray{T,N}, ...), # TODO size of result -# dimension_numbers, -# lhs_contracting_dimensions, -# rhs_contracting_dimensions, -# result_permutation, -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(lhs)) -# end +function dot_general( + lhs::TracedRArray{T}, + rhs::TracedRArray{T}; + contracting_dimensions, + batching_dimensions=(Int[], Int[]), + precision_config=nothing, + precision_type=nothing, + accumulation_type=nothing, + component_count=nothing, + num_primitive_operations=nothing, + allow_imprecise_accumulation=nothing, + location=mlir_stacktrace("dot_general", @__FILE__, @__LINE__), +) where {T} + # C1 + C2 + @assert length(batching_dimensions) == 2 && splat(==)(length.(batching_dimensions)) + @assert length(contracting_dimensions) == 2 && + splat(==)(length.(contracting_dimensions)) + + # C3 + C4 + @assert all(eltype.(contracting_dimensions) .<: Int64) + @assert all(eltype.(batching_dimensions) .<: Int64) + @assert all(isdisjoint.(contracting_dimensions, batching_dimensions)) + + lhs_contracting_dimensions, rhs_contracting_dimensions = contracting_dimensions + lhs_batching_dimensions, rhs_batching_dimensions = batching_dimensions + + # C5 + C6 + C7 + C8 + @assert all(lhs_batching_dimensions .<= ndims(lhs)) + @assert all(rhs_batching_dimensions .<= ndims(rhs)) + @assert all(lhs_contracting_dimensions .<= ndims(lhs)) + @assert all(rhs_contracting_dimensions .<= ndims(rhs)) + + # C9 + C10 + @assert size.(Ref(lhs), lhs_batching_dimensions) == + size.(Ref(rhs), rhs_batching_dimensions) + @assert size.(Ref(lhs), lhs_contracting_dimensions) == + size.(Ref(rhs), rhs_contracting_dimensions) + + # C11 + @assert isnothing(precision_config) || length(precision_config) == 2 + + @assert isnothing(precision_type) || + length(precision_type) == 2 && eltype(precision_type) <: AbstractFloat + @assert isnothing(accumulation_type) || accumulation_type <: AbstractFloat + + # C22 + C23 + @assert isnothing(component_count) || + length(component_count) == 2 && + eltype(component_count) <: Int32 && + all(0 .<= component_count) + + # C24 + @assert isnothing(num_primitive_operations) || + num_primitive_operations isa Int32 && num_primitive_operations > 0 + @assert isnothing(allow_imprecise_accumulation) || allow_imprecise_accumulation isa Bool + + ctx = MLIR.IR.context() + + # from C12 + lhs_result_dimensions = setdiff( + 1:ndims(lhs), lhs_batching_dimensions, lhs_contracting_dimensions + ) + rhs_result_dimensions = setdiff( + 1:ndims(rhs), rhs_batching_dimensions, rhs_contracting_dimensions + ) + + ressize = vcat( + size.(Ref(lhs), lhs_batching_dimensions), + size.(Ref(lhs), lhs_result_dimensions), + size.(Ref(rhs), rhs_result_dimensions), + ) + + # fix 1-indexing + lhs_batching_dimensions = lhs_batching_dimensions .- 1 + rhs_batching_dimensions = rhs_batching_dimensions .- 1 + lhs_contracting_dimensions = lhs_contracting_dimensions .- 1 + rhs_contracting_dimensions = rhs_contracting_dimensions .- 1 + + dot_dimension_numbers = GC.@preserve lhs_contracting_dimensions rhs_contracting_dimensions lhs_batching_dimensions rhs_batching_dimensions begin + MLIR.IR.Attribute( + MLIR.API.stablehloDotDimensionNumbersGet( + ctx, + length(lhs_batching_dimensions), + lhs_batching_dimensions, + length(rhs_batching_dimensions), + rhs_batching_dimensions, + length(lhs_contracting_dimensions), + lhs_contracting_dimensions, + length(rhs_contracting_dimensions), + rhs_contracting_dimensions, + ), + ) + end + + if !isnothing(precision_config) + precision_config = MLIR.IR.Attribute([ + MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[1]), + MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[2]), + ]) + end + + # all or nothing: if one is set, all must be set + # TODO maybe be more flexible, by setting some defaults? + if any( + !isnothing, + ( + precision_type, + accumulation_type, + component_count, + num_primitive_operations, + allow_imprecise_accumulation, + ), + ) + @assert all( + !isnothing, + ( + precision_type..., + accumulation_type, + component_count..., + num_primitive_operations, + allow_imprecise_accumulation, + ), + ) + lhs_precision_type, rhs_precision_type = precision_type + lhs_component_count, rhs_component_count = component_count + algorithm = GC.@preserve begin + MLIR.IR.Attribute( + MLIR.API.stablehloDotAlgorithmGet( + ctx, + lhs_precision_type, + rhs_precision_type, + accumulation_type, + lhs_component_count, + rhs_component_count, + num_primitive_operations, + allow_imprecise_accumulation, + ), + ) + end + else + algorithm = nothing + end + + res = MLIR.IR.result( + stablehlo.dot_general( + lhs.mlir_data, + rhs.mlir_data; + result_0=mlir_type(TracedRArray{T,length(ressize)}, ressize), + dot_dimension_numbers, + precision_config, + algorithm, + location, + ), + ) + return TracedRArray{T,length(ressize)}((), res, ressize) +end function einsum( lhs::TracedRArray{T}, @@ -578,6 +709,9 @@ function einsum( equation::String, location=mlir_stacktrace("einsum", @__FILE__, @__LINE__), ) where {T} + Base.depwarn( + "`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum + ) ins, ic = split(equation, "->") ia, ib = split(ins, ",") diff --git a/test/ops.jl b/test/ops.jl index 7334e6e3a..eb4981b80 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -188,6 +188,37 @@ end end end +@testset "dot_general" begin + # dot product of first dim + f1(x, y) = Ops.dot_general(x, y; contracting_dimensions=[[1], [1]]) + + # outer product + fouter(x, y) = Ops.dot_general(x, y; contracting_dimensions=[Int[], Int[]]) + + # outer product, batch first dim + fouter_batch1(x, y) = Ops.dot_general( + x, y; contracting_dimensions=[Int[], Int[]], batching_dimensions=[[1], [1]] + ) + + for (a, b) in [ + (ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])), + (ConcreteRArray([1.0, 2.0, 3.0, 4.0]), ConcreteRArray([5.0, 6.0, -7.0, -8.0])), + ( + ConcreteRArray([1.0, 2.0im, 3.0, 4.0im]), + ConcreteRArray([5.0, 6.0im, -7.0im, -8.0]), + ), + ] + # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation + @test sum(a .* b) ≈ @jit f1(a, b) + @test kron(reshape(a, length(a), 1), reshape(b, 1, length(b))) ≈ @jit fouter(a, b) + @test a .* b ≈ @jit fouter_batch1(a, b) + end + + a = ConcreteRArray([1 2; 3 4]) + b = ConcreteRArray([5 6; -7 -8]) + @test a' * b == @jit f1(a, b) +end + @testset "einsum" begin f1(a, b) = Ops.einsum(a, b; equation="i,i->i") f2(a, b) = Ops.einsum(a, b; equation="i,j->ij")