-
Notifications
You must be signed in to change notification settings - Fork 154
fix: regression in non-fast scalar indexing support #760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
9cbe0e8
095a9d1
0ed3a82
cbc1661
2c97323
5b8ffab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,36 +72,46 @@ end | |
|
|
||
| function seed!(duals::AbstractArray{Dual{T,V,N}}, x, | ||
| seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N} | ||
| if isbitstype(V) | ||
| for idx in structural_eachindex(duals, x) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| end | ||
| else | ||
| for idx in structural_eachindex(duals, x) | ||
| if isassigned(x, idx) | ||
| if supports_fast_scalar_indexing(duals) | ||
| if isbitstype(V) | ||
| for idx in structural_eachindex(duals, x) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| else | ||
| for idx in structural_eachindex(duals, x) | ||
| if isassigned(x, idx) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| end | ||
| end | ||
| else | ||
| idxs = collect(structural_eachindex(duals, x)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe let's throw a descriptive error in this branch if |
||
| duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) | ||
| end | ||
| return duals | ||
| end | ||
|
|
||
| function seed!(duals::AbstractArray{Dual{T,V,N}}, x, | ||
| seeds::NTuple{N,Partials{N,V}}) where {T,V,N} | ||
| if isbitstype(V) | ||
| for (i, idx) in zip(1:N, structural_eachindex(duals, x)) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| end | ||
| else | ||
| for (i, idx) in zip(1:N, structural_eachindex(duals, x)) | ||
| if isassigned(x, idx) | ||
| if supports_fast_scalar_indexing(duals) | ||
| if isbitstype(V) | ||
| for (i, idx) in zip(1:N, structural_eachindex(duals, x)) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| else | ||
| for (i, idx) in zip(1:N, structural_eachindex(duals, x)) | ||
| if isassigned(x, idx) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| end | ||
| end | ||
| else | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar here, maybe let's throw an error in this branch if |
||
| idxs = collect(Iterators.take(structural_eachindex(duals, x), N)) | ||
| duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) | ||
| end | ||
| return duals | ||
| end | ||
|
|
@@ -110,18 +120,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index, | |
| seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N} | ||
| offset = index - 1 | ||
| idxs = Iterators.drop(structural_eachindex(duals, x), offset) | ||
| if isbitstype(V) | ||
| for idx in idxs | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| end | ||
| else | ||
| for idx in idxs | ||
| if isassigned(x, idx) | ||
| if supports_fast_scalar_indexing(duals) | ||
| if isbitstype(V) | ||
| for idx in idxs | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| else | ||
| for idx in idxs | ||
| if isassigned(x, idx) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seed) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| end | ||
| end | ||
| else | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here. |
||
| idxs = collect(idxs) | ||
| duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) | ||
| end | ||
| return duals | ||
| end | ||
|
|
@@ -130,18 +145,23 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index, | |
| seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N} | ||
| offset = index - 1 | ||
| idxs = Iterators.drop(structural_eachindex(duals, x), offset) | ||
| if isbitstype(V) | ||
| for (i, idx) in zip(1:chunksize, idxs) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| end | ||
| else | ||
| for (i, idx) in zip(1:chunksize, idxs) | ||
| if isassigned(x, idx) | ||
| if supports_fast_scalar_indexing(duals) | ||
| if isbitstype(V) | ||
| for (i, idx) in zip(1:chunksize, idxs) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| else | ||
| for (i, idx) in zip(1:chunksize, idxs) | ||
| if isassigned(x, idx) | ||
| duals[idx] = Dual{T,V,N}(x[idx], seeds[i]) | ||
| else | ||
| Base._unsetindex!(duals, idx) | ||
| end | ||
| end | ||
| end | ||
| else | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here. |
||
| idxs = collect(Iterators.take(idxs, chunksize)) | ||
| duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) | ||
| end | ||
| return duals | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # overload for array types that | ||
| supports_fast_scalar_indexing(::Array) = true | ||
|
|
||
| function supports_fast_scalar_indexing(x::AbstractArray) | ||
| return parent(x) !== x && supports_fast_scalar_indexing(parent(x)) | ||
| end | ||
|
|
||
| # Helper function for broadcasting | ||
| struct PartialsFn{T,D<:Dual} | ||
| dual::D | ||
| end | ||
| PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual) | ||
|
|
||
| (f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually true?
StaticArrayis the abstract supertype, is it required that all subtypes support fast scalar indexing?