Skip to content

Commit

Permalink
fix Resampler update bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jan 23, 2024
1 parent 193ad67 commit 9742d44
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 9 additions & 2 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1591,8 +1591,15 @@ function MLJModelInterface.update(
mach, e = fitresult
train_test_rows = e.train_test_rows

measures = e.measure
operations = e.operation
# since `resampler.model` could have changed, so might the actual measures and
# operations that should be passed to the (low level) `evaluate!`:
measures = _actual_measures(resampler.measure, resampler.model)
operations = _actual_operations(
resampler.operation,
measures,
resampler.model,
verbosity
)

# update the model:
mach2 = _update!(mach, resampler.model)
Expand Down
20 changes: 17 additions & 3 deletions test/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#module TestResampling

using Distributed
import ComputationalResources: CPU1, CPUProcesses, CPUThreads
using .TestUtilities
Expand Down Expand Up @@ -876,5 +874,21 @@ end
@test contains(printed_evaluations, "N/A")
end

#end
@testset_accelerated "issue with Resampler #954" acceleration begin
knn = KNNClassifier()
cnst =DeterministicConstantClassifier()
X, y = make_blobs(10)

resampler = MLJBase.Resampler(
;model=knn,
measure=accuracy,
operation=nothing,
acceleration,
)
mach = machine(resampler, X, y) |> fit!

resampler.model = cnst
fit!(mach)
end

true

0 comments on commit 9742d44

Please sign in to comment.