Skip to content

Commit

Permalink
pb2 and pfwd for sparsesymmprod
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Jun 12, 2024
1 parent 00f9be5 commit 7f382d4
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 206 deletions.
10 changes: 10 additions & 0 deletions src/ace/simpleprodbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ end
# we only provide the evaluate functionality itself to test the DAG
# gradients can just be checked by finite differences

_valtype(basis::SimpleProdBasis, ::Type{T}) where {T} = T

function whatalloc(::typeof(evaluate!),
basis::SimpleProdBasis, A::AbstractVector{T}) where {T}
VT = _valtype(basis, T)
return (VT, length(basis))
end




function evaluate!(AA, basis::SimpleProdBasis, A::AbstractVector)
for i = 1:length(basis)
Expand Down
180 changes: 85 additions & 95 deletions src/ace/sparsesymmprod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ end

Base.length(basis::SparseSymmProd) = sum(length, basis.specs) + basis.hasconst

Base.show(io::IO, basis::SparseSymmProd{ORD}) where {ORD} =
print(io, "SparseSymmProd(order=$(ORD), length = $(length(basis)))")

function reconstruct_spec(basis::SparseSymmProd)
spec = [ [ bb... ] for bb in vcat(basis.specs...) ]
if basis.hasconst
Expand Down Expand Up @@ -136,7 +139,8 @@ function _evaluate_AA!(AA, spec::Vector{NTuple{N, Int}}, A::AbstractMatrix) wher
end


# -------
# -----------------------------------
# pullback

import ChainRulesCore: rrule, NoTangent

