Skip to content

Commit 43316b9

Browse files
committed
fix some ComponentArray grads reverting to Vector
1 parent 6d18650 commit 43316b9

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/autodiff.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
# this does basis promotion, unlike Zygote's default for AbstractArrays
33
Zygote.accum(a::Field, b::Field) = a+b
4+
Zygote.accum(a::FieldTuple, b::FieldTuple) = Zygote.accum.(a,b)
45
# this may create a LazyBinaryOp, unlike Zygote's
56
Zygote.accum(a::FieldOp, b::FieldOp) = a+b
67

src/dataset.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ function load_sim(;
230230
@warn "`rfid` will be removed in a future version. Use `fiducial_θ=(r=...,)` instead."
231231
fiducial_θ = merge(fiducial_θ,(r=rfid,))
232232
end
233-
Aϕ₀ = get(fiducial_θ, :Aϕ, 1)
233+
Aϕ₀ = T(get(fiducial_θ, :Aϕ, 1))
234234
fiducial_θ = Base.structdiff(fiducial_θ, NamedTuple{(:Aϕ,)}) # remove Aϕ key if present
235235
if (Cℓ == nothing)
236236
Cℓ = camb(;fiducial_θ..., ℓmax=ℓmax)
@@ -241,7 +241,7 @@ function load_sim(;
241241
error("ℓmax of `Cℓ` argument should be higher than $ℓmax for this configuration.")
242242
end
243243
end
244-
r₀ = Cℓ.params.r
244+
r₀ = T(Cℓ.params.r)
245245

246246
# noise Cℓs (these are non-debeamed, hence beamFWHM=0 below; the beam comes in via the B operator)
247247
if (Cℓn == nothing)
@@ -264,7 +264,7 @@ function load_sim(;
264264
Cf̃ = Cℓ_to_Cov(pol, proj, (Cℓ.total[k] for k in ks)...)
265265
Cn̂ = Cℓ_to_Cov(pol, proj, (Cℓn[k] for k in ks)...)
266266
if (Cn == nothing); Cn = Cn̂; end
267-
Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + T(r/r₀)*Cft))
267+
Cf = ParamDependentOp((;r=r₀, _...)->(Cfs + (T(r)/r₀)*Cft))
268268
= ParamDependentOp((;Aϕ=Aϕ₀, _...)->(T(Aϕ) * Cϕ₀))
269269

270270
# data mask

0 commit comments

Comments
 (0)