diff --git a/src/base.jl b/src/base.jl index 8ad6b7f..f02bf8d 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,8 +1,15 @@ @functor Base.RefValue +@functor Base.Pair + +@functor Base.Generator # aka Iterators.map + functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) +@functor Base.Fix1 +@functor Base.Fix2 + ### ### Array wrappers ### @@ -36,3 +43,26 @@ end _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) _PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent _PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm) + +### +### Iterators +### + +@functor Iterators.Accumulate +# Count +@functor Iterators.Cycle +@functor Iterators.Drop +@functor Iterators.DropWhile +@functor Iterators.Enumerate +@functor Iterators.Filter +@functor Iterators.Flatten +# IterationCutShort +@functor Iterators.PartitionIterator +@functor Iterators.ProductIterator +@functor Iterators.Repeated +@functor Iterators.Rest +@functor Iterators.Reverse +# Stateful +@functor Iterators.Take +@functor Iterators.TakeWhile +@functor Iterators.Zip diff --git a/test/base.jl b/test/base.jl index 4dba6c5..3e3d2ef 100644 --- a/test/base.jl +++ b/test/base.jl @@ -34,6 +34,14 @@ end @test fmap(x -> x + 10, f1 ∘ f2) == Foo(11.1, 12.2) ∘ Bar(13.3) end +@testset "Pair, Fix12" begin + @test fmap(sqrt, 4 => 9) === (2.0 => 3.0) + + exclude = x -> x isa Number + @test fmap(sqrt, Base.Fix1(/, 4); exclude)(10) == 0.2 + @test fmap(sqrt, Base.Fix2(/, 4); exclude)(10) == 5.0 +end + @testset "LinearAlgebra containers" begin @test fmapstructure(identity, [1,2,3]') == (parent = [1, 2, 3],) @test fmapstructure(identity, transpose([1,2,3])) == (parent = [1, 2, 3],) @@ -84,3 +92,63 @@ end @test fmapstructure(identity, PermutedDimsArray([1 2; 3 4], (2,1))) == (parent = [1 2; 3 4],) @test fmap(exp, PermutedDimsArray([1 2; 3 4], (2,1))) isa PermutedDimsArray{Float64} end + +@testset "Iterators" begin + exclude = x -> x isa Array + + x = fmap(complex, Iterators.map(sqrt, [1,2,3]); exclude) # Base.Generator + @test x.iter isa Vector{<:Complex} + @test collect(x) isa Vector{<:Complex} + + x = fmap(complex, Iterators.accumulate(/, [1,2,3]); exclude) + @test x.itr isa Vector{<:Complex} + @test collect(x) isa Vector{<:Complex} + + x = fmap(complex, Iterators.cycle([1,2,3])) + @test x.xs isa Vector{<:Complex} + @test first(x) isa Complex + + x = fmap(complex, Iterators.drop([1,2,3], 1); exclude) + @test x.xs isa Vector{<:Complex} + @test collect(x) isa Vector{<:Complex} + + + x = fmap(complex, Iterators.drop([1,2,3], 1); exclude) + @test x.xs isa Vector{<:Complex} + @test collect(x) isa Vector{<:Complex} + + x = fmap(float, Iterators.dropwhile(<(2), [1,2,3]); exclude) + @test x.xs isa Vector{Float64} + @test collect(x) isa Vector{Float64} + + x = fmap(complex, enumerate([1,2,3])) + @test first(x) === (1, 1+0im) + + x = fmap(float, Iterators.filter(<(3), [1,2,3]); exclude) + @test collect(x) isa Vector{Float64} + + x = fmap(complex, Iterators.flatten(([1,2,3], [4,5]))) + @test collect(x) isa Vector{<:Complex} + + x = fmap(complex, Iterators.partition([1,2,3],2); exclude) + @test first(x) isa AbstractVector{<:Complex} + + x = fmap(complex, Iterators.product([1,2,3],[4,5])) + @test first(x) === (1 + 0im, 4 + 0im) + + x = fmap(complex, Iterators.repeated([1,2,3], 4); exclude) # Iterators.Take{Iterators.Repeated} + @test first(x) isa Vector{<:Complex} + + x = fmap(complex, Iterators.rest([1,2,3], 2); exclude) + @test collect(x) isa Vector{<:Complex} + + x = fmap(complex, Iterators.reverse([1,2,3])) + @test collect(x) isa Vector{<:Complex} + + x = fmap(float, Iterators.takewhile(<(2), [1,2,3]); exclude) + @test collect(x) isa Vector{Float64} + + x = fmap(complex, zip([1,2,3], [4,5])) + @test x.is[1] isa Vector{<:Complex} + @test collect(x) isa Vector{<:Tuple{Complex, Complex}} +end