Skip to content

Commit

Permalink
add shortcut to dot_general op (#330)
Browse files Browse the repository at this point in the history
* implement `dot_general` shortcut op

* fix 1-based indexing to 0-based indexing conversion

* test `dot_general`

* fix checks

* last fixes

* add deprecation notice in `einsum`

* fix test
  • Loading branch information
mofeing authored Dec 6, 2024
1 parent 74174ae commit 390b585
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 25 deletions.
184 changes: 159 additions & 25 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,38 +546,172 @@ end
# return TracedRArray{T,N}((), res, size(lhs))
# end

# function dot_general(
# lhs::TracedRArray{T,N},
# rhs::TracedRArray{T,N};
# dimension_numbers,
# lhs_contracting_dimensions,
# rhs_contracting_dimensions,
# result_permutation,
# location=mlir_stacktrace(
# "dot_general", @__FILE__, @__LINE__
# ),
# ) where {T,N}
# res = MLIR.IR.result(
# stablehlo.dot_general(
# lhs.mlir_data,
# rhs.mlir_data;
# result=mlir_type(TracedRArray{T,N}, ...), # TODO size of result
# dimension_numbers,
# lhs_contracting_dimensions,
# rhs_contracting_dimensions,
# result_permutation,
# location,
# ),
# )
# return TracedRArray{T,N}((), res, size(lhs))
# end
function dot_general(
lhs::TracedRArray{T},
rhs::TracedRArray{T};
contracting_dimensions,
batching_dimensions=(Int[], Int[]),
precision_config=nothing,
precision_type=nothing,
accumulation_type=nothing,
component_count=nothing,
num_primitive_operations=nothing,
allow_imprecise_accumulation=nothing,
location=mlir_stacktrace("dot_general", @__FILE__, @__LINE__),
) where {T}
# C1 + C2
@assert length(batching_dimensions) == 2 && splat(==)(length.(batching_dimensions))
@assert length(contracting_dimensions) == 2 &&
splat(==)(length.(contracting_dimensions))

# C3 + C4
@assert all(eltype.(contracting_dimensions) .<: Int64)
@assert all(eltype.(batching_dimensions) .<: Int64)
@assert all(isdisjoint.(contracting_dimensions, batching_dimensions))

lhs_contracting_dimensions, rhs_contracting_dimensions = contracting_dimensions
lhs_batching_dimensions, rhs_batching_dimensions = batching_dimensions

# C5 + C6 + C7 + C8
@assert all(lhs_batching_dimensions .<= ndims(lhs))
@assert all(rhs_batching_dimensions .<= ndims(rhs))
@assert all(lhs_contracting_dimensions .<= ndims(lhs))
@assert all(rhs_contracting_dimensions .<= ndims(rhs))

# C9 + C10
@assert size.(Ref(lhs), lhs_batching_dimensions) ==
size.(Ref(rhs), rhs_batching_dimensions)
@assert size.(Ref(lhs), lhs_contracting_dimensions) ==
size.(Ref(rhs), rhs_contracting_dimensions)

# C11
@assert isnothing(precision_config) || length(precision_config) == 2

@assert isnothing(precision_type) ||
length(precision_type) == 2 && eltype(precision_type) <: AbstractFloat
@assert isnothing(accumulation_type) || accumulation_type <: AbstractFloat

# C22 + C23
@assert isnothing(component_count) ||
length(component_count) == 2 &&
eltype(component_count) <: Int32 &&
all(0 .<= component_count)

# C24
@assert isnothing(num_primitive_operations) ||
num_primitive_operations isa Int32 && num_primitive_operations > 0
@assert isnothing(allow_imprecise_accumulation) || allow_imprecise_accumulation isa Bool

ctx = MLIR.IR.context()

# from C12
lhs_result_dimensions = setdiff(
1:ndims(lhs), lhs_batching_dimensions, lhs_contracting_dimensions
)
rhs_result_dimensions = setdiff(
1:ndims(rhs), rhs_batching_dimensions, rhs_contracting_dimensions
)

ressize = vcat(
size.(Ref(lhs), lhs_batching_dimensions),
size.(Ref(lhs), lhs_result_dimensions),
size.(Ref(rhs), rhs_result_dimensions),
)

# fix 1-indexing
lhs_batching_dimensions = lhs_batching_dimensions .- 1
rhs_batching_dimensions = rhs_batching_dimensions .- 1
lhs_contracting_dimensions = lhs_contracting_dimensions .- 1
rhs_contracting_dimensions = rhs_contracting_dimensions .- 1

dot_dimension_numbers = GC.@preserve lhs_contracting_dimensions rhs_contracting_dimensions lhs_batching_dimensions rhs_batching_dimensions begin
MLIR.IR.Attribute(
MLIR.API.stablehloDotDimensionNumbersGet(
ctx,
length(lhs_batching_dimensions),
lhs_batching_dimensions,
length(rhs_batching_dimensions),
rhs_batching_dimensions,
length(lhs_contracting_dimensions),
lhs_contracting_dimensions,
length(rhs_contracting_dimensions),
rhs_contracting_dimensions,
),
)
end

if !isnothing(precision_config)
precision_config = MLIR.IR.Attribute([
MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[1]),
MLIR.API.stablehloPrecisionAttrGet(ctx, precision_config[2]),
])
end