Expand All @@ -151,6 +155,7 @@ end
∂AA, basis::SparseSymmProd{ORD}, A
) where {ORD}
quote
fill!(∂A, zero(eltype(∂A)))
@nexprs $ORD N -> _pb_evaluate_pbAA!(
∂A,
__view_AA(∂AA, basis, N),
Expand Down Expand Up @@ -192,122 +197,107 @@ function _pb_evaluate_pbAA!(gA, ΔN::AbstractMatrix,
end


function rrule(::typeof(pullback), ∂AA, basis::SparseSymmProd, A)
∂A = pullback(∂AA, basis, A)
function pb(∂∂A)
∂²∂AA, ∂²A = pullback2(∂∂A, ∂AA, basis, A)
return NoTangent(), ∂²∂AA, NoTangent(), ∂²A
end
return ∂A, pb
# -----------------------------------
# reverse-over-reverse
#
# AA = evaluate(basis, A)
# ∇_A = pullback(∂AA, basis, A)
# ∇_∂AA, ∇_A = pullback2(∂∇A, ∂AA, basis, A)
#

function whatalloc(::typeof(pullback2!),
∂∇A, # cotangent to be pulled back
∂AA, # cotangent from pullback(∂AA, basis, A)
basis::SparseSymmProd, A)
T = promote_type(eltype(∂∇A), eltype(∂AA), eltype(A))
return (T, size(∂AA)...), (T, size(A)...)
end


@generated function pullback2(Δ², ΔAA, basis::SparseSymmProd{ORD}, A) where {ORD}
quote
TG = promote_type(eltype(Δ²), eltype(ΔAA), eltype(A))
gΔAA = zeros(TG, size(ΔAA))
gA = zeros(TG, size(A))
@nexprs $ORD N -> _pullback2_AA!(basis.specs[N],
__view_AA(gΔAA, basis, N), gA, # outputs (gradients)
Δ², # differential
__view_AA(ΔAA, basis, N), A, # inputs
)
return gΔAA, gA
end
end

function _pullback2_AA!(spec::Vector{NTuple{N, Int}},
gΔAA, gA,
Δ²,
ΔN::AbstractVector, A::AbstractVector) where {N}
# We wish to compute ∇_Δ and ∇_A w.r.t. the expression
# ∑ₖ Δ²ₖ * ∇_{Aₖ} (Δ ⋅ AA) (Δ = ΔN)
# = ∑ᵢ Δᵢ * ∇̃ AA[i]
# where ∇̃ = ∑_k Δ²ₖ * ∇_Aₖ
# here k = 1,...,#A and i = 1,...,#AA

@assert size(gA) == size(A)
@assert length(gΔAA) >= length(spec)
@assert length(ΔN) >= length(spec)
@assert length(Δ²) >= length(A)

@inbounds for (i, ϕ) in enumerate(spec)
A_ϕ = ntuple(t -> A[ϕ[t]], N)
Δ²_ϕ = ntuple(t -> Δ²[ϕ[t]], N)
p_i, g_i, u_i = _pb_grad_static_prod(Δ²_ϕ, A_ϕ)
gΔAA[i] = sum(g_i .* Δ²_ϕ)
for t = 1:N
gA[ϕ[t]] += u_i[t] * ΔN[i]
function pullback2!(∇_∂AA, ∇_A,
∂∇A, ∂AA, basis, A)
@assert size(∂∇A) == size(A)
T = promote_type(eltype(∂∇A), eltype(∂AA), eltype(A))
d = Dual{T}(zero(T), one(T))
DT = typeof(d)
@no_escape begin
A_d = @alloc(DT, size(A)...)
@inbounds for i = 1:length(A)
A_d[i] = A[i] + d * ∂∇A[i]
end
end
return nothing
end


function _pullback2_AA!(spec::Vector{NTuple{N, Int}},
gΔAA, gA,
Δ²,
ΔN::AbstractMatrix, A::AbstractMatrix) where {N}
nX = size(A, 1)
for (i, ϕ) in enumerate(spec)
for j = 1:nX
A_ϕ = ntuple(t -> A[j, ϕ[t]], N)
Δ²_ϕ = ntuple(t -> Δ²[j, ϕ[t]], N)
p_i, g_i, u_i = _pb_grad_static_prod(Δ²_ϕ, A_ϕ)
gΔAA[j, i] = sum(g_i .* Δ²_ϕ)
for t = 1:N
gA[j, ϕ[t]] += u_i[t] * ΔN[j, i]
end
AA_d = @withalloc evaluate!(basis, A_d)
# ∇A_d = pullback(∂AA, basis, A_d)
∇A_d = @withalloc pullback!(∂AA, basis, A_d)
@inbounds for i = 1:length(AA_d)
∇_∂AA[i] = extract_derivative(eltype(∇_∂AA), AA_d[i])
end
@inbounds for i = 1:length(∇A_d)
∇_A[i] = extract_derivative(eltype(∇_A), ∇A_d[i])
end
end
return nothing
return ∇_∂AA, ∇_A
end



# -------------- Pushforwards / frules

# TODO: REMOVE
# function pushforward(basis::SparseSymmProd,
# A::AbstractVector{<: Number},
# ΔA::AbstractMatrix)
# nAA = length(basis)
# TAA = eltype(A)
# AA = zeros(TAA, nAA)
# T∂AA = _my_promote_type(TAA, eltype(ΔA))
# ∂AA = zeros(T∂AA, nAA, size(ΔA, 2))
# pushforward!(AA, ∂AA, basis, A, ΔA)
# return AA, ∂AA
# end
using ForwardDiff

function whatalloc(::typeof(pushforward!),
basis::SparseSymmProd, A, ΔA)
nAA = length(basis)
basis::SparseSymmProd,
A::AbstractVector, ΔA::AbstractVector)
nAA = length(basis)
TAA = eltype(A)
T∂AA = _my_promote_type(TAA, eltype(ΔA))
return (TAA, nAA), (T∂AA, nAA, size(ΔA, 2))
T∂AA = promote_type(TAA, eltype(ΔA))
return (TAA, nAA), (T∂AA, nAA)
end

@generated function pushforward!(AA, ∂AA, basis::SparseSymmProd{NB}, A, ΔA) where {NB}
quote
if basis.hasconst; error("no implementation with hasconst"); end
Base.Cartesian.@nexprs $NB N -> _pfwd_AA_N!(AA, ∂AA, A, ΔA, basis.ranges[N], basis.specs[N])
return AA, ∂AA
end
function whatalloc(::typeof(pushforward!),
basis::SparseSymmProd,
A::AbstractMatrix, ΔA::AbstractMatrix)
nAA = length(basis)
TAA = eltype(A)
T∂AA = promote_type(TAA, eltype(ΔA))
nX = size(A, 1)
return (TAA, nX, nAA), (T∂AA, nX, nAA)
end

function _pfwd_AA_N!(AA, ∂AA, A, ΔA,
rg_N, spec_N::Vector{NTuple{N, Int}}) where {N}
nX = size(ΔA, 2)
for (i, bb) in zip(rg_N, spec_N)
aa = ntuple(t -> A[bb[t]], N)
∏aa, ∇∏aa = Polynomials4ML._static_prod_ed(aa)
AA[i] = ∏aa
for t = 1:N, j = 1:nX
∂AA[i, j] += ∇∏aa[t] * ΔA[bb[t], j]

function pushforward!(AA, ∂AA, basis::SparseSymmProd, A, ∂A)
@assert size(∂A) == size(A)
@assert size(∂AA) == size(AA)
T = promote_type(eltype(A), eltype(∂A))
d = Dual{T}(zero(T), one(T))
DT = typeof(d)
@no_escape begin
A_d = @alloc(DT, size(A)...)
@inbounds for i = 1:length(A)
A_d[i] = A[i] + d * ∂A[i]
end
AA_d = @withalloc evaluate!(basis, A_d)
@assert length(AA_d) <= length(AA) && length(AA_d) == length(∂AA)
@inbounds for i = 1:length(AA_d)
AA[i] = ForwardDiff.value(AA_d[i])
∂AA[i] = ForwardDiff.extract_derivative(eltype(∂AA), AA_d[i])
end
end
end
return AA, ∂AA
end


# ------------------------------------------
# ChainRules integration

function rrule(::typeof(pullback), ∂AA, basis::SparseSymmProd, A)
∂A = pullback(∂AA, basis, A)
function pb(∂∂A)
∂²∂AA, ∂²A = pullback2(∂∂A, ∂AA, basis, A)
return NoTangent(), ∂²∂AA, NoTangent(), ∂²A
end
return ∂A, pb
end



Expand Down
3 changes: 3 additions & 0 deletions test/ace/test_prodpool_mult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# this is an experimental testset for working with batched
# pooled products. This isn't really supported yet, and not
# property working, hence not part of runtests.
#
# Note also this hasn't yet been updated to the updated interface.
#

@info("PooledSparseProduct - Multiple evaluations")

Expand Down
Loading

0 comments on commit 7f382d4

Please sign in to comment.