Skip to content

Commit

Permalink
Merge pull request #48 from JuliaAI/add-tests-categorical-features
Browse files Browse the repository at this point in the history
Add generic MLJ interface tests for categorical features
  • Loading branch information
ablaom authored Feb 28, 2024
2 parents be1a390 + 886deee commit b4c5f65
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,41 @@ using MLJGLMInterface
using GLM: coeftable
import GLM
import MLJTestInterface
using Tables

using Distributions: Normal, Poisson, Uniform
import StableRNGs
using Tables

expit(X) = 1 ./ (1 .+ exp.(-X))

# TODO: Add more datasets to the following generic interface tests after #45 is merged
# synthesize small data sets with mixed features:

n = 100
X_regression, y_regression = MLJBase.make_regression(n, 3)
outlook = categorical(rand(["sunny", "overcast", "rainy"], n))
temperature = categorical(
rand(["cold", "mild", "hot"], n),
ordered=true,
levels=["cold", "mild", "hot"],
)
X = merge(
Tables.columntable(X_regression),
(; outlook, temperature),
)
y_binary = categorical((temperature .== "mild") .| (outlook .== "sunny"))
y_count = map(X.x1) do x
floor(Int, 10*abs(x))
end
mixed_binary = (X, y_binary)
mixed_count = (X, y_count)
mixed_regression = (X, y_regression)

@testset "generic interface tests" begin
@testset "LinearRegressor" begin
for data in [
MLJTestInterface.make_regression(),
MLJTestInterface.make_regression(),
mixed_regression,
]
failures, summary = MLJTestInterface.test(
[LinearRegressor,],
Expand All @@ -36,6 +58,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X))
@testset "LinearCountRegressor" begin
for data in [
MLJTestInterface.make_count(),
mixed_count,
]
failures, summary = MLJTestInterface.test(
[LinearCountRegressor,],
Expand All @@ -51,6 +74,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X))
@testset "LinearBinaryClassifier" begin
for data in [
MLJTestInterface.make_binary(),
mixed_binary,
]
failures, summary = MLJTestInterface.test(
[LinearBinaryClassifier,],
Expand Down

0 comments on commit b4c5f65

Please sign in to comment.