Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation about the reverse rule customization needs to be improved #2132

Open
GiggleLiu opened this issue Nov 28, 2024 · 3 comments
Open

Comments

@GiggleLiu
Copy link
Contributor

@vchuravy got the following code for differentiation the einsum! function in OMEinsum work. He also pointed out that the relevant documentation could be improved. Hope this code snippet helps.

using Enzyme, Enzyme.EnzymeRules, OMEinsum

function EnzymeRules.augmented_primal(
        config::EnzymeRules.RevConfigWidth{1},
        func::Const{typeof(einsum!)}, ::Type, 
        code::Const, xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
    @assert sx.val == 1 && sy.val == 0 "Only α = 1 and β = 0 is supported, got: $sx, $sy"
    # Compute primal
    if EnzymeRules.needs_primal(config)
        primal = func.val(code.val, xs.val, ys.val, sx.val, sy.val, size_dict.val)
    else
        primal = nothing
    end
    # Save x in tape if x will be overwritten
    if EnzymeRules.overwritten(config)[3]
        tape = copy(xs.val)
    else
        tape = nothing
    end
    shadow = ys.dval
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
               func::Const{typeof(einsum!)}, dret::Type{<:Annotation}, tape,
               code::Const,
               xs::Duplicated, ys::Duplicated, sx::Const, sy::Const, size_dict::Const)
   
   xval = EnzymeRules.overwritten(config)[3] ? tape : xs.val

   for i=1:length(xs.val)
       xs.dval[i] .+= OMEinsum.einsum_grad(OMEinsum.getixs(code.val),
             xval, OMEinsum.getiy(code.val), size_dict.val, conj(ys.dval), i)
   end
   return (nothing, nothing, nothing, nothing, nothing, nothing)
end

x = randn(3, 3);
y = randn(3);
gx = zero(x);
gy = zero(y);

function testf2(x)
    y = zeros(size(x, 1))
    einsum!(ein"ii->i", (x,), y, 1, 0, Dict('i'=>3))
    return sum(y)
end

autodiff(ReverseWithPrimal, testf2, Duplicated(x, gx))
gx

The function signature of einsum! is

einsum!(code::EinCode, xs::Tuple, y, sx, sy, size_dict::Dict=get_size_dict(getixs(code), xs))

The input y is directly changed, and the return value is the same as y.

@wsmoses
Copy link
Member

wsmoses commented Nov 28, 2024

perhaps make a PR on OMEinsum.jl with your rule?

@GiggleLiu
Copy link
Contributor Author

perhaps make a PR on OMEinsum.jl with your rule?

Yeah, this is exactly what we were doing. @vchuravy mentioned that there are something needs to be documented, e.g.

  1. how to create shadow correctly.
  2. how to handle the output of an inplace function correctly, here the mutable array y is mutable, and also returned by the function.

Do you want to add more? @vchuravy

@wsmoses
Copy link
Member

wsmoses commented Nov 29, 2024

oh yeah for sure we definitely need more docs on custom rules.

Since you went through the first time process recently, would you be interested in giving it a go?

I think here would be the place to add text: https://github.com/EnzymeAD/Enzyme.jl/blob/main/examples/custom_rule.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants