Skip to content

Commit

Permalink
Add more base types (#47)
Browse files Browse the repository at this point in the history
* add more base types

* add tests
  • Loading branch information
mcabbott authored Nov 15, 2022
1 parent 75fef99 commit 04d08aa
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@

@functor Base.RefValue

@functor Base.Pair

@functor Base.Generator # aka Iterators.map

functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)

@functor Base.Fix1
@functor Base.Fix2

###
### Array wrappers
###
Expand Down Expand Up @@ -36,3 +43,26 @@ end
_PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm)
_PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent
_PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm)

###
### Iterators
###

@functor Iterators.Accumulate
# Count
@functor Iterators.Cycle
@functor Iterators.Drop
@functor Iterators.DropWhile
@functor Iterators.Enumerate
@functor Iterators.Filter
@functor Iterators.Flatten
# IterationCutShort
@functor Iterators.PartitionIterator
@functor Iterators.ProductIterator
@functor Iterators.Repeated
@functor Iterators.Rest
@functor Iterators.Reverse
# Stateful
@functor Iterators.Take
@functor Iterators.TakeWhile
@functor Iterators.Zip
68 changes: 68 additions & 0 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ end
@test fmap(x -> x + 10, f1 f2) == Foo(11.1, 12.2) Bar(13.3)
end

@testset "Pair, Fix12" begin
@test fmap(sqrt, 4 => 9) === (2.0 => 3.0)

exclude = x -> x isa Number
@test fmap(sqrt, Base.Fix1(/, 4); exclude)(10) == 0.2
@test fmap(sqrt, Base.Fix2(/, 4); exclude)(10) == 5.0
end

@testset "LinearAlgebra containers" begin
@test fmapstructure(identity, [1,2,3]') == (parent = [1, 2, 3],)
@test fmapstructure(identity, transpose([1,2,3])) == (parent = [1, 2, 3],)
Expand Down Expand Up @@ -84,3 +92,63 @@ end
@test fmapstructure(identity, PermutedDimsArray([1 2; 3 4], (2,1))) == (parent = [1 2; 3 4],)
@test fmap(exp, PermutedDimsArray([1 2; 3 4], (2,1))) isa PermutedDimsArray{Float64}
end

@testset "Iterators" begin
exclude = x -> x isa Array

x = fmap(complex, Iterators.map(sqrt, [1,2,3]); exclude) # Base.Generator
@test x.iter isa Vector{<:Complex}
@test collect(x) isa Vector{<:Complex}

x = fmap(complex, Iterators.accumulate(/, [1,2,3]); exclude)
@test x.itr isa Vector{<:Complex}
@test collect(x) isa Vector{<:Complex}

x = fmap(complex, Iterators.cycle([1,2,3]))
@test x.xs isa Vector{<:Complex}
@test first(x) isa Complex

x = fmap(complex, Iterators.drop([1,2,3], 1); exclude)
@test x.xs isa Vector{<:Complex}
@test collect(x) isa Vector{<:Complex}


x = fmap(complex, Iterators.drop([1,2,3], 1); exclude)
@test x.xs isa Vector{<:Complex}
@test collect(x) isa Vector{<:Complex}

x = fmap(float, Iterators.dropwhile(<(2), [1,2,3]); exclude)
@test x.xs isa Vector{Float64}
@test collect(x) isa Vector{Float64}

x = fmap(complex, enumerate([1,2,3]))
@test first(x) === (1, 1+0im)

x = fmap(float, Iterators.filter(<(3), [1,2,3]); exclude)
@test collect(x) isa Vector{Float64}

x = fmap(complex, Iterators.flatten(([1,2,3], [4,5])))
@test collect(x) isa Vector{<:Complex}

x = fmap(complex, Iterators.partition([1,2,3],2); exclude)
@test first(x) isa AbstractVector{<:Complex}

x = fmap(complex, Iterators.product([1,2,3],[4,5]))
@test first(x) === (1 + 0im, 4 + 0im)

x = fmap(complex, Iterators.repeated([1,2,3], 4); exclude) # Iterators.Take{Iterators.Repeated}
@test first(x) isa Vector{<:Complex}

x = fmap(complex, Iterators.rest([1,2,3], 2); exclude)
@test collect(x) isa Vector{<:Complex}

x = fmap(complex, Iterators.reverse([1,2,3]))
@test collect(x) isa Vector{<:Complex}

x = fmap(float, Iterators.takewhile(<(2), [1,2,3]); exclude)
@test collect(x) isa Vector{Float64}

x = fmap(complex, zip([1,2,3], [4,5]))
@test x.is[1] isa Vector{<:Complex}
@test collect(x) isa Vector{<:Tuple{Complex, Complex}}
end

0 comments on commit 04d08aa

Please sign in to comment.