From 7afa1fc4b9af85a82b24465ba8a8933b93ba57b9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 16:00:21 +1300 Subject: [PATCH 1/5] add a test to catch serialization bug fix test oops --- src/builtins/ThresholdPredictors.jl | 14 +++++++ test/builtins/ThresholdPredictors.jl | 55 ++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/src/builtins/ThresholdPredictors.jl b/src/builtins/ThresholdPredictors.jl index c8c85593..5e11e1a1 100644 --- a/src/builtins/ThresholdPredictors.jl +++ b/src/builtins/ThresholdPredictors.jl @@ -305,6 +305,20 @@ function _predict_threshold(yhat::UnivariateFiniteArray{S,V,R,P,N}, end +## SERIALIZATION + +# function MMI.save(model::ThresholdUnion, fitresult) +# atomic_fitresult, threshold = fitresult +# atom = model.model +# return MMI.save(atom, atomic_fitresult), threshold +# end +# function MMI.restore(model::ThresholdUnion, serializable_fitresult) +# atomic_serializable_fitresult, threshold = serializable_fitresult +# atom = model.model +# return MMI.restore(atom, atomic_serializable_fitresult), threshold +# end + + ## TRAITS # Note: input traits are inherited from the wrapped model diff --git a/test/builtins/ThresholdPredictors.jl b/test/builtins/ThresholdPredictors.jl index 069add61..6c3b2b82 100644 --- a/test/builtins/ThresholdPredictors.jl +++ b/test/builtins/ThresholdPredictors.jl @@ -268,6 +268,61 @@ MMI.input_scitype(::Type{<:NaiveClassifier}) = Table(Continuous) mode(Distributions.fit(MLJBase.UnivariateFinite, y[I])), MLJBase.nrows(X) ) end + + +# define a probabilistic classifier with non-persistent `fitresult`, but which addresses +# this by overloading `save`/`restore`: +thing = [] +struct EphemeralClassifier <: MLJBase.Probabilistic end +function MLJBase.fit(::EphemeralClassifier, verbosity, X, y) + # if I serialize/deserialized `thing` then `view` below changes: + view = objectid(thing) + p = Distributions.fit(UnivariateFinite, y) + fitresult = (thing, view, p) + return fitresult, nothing, NamedTuple() +end +function MLJBase.predict(::EphemeralClassifier, fitresult, X) + thing, view, p = fitresult + return view == objectid(thing) ? fill(p, MLJBase.nrows(X)) : + throw(ErrorException("dead fitresult")) +end +MLJBase.target_scitype(::Type{<:EphemeralClassifier}) = AbstractVector{OrderedFactor{2}} +function MLJBase.save(::EphemeralClassifier, fitresult) + thing, _, p = fitresult + return (thing, p) +end +function MLJBase.restore(::EphemeralClassifier, serialized_fitresult) + thing, p = serialized_fitresult + view = objectid(thing) + return (thing, view, p) +end + +# X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true) +# mach = machine(EphemeralClassifier(), X, y) |> fit! +# io = IOBuffer() +# MLJBase.save(io, mach) +# seekstart(io) +# mach2 = machine(io) +# predict(mach2, X) + +@testset "serialization for atomic models with non-persistent fitresults" begin + # https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true) + deterministic_classifier = BinaryThresholdPredictor( + EphemeralClassifier(), + threshold=0.5, + ) + mach = MLJBase.machine(deterministic_classifier, X, y) + MLJBase.fit!(mach, verbosity=0) + yhat = MLJBase.predict(mach, MLJBase.selectrows(X, 1:2)) + io = IOBuffer() + MLJBase.save(io, mach) + seekstart(io) + mach2 = MLJBase.machine(io) + close(io) + @test_broken MLJBase.predict(mach2, (; x = rand(2))) == yhat +end + end # module true From 52dda58ea136d75b2751ed178467cfd98e10747c Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 16:21:32 +1300 Subject: [PATCH 2/5] overload save and restore to fix broken test --- src/builtins/ThresholdPredictors.jl | 20 ++++++++++---------- test/builtins/ThresholdPredictors.jl | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/builtins/ThresholdPredictors.jl b/src/builtins/ThresholdPredictors.jl index 5e11e1a1..a64e08bf 100644 --- a/src/builtins/ThresholdPredictors.jl +++ b/src/builtins/ThresholdPredictors.jl @@ -307,16 +307,16 @@ end ## SERIALIZATION -# function MMI.save(model::ThresholdUnion, fitresult) -# atomic_fitresult, threshold = fitresult -# atom = model.model -# return MMI.save(atom, atomic_fitresult), threshold -# end -# function MMI.restore(model::ThresholdUnion, serializable_fitresult) -# atomic_serializable_fitresult, threshold = serializable_fitresult -# atom = model.model -# return MMI.restore(atom, atomic_serializable_fitresult), threshold -# end +function MMI.save(model::ThresholdUnion, fitresult) + atomic_fitresult, threshold = fitresult + atom = model.model + return MMI.save(atom, atomic_fitresult), threshold +end +function MMI.restore(model::ThresholdUnion, serializable_fitresult) + atomic_serializable_fitresult, threshold = serializable_fitresult + atom = model.model + return MMI.restore(atom, atomic_serializable_fitresult), threshold +end ## TRAITS diff --git a/test/builtins/ThresholdPredictors.jl b/test/builtins/ThresholdPredictors.jl index 6c3b2b82..8945436e 100644 --- a/test/builtins/ThresholdPredictors.jl +++ b/test/builtins/ThresholdPredictors.jl @@ -320,7 +320,7 @@ end seekstart(io) mach2 = MLJBase.machine(io) close(io) - @test_broken MLJBase.predict(mach2, (; x = rand(2))) == yhat + @test MLJBase.predict(mach2, (; x = rand(2))) == yhat end end # module From 784b160d5fd9e22c2247fe6d8ffc3ce51c38c7a5 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 16:25:04 +1300 Subject: [PATCH 3/5] bump 0.16.16 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e5d1e9a8..63ba7cfc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJModels" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" authors = ["Anthony D. Blaom "] -version = "0.16.15" +version = "0.16.16" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" From fa7297e06afd3b2332e91919b533067b2f47038d Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 16:41:43 +1300 Subject: [PATCH 4/5] rm nightly from ci --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cac96abe..9b5d87fb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,6 @@ jobs: version: - '1.6' - '1' - - 'nightly' os: - ubuntu-latest arch: From 5dc7eb282ec1774354a134ca503fe71f2f2e27fa Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 7 Mar 2024 07:57:01 +1300 Subject: [PATCH 5/5] view -> id --- test/builtins/ThresholdPredictors.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/builtins/ThresholdPredictors.jl b/test/builtins/ThresholdPredictors.jl index 8945436e..16a81f25 100644 --- a/test/builtins/ThresholdPredictors.jl +++ b/test/builtins/ThresholdPredictors.jl @@ -275,15 +275,15 @@ end thing = [] struct EphemeralClassifier <: MLJBase.Probabilistic end function MLJBase.fit(::EphemeralClassifier, verbosity, X, y) - # if I serialize/deserialized `thing` then `view` below changes: - view = objectid(thing) + # if I serialize/deserialized `thing` then `id` below changes: + id = objectid(thing) p = Distributions.fit(UnivariateFinite, y) - fitresult = (thing, view, p) + fitresult = (thing, id, p) return fitresult, nothing, NamedTuple() end function MLJBase.predict(::EphemeralClassifier, fitresult, X) - thing, view, p = fitresult - return view == objectid(thing) ? fill(p, MLJBase.nrows(X)) : + thing, id, p = fitresult + return id == objectid(thing) ? fill(p, MLJBase.nrows(X)) : throw(ErrorException("dead fitresult")) end MLJBase.target_scitype(::Type{<:EphemeralClassifier}) = AbstractVector{OrderedFactor{2}} @@ -293,8 +293,8 @@ function MLJBase.save(::EphemeralClassifier, fitresult) end function MLJBase.restore(::EphemeralClassifier, serialized_fitresult) thing, p = serialized_fitresult - view = objectid(thing) - return (thing, view, p) + id = objectid(thing) + return (thing, id, p) end # X, y = (; x = rand(8)), categorical(collect("OXXXXOOX"), ordered=true)