Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce EntityEmbeddings #267

Merged
merged 42 commits into from
Sep 9, 2024
Merged

Introduce EntityEmbeddings #267

merged 42 commits into from
Sep 9, 2024

Conversation

EssamWisam
Copy link
Collaborator

@EssamWisam EssamWisam commented Aug 3, 2024

Basic Description

This PR extends the NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor models of the MLJFlux.jl library such that:

  • These model now support tables with categorical columns. Iff any are present, an extra entity embedding layer is introduced after the input as described in the paper Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn.

  • It's possible, after training any of such models, to transform the categorical columns of a new sample, seen or unseen in training for the purposes of encoding the categorical column for a further model or transformer in a pipeline

See the updated documentation to learn more.

Implementation Plan

The following was my plan for implementing this feature which is more nontrivial than it may seem in the first glance.

  • Implement EntityEmbedder layer and test it with theory-based example
  • Introduce the EntityEmbedder layer to MLJFlux with more formal tests that compare a mathematical implementation to the actual one
  • Introduce OrdinalEncoding with tests (as well as multi-column encoding transformer from MLJTransforms for later)
  • Introduce functionality to infer categorical variables, number of levels and prepare EntityEmbedder input
  • Integrate ordinal encoding and preparing categorical embedding inputs in MLJInterface.fit
  • Let MLJInterface.fit insert EntityEmbedder into the model chain when needed (input has categorical columns)
  • Likewise, adapt MLJInterface.update accordingly
  • Adapt the predict and fitresult for ordinal encoding and storing embedding matrices in classifier.jl and regressor.jl
  • Separate the case where there is no entity embedding in MLJInterface.fit (refactoring)
  • Make categorical embedder completely transparent when no categorical variables are there (instead of being there free of parameters)
  • Refactor code in MLJModelsInterface.jl for less redundancy and more organization
  • Use better default for the new dimensionality of EntityEmbedder
  • Expose new dimensionality argument in types.jl
  • Allow transform for each method that accesses embedding matrices of the EntityEmbedder
  • Better variable names
  • Modify documentation to take into account EntityEmbedder (plan to also make a tutorial(s) later)
  • Ensure existing tests pass with no problem
  • Finish writing tests for the majority if not all functional components introduced (mainly entity-embedding.jl, entity-embedding-utils.jl and encoders.jl
  • Finish writing some end-to-end tests for entity embeddings over the four models

Copy link

codecov bot commented Aug 3, 2024

Codecov Report

Attention: Patch coverage is 96.82540% with 8 lines in your changes missing coverage. Please review.

Project coverage is 96.48%. Comparing base (70dff6e) to head (e5d8141).
Report is 2 commits behind head on dev.

Files with missing lines Patch % Lines
src/entity_embedding_utils.jl 94.82% 3 Missing ⚠️
src/encoders.jl 97.01% 2 Missing ⚠️
src/fit_utils.jl 92.85% 2 Missing ⚠️
src/mlj_model_interface.jl 97.82% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #267      +/-   ##
==========================================
+ Coverage   92.42%   96.48%   +4.06%     
==========================================
  Files          11       14       +3     
  Lines         330      512     +182     
==========================================
+ Hits          305      494     +189     
+ Misses         25       18       -7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@EssamWisam EssamWisam requested a review from ablaom August 6, 2024 00:10
Project.toml Outdated Show resolved Hide resolved
src/MLJFlux.jl Outdated Show resolved Hide resolved
src/MLJFlux.jl Outdated Show resolved Hide resolved
src/types.jl Show resolved Hide resolved
rng::Union{AbstractRNG, Int64}
optimiser_changes_trigger_retraining::Bool
acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()`
embedding_dims::Dict{Symbol, Real}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ablaom I handle differently depending on whether it's an integer or float as in the docs.

src/types.jl Outdated Show resolved Hide resolved
src/types.jl Outdated Show resolved Hide resolved
src/types.jl Show resolved Hide resolved
src/types.jl Outdated Show resolved Hide resolved
src/types.jl Show resolved Hide resolved
src/mlj_model_interface.jl Outdated Show resolved Hide resolved
Co-authored-by: Anthony Blaom, PhD <[email protected]>
Project.toml Outdated

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "ScientificTypes", "Test"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think MLJBase, which depends on ScientificTypes, re-exports all the public ScientificTypes methods, so you may be able to dump it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ablaom So replace using ScientificTypes: coerce, Multiclass, OrderedFactor with using MLJBase: coerce, Multiclass, OrderedFactor in the test file? This is minor because it's in the test only right and it shouldn't redownload already downloaded package?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think using MLJBase suffices. All those objects are re-exported.

@@ -47,6 +50,7 @@ A private method that returns the shape of the input and output of the model for
data `X` and `y`.
"""
shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y))
is_embedding_enabled_type(::MultitargetNeuralNetworkRegressor) = true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you are now acting on instances instead of types, I'd change the name of your trait from is_embedding_enabled_type to is_embedding_enabled, but this is just a suggestion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing.

@@ -125,7 +127,8 @@ function MLJModelInterface.transform(
fitresult,
Xnew,
)
is_embedding_enabled_type(transformer) || return Xnew
# if it doesn't have the property its not an entity-enabled model
hasproperty(transformer, :embedding_dims) || return Xnew
ordinal_mappings, embedding_matrices = fitresult[3:4]
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
Xnew_transf = embedding_transform(Xnew, embedding_matrices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we have any test for transform(::MLJFluxModel, ...) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think transform is tested in entity_embedding.jl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see it's covered in Codecov.

Can I ask that this method (and corresponding test) be moved to "mlj_model_interface.jl", where the other implementations of MLJModelInterface methods (such as fit and predict) live?

src/image.jl Outdated Show resolved Hide resolved
src/mlj_model_interface.jl Outdated Show resolved Hide resolved
Copy link
Collaborator

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the speedy response to my review.

A couple of minor points left. Note in particular the need to annotate the type of model (transformer) in MLJModelInterface.transform overloading. And we should have a test for this method for one of the models.

@EssamWisam
Copy link
Collaborator Author

@ablaom have addressed everything (except removing ScientificTypes from tests). I noticed that after adding categorical variables to integration tests, GPU testing fails :(. Will look into that when I get the chance.

@ablaom
Copy link
Collaborator

ablaom commented Sep 1, 2024

@EssamWisam You can try raising the tolerance near test/classifier.jl:114, which seems to be the fail.

@EssamWisam
Copy link
Collaborator Author

Hooray it worked 🎉 @ablaom

Copy link
Collaborator

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates.

Just a couple more details, as flagged.

@EssamWisam
Copy link
Collaborator Author

@ablaom let's finalize this!

Copy link
Collaborator

@ablaom ablaom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again @EssamWisam for this valuable contribution.

Just made a few doc string tweaks after reviewing these again. Ready to go. 🎉

@ablaom ablaom merged commit 945016d into dev Sep 9, 2024
6 checks passed
@ablaom ablaom deleted the entity-embeddings branch September 9, 2024 00:52
@ablaom ablaom mentioned this pull request Sep 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants