diff --git a/src/TableTransforms.jl b/src/TableTransforms.jl index 46d6a295..67f4cf1f 100644 --- a/src/TableTransforms.jl +++ b/src/TableTransforms.jl @@ -39,6 +39,7 @@ export Coerce, Levels, OneHot, + NarrowTypes, Identity, Center, Scale, diff --git a/src/transforms.jl b/src/transforms.jl index 2183face..b2ea1394 100644 --- a/src/transforms.jl +++ b/src/transforms.jl @@ -221,6 +221,7 @@ include("transforms/coalesce.jl") include("transforms/coerce.jl") include("transforms/levels.jl") include("transforms/onehot.jl") +include("transforms/narrowtypes.jl") include("transforms/identity.jl") include("transforms/center.jl") include("transforms/scale.jl") diff --git a/src/transforms/narrowtypes.jl b/src/transforms/narrowtypes.jl new file mode 100644 index 00000000..da22d8fb --- /dev/null +++ b/src/transforms/narrowtypes.jl @@ -0,0 +1,18 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + NarrowTypes() + +Converts the element type of columns with generic types to more specific types. +""" +struct NarrowTypes <: Colwise end + +isrevertible(::Type{NarrowTypes}) = true + +colcache(::NarrowTypes, x) = eltype(x) + +colapply(::NarrowTypes, x, c) = identity.(x) + +colrevert(::NarrowTypes, y, c) = collect(c, y) diff --git a/test/transforms.jl b/test/transforms.jl index afbe2960..7baa6323 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -1197,6 +1197,24 @@ @test_throws AssertionError apply(OneHot("c"), t) end + @testset "NarrowTypes" begin + a = Integer[4, 6, 5, 9, 3, 1] + b = Number[1, 0, 1, true, false, true] + c = Any[5, 6, 7, 1.5, 1.6, 1.7] + t = Table(; a, b, c) + + T = NarrowTypes() + n, c = apply(T, t) + @test eltype(n.a) == Int + @test eltype(n.b) == Integer + @test eltype(n.c) == Real + tₒ = revert(T, n, c) + @test eltype(tₒ.a) == Integer + @test eltype(tₒ.b) == Number + @test eltype(tₒ.c) == Any + @test t == tₒ + end + @testset "Identity" begin x = rand(4000) y = rand(4000)