Skip to content

Commit

Permalink
Merge pull request #257 from FluxML/tiemvanderdeure-binaryclassifier
Browse files Browse the repository at this point in the history
Add binary classifier
  • Loading branch information
ablaom authored Jun 10, 2024
2 parents 3ac787e + 7eae840 commit 05be138
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 66 deletions.
4 changes: 4 additions & 0 deletions docs/src/interface/Classification.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
```@docs
MLJFlux.NeuralNetworkClassifier
```

```@docs
MLJFlux.NeuralNetworkBinaryClassifier
```
1 change: 1 addition & 0 deletions docs/src/interface/Summary.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Model Type | Prediction type | `scitype(X) <: _` | `scitype(y) <: _`
`NeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `AbstractVector{<:Continuous)` (`n_out = 1`)
`MultitargetNeuralNetworkRegressor` | `Deterministic` | `Table(Continuous)` with `n_in` columns | `<: Table(Continuous)` with `n_out` columns
`NeuralNetworkClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite}` with `n_out` classes
`NeuralNetworkBinaryClassifier` | `Probabilistic` | `<:Table(Continuous)` with `n_in` columns | `AbstractVector{<:Finite{2}}` (`n_out = 2`)
`ImageClassifier` | `Probabilistic` | `AbstractVector(<:Image{W,H})` with `n_in = (W, H)` | `AbstractVector{<:Finite}` with `n_out` classes


Expand Down
2 changes: 1 addition & 1 deletion src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include("image.jl")
include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, ImageClassifier
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1

include("deprecated.jl")
Expand Down
61 changes: 49 additions & 12 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""
shape(model::NeuralNetworkClassifier, X, y)
A private method that returns the shape of the input and output of the model for given data `X` and `y`.
A private method that returns the shape of the input and output of the model for given
data `X` and `y`.
"""
function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
Expand All @@ -14,26 +16,61 @@ function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
end

# builds the end-to-end Flux chain needed, given the `model` and `shape`:
MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) =
Flux.Chain(build(model.builder, rng, shape...),
model.finaliser)
MLJFlux.build(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
rng,
shape,
) = Flux.Chain(build(model.builder, rng, shape...), model.finaliser)

# returns the model `fitresult` (see "Adding Models for General Use"
# section of the MLJ manual) which must always have the form `(chain,
# metadata)`, where `metadata` is anything extra needed by `predict`:
MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) =
(chain, MLJModelInterface.classes(y[1]))
MLJFlux.fitresult(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
chain,
y,
) = (chain, MLJModelInterface.classes(y[1]))

function MLJModelInterface.predict(model::NeuralNetworkClassifier,
function MLJModelInterface.predict(
model::NeuralNetworkClassifier,
fitresult,
Xnew)
Xnew,
)
chain, levels = fitresult
X = reformat(Xnew)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end

MLJModelInterface.metadata_model(NeuralNetworkClassifier,
input=Union{AbstractMatrix{Continuous},Table(Continuous)},
target=AbstractVector{<:Finite},
path="MLJFlux.NeuralNetworkClassifier")
MLJModelInterface.metadata_model(
NeuralNetworkClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite},
load_path="MLJFlux.NeuralNetworkClassifier",
)

#### Binary Classifier

function MLJFlux.shape(model::NeuralNetworkBinaryClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
n_input = Tables.schema(X).names |> length
return (n_input, 1) # n_output is always 1 for a binary classifier
end

function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew,
)
chain, levels = fitresult
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(
NeuralNetworkBinaryClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite{2}},
load_path="MLJFlux.NeuralNetworkBinaryClassifier",
)
6 changes: 6 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,9 @@ function collate(model, X, y)
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
end
Loading

0 comments on commit 05be138

Please sign in to comment.