You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@vchuravy got the following code for differentiation the einsum! function in OMEinsum work. He also pointed out that the relevant documentation could be improved. Hope this code snippet helps.
using Enzyme, Enzyme.EnzymeRules, OMEinsum
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, ::Type,
code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
@assert sx.val ==1&& sy.val ==0"Only α = 1 and β = 0 is supported, got: $sx, $sy"# Compute primalif EnzymeRules.needs_primal(config)
primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
else
primal =nothingend# Save x in tape if x will be overwrittenif EnzymeRules.overwritten(config)[3]
tape =copy(xs.val)
else
tape =nothingend
shadow = ys.dval
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
endfunction EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
code::Const,
xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val
for i=1:length(xs.val)
xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
endreturn (nothing, nothing, nothing, nothing, nothing, nothing)
end
x =randn(3, 3);
y =randn(3);
gx =zero(x);
gy =zero(y);
functiontestf2(x)
y =zeros(size(x, 1))
einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
returnsum(y)
endautodiff(ReverseWithPrimal, testf2, Duplicated(x, gx))
gx
The function signature of einsum! is
einsum!(code::EinCode, xs::Tuple, y, sx, sy, size_dict::Dict=get_size_dict(getixs(code), xs))
The input y is directly changed, and the return value is the same as y.
The text was updated successfully, but these errors were encountered:
@vchuravy got the following code for differentiation the
einsum!
function in OMEinsum work. He also pointed out that the relevant documentation could be improved. Hope this code snippet helps.The function signature of
einsum!
isThe input
y
is directly changed, and the return value is the same asy
.The text was updated successfully, but these errors were encountered: