-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce EntityEmbeddings #267
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## dev #267 +/- ##
==========================================
+ Coverage 92.42% 96.48% +4.06%
==========================================
Files 11 14 +3
Lines 330 512 +182
==========================================
+ Hits 305 494 +189
+ Misses 25 18 -7 ☔ View full report in Codecov by Sentry. |
The two classifiers and two regressors.
rng::Union{AbstractRNG, Int64} | ||
optimiser_changes_trigger_retraining::Bool | ||
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` | ||
embedding_dims::Dict{Symbol, Real} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ablaom I handle differently depending on whether it's an integer or float as in the docs.
Co-authored-by: Anthony Blaom, PhD <[email protected]>
Project.toml
Outdated
|
||
[targets] | ||
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"] | ||
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "ScientificTypes", "Test"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think MLJBase, which depends on ScientificTypes, re-exports all the public ScientificTypes methods, so you may be able to dump it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ablaom So replace using ScientificTypes: coerce, Multiclass, OrderedFactor
with using MLJBase: coerce, Multiclass, OrderedFactor
in the test file? This is minor because it's in the test only right and it shouldn't redownload already downloaded package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think using MLJBase
suffices. All those objects are re-exported.
@@ -47,6 +50,7 @@ A private method that returns the shape of the input and output of the model for | |||
data `X` and `y`. | |||
""" | |||
shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) | |||
is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you are now acting on instances instead of types, I'd change the name of your trait from is_embedding_enabled_type
to is_embedding_enabled
, but this is just a suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing.
src/entity_embedding_utils.jl
Outdated
@@ -125,7 +127,8 @@ function MLJModelInterface.transform( | |||
fitresult, | |||
Xnew, | |||
) | |||
is_embedding_enabled_type(transformer) || return Xnew | |||
# if it doesn't have the property its not an entity-enabled model | |||
hasproperty(transformer, :embedding_dims) || return Xnew | |||
ordinal_mappings, embedding_matrices = fitresult[3:4] | |||
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) | |||
Xnew_transf = embedding_transform(Xnew, embedding_matrices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do we have any test for transform(::MLJFluxModel, ...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think transform
is tested in entity_embedding.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I see it's covered in Codecov.
Can I ask that this method (and corresponding test) be moved to "mlj_model_interface.jl", where the other implementations of MLJModelInterface
methods (such as fit
and predict
) live?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the speedy response to my review.
A couple of minor points left. Note in particular the need to annotate the type of model (transformer) in MLJModelInterface.transform
overloading. And we should have a test for this method for one of the models.
@ablaom have addressed everything (except removing ScientificTypes from tests). I noticed that after adding categorical variables to integration tests, GPU testing fails :(. Will look into that when I get the chance. |
@EssamWisam You can try raising the tolerance near test/classifier.jl:114, which seems to be the fail. |
Hooray it worked 🎉 @ablaom |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates.
Just a couple more details, as flagged.
@ablaom let's finalize this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again @EssamWisam for this valuable contribution.
Just made a few doc string tweaks after reviewing these again. Ready to go. 🎉
Basic Description
This PR extends the
NeuralNetworkClassifier
,NeuralNetworkBinaryClassifier
,NeuralNetworkRegressor
,MultitargetNeuralNetworkRegressor
models of theMLJFlux.jl
library such that:These model now support tables with categorical columns. Iff any are present, an extra entity embedding layer is introduced after the input as described in the paper Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn.
It's possible, after training any of such models, to transform the categorical columns of a new sample, seen or unseen in training for the purposes of encoding the categorical column for a further model or transformer in a pipeline
See the updated documentation to learn more.
Implementation Plan
The following was my plan for implementing this feature which is more nontrivial than it may seem in the first glance.
MLJFlux
with more formal tests that compare a mathematical implementation to the actual oneEntityEmbedder
inputMLJInterface.fit
insertEntityEmbedder
into the model chain when needed (input has categorical columns)MLJInterface.update
accordinglyclassifier.jl
andregressor.jl
MLJInterface.fit
(refactoring)MLJModelsInterface.jl
for less redundancy and more organizationEntityEmbedder
types.jl
EntityEmbedder
EntityEmbedder
(plan to also make a tutorial(s) later)entity-embedding.jl
,entity-embedding-utils.jl
andencoders.jl