diff --git a/ext/TensorKitChainRulesCoreExt/factorizations.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl index d4dc66f7..5680afd9 100644 --- a/ext/TensorKitChainRulesCoreExt/factorizations.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -20,7 +20,8 @@ function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; Ũ, Σ̃, Ṽ⁺ = U, Σ, V⁺ end - function tsvd!_pullback((ΔU, ΔΣ, ΔV⁺, Δϵ)) + function tsvd!_pullback(ΔUSVϵ) + ΔU, ΔΣ, ΔV⁺, = unthunk.(ΔUSVϵ) Δt = similar(t) for (c, b) in blocks(Δt) Uc, Σc, V⁺c = block(U, c), block(Σ, c), block(V⁺, c)