Skip to content

Commit

Permalink
Merge pull request #984 from JuliaAI/pipelines-with-supervised-transf…
Browse files Browse the repository at this point in the history
…ormers

Add support in pipelines for `Unsupervised` models for which `target_in_fit`  is `true`
  • Loading branch information
ablaom authored Jul 2, 2024
2 parents 370b3da + 63344b7 commit b44e7cf
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.5.0"
version = "1.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -47,7 +47,7 @@ DelimitedFiles = "1"
Distributions = "0.25.3"
InvertedIndices = "1"
LearnAPI = "0.1"
MLJModelInterface = "1.10"
MLJModelInterface = "1.11"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
Expand All @@ -58,7 +58,7 @@ Reexport = "1.2"
ScientificTypes = "3"
StatisticalMeasures = "0.1.1"
StatisticalMeasuresBase = "0.1.1"
StatisticalTraits = "3.3"
StatisticalTraits = "3.4"
Statistics = "1"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1.0"
Expand Down
18 changes: 16 additions & 2 deletions src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,15 @@ implements it (some clustering models). Similarly, calling `transform`
on a supervised pipeline calls `transform` on the supervised
component.
### Transformers that need a target in training
Some transformers that have type `Unsupervised` (so that the output of `transform` is
propagated in pipelines) may require a target variable for training. An example are
so-called target encoders (which transform categorical input features, based on some
target observations). Provided they appear before any `Supervised` component in the
pipelines, such models are supported. Of course a target must be provided whenever
training such a pipeline, whether or not it contains a `Supervised` component.
### Optional key-word arguments
- `prediction_type` -
Expand Down Expand Up @@ -444,9 +453,13 @@ function extend(front::Front{Pred}, ::Static, name, cache, args...)
Front(transform(mach, active(front)), front.transform, Pred())
end

function extend(front::Front{Trans}, component::Unsupervised, name, cache, args...)
function extend(front::Front{Trans}, component::Unsupervised, name, cache, ::Any, sources...)
a = active(front)
mach = machine(name, a; cache=cache)
if target_in_fit(component)
mach = machine(name, a, first(sources); cache=cache)
else
mach = machine(name, a; cache=cache)
end
Front(predict(mach, a), transform(mach, a), Trans())
end

Expand Down Expand Up @@ -598,6 +611,7 @@ function MMI.iteration_parameter(pipe::SupervisedPipeline)
end

MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(p))
MMI.target_in_fit(p::SomePipeline) = any(target_in_fit, components(p))

MMI.package_name(::Type{<:SomePipeline}) = "MLJBase"
MMI.load_path(::Type{<:SomePipeline}) = "MLJBase.Pipeline"
Expand Down
35 changes: 35 additions & 0 deletions test/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ end
# inverse transform:
p = Pipeline(UnivariateBoxCoxTransformer,
UnivariateStandardizer)
@test !target_in_fit(p)
xtrain = rand(rng, 10)
mach = machine(p, xtrain)
fit!(mach, verbosity=0)
Expand Down Expand Up @@ -702,6 +703,40 @@ end
@test Set(features) == Set(keys(X))
end

struct SupervisedTransformer <: Unsupervised end

MLJBase.fit(::SupervisedTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
MLJBase.transform(::SupervisedTransformer, fitresult, X) =
fitresult*MLJBase.matrix(X) |> MLJBase.table
MLJBase.target_in_fit(::Type{<:SupervisedTransformer}) = true

struct DummyTransformer <: Unsupervised end
MLJBase.fit(::DummyTransformer, verbosity, X) = (nothing, nothing, nothing)
MLJBase.transform(::DummyTransformer, fitresult, X) = X

@testset "supervised transformers in a pipeline" begin
X = MLJBase.table((a=fill(10.0, 3),))
y = fill(2, 3)
pipe = SupervisedTransformer() |> DeterministicConstantRegressor()
@test target_in_fit(pipe)
mach = machine(pipe, X, y)
fit!(mach, verbosity=0)
@test predict(mach, X) == fill(2.0, 3)

pipe2 = DummyTransformer |> pipe
@test target_in_fit(pipe2)
mach = machine(pipe2, X, y)
fit!(mach, verbosity=0)
@test predict(mach, X) == fill(2.0, 3)

pipe3 = DummyTransformer |> SupervisedTransformer |> DummyTransformer
@test target_in_fit(pipe3)
mach = machine(pipe3, X, y)
fit!(mach, verbosity=0)
@test transform(mach, X).x1 == fill(20.0, 3)
end


end # module

true

0 comments on commit b44e7cf

Please sign in to comment.