# all or nothing: if one is set, all must be set
# TODO maybe be more flexible, by setting some defaults?
if any(
!isnothing,
(
precision_type,
accumulation_type,
component_count,
num_primitive_operations,
allow_imprecise_accumulation,
),
)
@assert all(
!isnothing,
(
precision_type...,
accumulation_type,
component_count...,
num_primitive_operations,
allow_imprecise_accumulation,
),
)
lhs_precision_type, rhs_precision_type = precision_type
lhs_component_count, rhs_component_count = component_count
algorithm = GC.@preserve begin
MLIR.IR.Attribute(
MLIR.API.stablehloDotAlgorithmGet(
ctx,
lhs_precision_type,
rhs_precision_type,
accumulation_type,
lhs_component_count,
rhs_component_count,
num_primitive_operations,
allow_imprecise_accumulation,
),
)
end
else
algorithm = nothing
end

res = MLIR.IR.result(
stablehlo.dot_general(
lhs.mlir_data,
rhs.mlir_data;
result_0=mlir_type(TracedRArray{T,length(ressize)}, ressize),
dot_dimension_numbers,
precision_config,
algorithm,
location,
),
)
return TracedRArray{T,length(ressize)}((), res, ressize)
end

function einsum(
lhs::TracedRArray{T},
rhs::TracedRArray{T};
equation::String,
location=mlir_stacktrace("einsum", @__FILE__, @__LINE__),
) where {T}
Base.depwarn(
"`stablehlo.einsum` is on deprecation process; use `dot_general` instead", :einsum
)
ins, ic = split(equation, "->")
ia, ib = split(ins, ",")

Expand Down
31 changes: 31 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ end
end
end

@testset "dot_general" begin
# dot product of first dim
f1(x, y) = Ops.dot_general(x, y; contracting_dimensions=[[1], [1]])

# outer product
fouter(x, y) = Ops.dot_general(x, y; contracting_dimensions=[Int[], Int[]])

# outer product, batch first dim
fouter_batch1(x, y) = Ops.dot_general(
x, y; contracting_dimensions=[Int[], Int[]], batching_dimensions=[[1], [1]]
)

for (a, b) in [
(ConcreteRArray([1, 2, 3, 4]), ConcreteRArray([5, 6, -7, -8])),
(ConcreteRArray([1.0, 2.0, 3.0, 4.0]), ConcreteRArray([5.0, 6.0, -7.0, -8.0])),
(
ConcreteRArray([1.0, 2.0im, 3.0, 4.0im]),
ConcreteRArray([5.0, 6.0im, -7.0im, -8.0]),
),
]
# NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation
@test sum(a .* b) @jit f1(a, b)
@test kron(reshape(a, length(a), 1), reshape(b, 1, length(b))) @jit fouter(a, b)
@test a .* b @jit fouter_batch1(a, b)
end

a = ConcreteRArray([1 2; 3 4])
b = ConcreteRArray([5 6; -7 -8])
@test a' * b == @jit f1(a, b)
end

@testset "einsum" begin
f1(a, b) = Ops.einsum(a, b; equation="i,i->i")
f2(a, b) = Ops.einsum(a, b; equation="i,j->ij")
Expand Down

0 comments on commit 390b585

Please sign in to comment.