-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce EntityEmbeddings #267
Merged
Merged
Changes from 34 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
fb98c27
✅ Add CategoricalEmbedding layer with documentation and tests
EssamWisam 3e67625
➕ Add ordinal encoder
EssamWisam d2fc2a9
➕ Add embedding transformer and more test encoder tests
EssamWisam 41ca8e0
𝚫 Override fitresult and input scitypes for the four models
EssamWisam afab1e0
𝚫 Change name of embedding layer
EssamWisam 86fa74c
👨💻 Refactor fit and update in mlj_model_interface
EssamWisam 33cb7d9
🔥 Delete replaced file
EssamWisam cf30f90
➕ Include necessary files in MLJFlux.jl
EssamWisam 48b1ce2
✅ Expose embedding_dims hyperparameter and update docs
EssamWisam d1e65c0
✅ Finish adding entity embedding tests
EssamWisam af91ee9
👨🔧 Fix small error
EssamWisam e7b0fff
🔥 Delete old file
EssamWisam 833b845
🤔 Make all input tables float
EssamWisam 962bbda
🤔 Further ensure all inputs are float
EssamWisam 4ff7cbe
Update src/types.jl
EssamWisam 878a905
Update src/mlj_model_interface.jl
EssamWisam c90cf6e
Update src/MLJFlux.jl
EssamWisam 4547f2a
Update src/classifier.jl
EssamWisam e8e87cd
Update src/regressor.jl
EssamWisam 9a836b2
Update src/regressor.jl
EssamWisam 6fb81ff
✅ Fix type for binary classifier
EssamWisam 3a8bdcb
➕ No longer export EntityEmbedder
EssamWisam f7abf07
✍🏻 Improve docstring
EssamWisam bbe121f
🚨 Get rid of ScientificTypes
EssamWisam 706b630
Follow up
EssamWisam b0d8e44
🔼 Better docstrings
EssamWisam 4458627
✅ Add missing dependency
EssamWisam f1f7dfe
✅ Improve docstring
EssamWisam daa4b2a
✅ Better default for embedding dims
EssamWisam 59c9bb9
➕ Include in top-level only
EssamWisam 749a247
👨🔧 Fix docstring spacing
EssamWisam 46577f7
✅ Enforce columns
EssamWisam e5d8141
✅ Add trait for embedding model
EssamWisam f60f789
✅ Get rid of case distinction
EssamWisam 5fd85a3
🌟 Introduce integration tests
EssamWisam 2cc227a
Add `Tables.columns`
EssamWisam eac7727
✅ Bring back is_entity_enabled and more
EssamWisam cc56439
✅ Raise tolerance
EssamWisam 407edc5
🫧 Some final polishing
EssamWisam 21a2fdd
tweak document strings
ablaom e0a2eb8
tag some methods as private to avoid any possible ambiguity
ablaom 310cb12
bump 0.6.0
ablaom File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
""" | ||
File containing ordinal encoder and entity embedding encoder. Borrows code from the MLJTransforms package. | ||
""" | ||
|
||
### Ordinal Encoder | ||
""" | ||
**Private Method** | ||
|
||
Fits an ordinal encoder to the table `X`, using only the columns with indices in `featinds`. | ||
|
||
Returns a dictionary mapping each column index to a dictionary mapping each level in that column to an integer. | ||
""" | ||
function ordinal_encoder_fit(X; featinds) | ||
# 1. Define mapping per column per level dictionary | ||
mapping_matrix = Dict() | ||
|
||
# 2. Use feature mapper to compute the mapping of each level in each column | ||
for i in featinds | ||
feat_col = Tables.getcolumn(Tables.columns(X), i) | ||
feat_levels = levels(feat_col) | ||
# Check if feat levels is already ordinal encoded in which case we skip | ||
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue | ||
# Compute the dict using the given feature_mapper function | ||
mapping_matrix[i] = | ||
Dict{Any, AbstractFloat}( | ||
value => float(index) for (index, value) in enumerate(feat_levels) | ||
) | ||
end | ||
return mapping_matrix | ||
end | ||
|
||
""" | ||
**Private Method** | ||
|
||
Checks that all levels in `test_levels` are also in `train_levels`. If not, throws an error. | ||
""" | ||
function check_unkown_levels(train_levels, test_levels) | ||
# test levels must be a subset of train levels | ||
if !issubset(test_levels, train_levels) | ||
# get the levels in test that are not in train | ||
lost_levels = setdiff(test_levels, train_levels) | ||
error( | ||
"While transforming, found novel levels for the column: $(lost_levels) that were not seen while training.", | ||
) | ||
end | ||
end | ||
|
||
""" | ||
**Private Method** | ||
|
||
Transforms the table `X` using the ordinal encoder defined by `mapping_matrix`. | ||
|
||
Returns a new table with the same column names as `X`, but with categorical columns replaced by integer columns. | ||
""" | ||
function ordinal_encoder_transform(X, mapping_matrix) | ||
isnothing(mapping_matrix) && return X | ||
isempty(mapping_matrix) && return X | ||
feat_names = Tables.schema(X).names | ||
numfeats = length(feat_names) | ||
new_feats = [] | ||
for ind in 1:numfeats | ||
col = Tables.getcolumn(Tables.columns(X), ind) | ||
|
||
# Create the transformation function for each column | ||
if ind in keys(mapping_matrix) | ||
train_levels = keys(mapping_matrix[ind]) | ||
test_levels = levels(col) | ||
check_unkown_levels(train_levels, test_levels) | ||
level2scalar = mapping_matrix[ind] | ||
new_col = recode(col, level2scalar...) | ||
push!(new_feats, new_col) | ||
else | ||
push!(new_feats, col) | ||
end | ||
end | ||
|
||
transformed_X = NamedTuple{tuple(feat_names...)}(tuple(new_feats)...) | ||
# Attempt to preserve table type | ||
transformed_X = Tables.materializer(X)(transformed_X) | ||
return transformed_X | ||
end | ||
|
||
""" | ||
**Private Method** | ||
|
||
Combine ordinal_encoder_fit and ordinal_encoder_transform and return both X and ordinal_mappings | ||
""" | ||
function ordinal_encoder_fit_transform(X; featinds) | ||
ordinal_mappings = ordinal_encoder_fit(X; featinds = featinds) | ||
return ordinal_encoder_transform(X, ordinal_mappings), ordinal_mappings | ||
end | ||
|
||
|
||
|
||
## Entity Embedding Encoder (assuming precomputed weights) | ||
""" | ||
**Private method.** | ||
|
||
Function to generate new feature names: feat_name_0, feat_name_1,..., feat_name_n | ||
""" | ||
function generate_new_feat_names(feat_name, num_inds, existing_names) | ||
conflict = true # will be kept true as long as there is a conflict | ||
count = 1 # number of conflicts+1 = number of underscores | ||
|
||
new_column_names = [] | ||
while conflict | ||
suffix = repeat("_", count) | ||
new_column_names = [Symbol("$(feat_name)$(suffix)$i") for i in 1:num_inds] | ||
conflict = any(name -> name in existing_names, new_column_names) | ||
count += 1 | ||
end | ||
return new_column_names | ||
end | ||
|
||
|
||
""" | ||
Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform | ||
each level in each categorical columns using the columns of the matrix. | ||
|
||
This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings. | ||
""" | ||
function embedding_transform(X, mapping_matrices) | ||
(isempty(mapping_matrices)) && return X | ||
feat_names = Tables.schema(X).names | ||
new_feat_names = Symbol[] | ||
new_cols = [] | ||
for feat_name in feat_names | ||
col = Tables.getcolumn(Tables.columns(X), feat_name) | ||
# Create the transformation function for each column | ||
if feat_name in keys(mapping_matrices) | ||
level2vector = mapping_matrices[feat_name] | ||
new_multi_col = map(x -> level2vector[:, Int.(unwrap(x))], col) | ||
new_multi_col = [col for col in eachrow(hcat(new_multi_col...))] | ||
push!(new_cols, new_multi_col...) | ||
feat_names_with_inds = generate_new_feat_names( | ||
feat_name, | ||
size(level2vector, 1), | ||
feat_names, | ||
) | ||
push!(new_feat_names, feat_names_with_inds...) | ||
else | ||
# Not to be transformed => left as is | ||
push!(new_feat_names, feat_name) | ||
push!(new_cols, col) | ||
end | ||
end | ||
|
||
transformed_X = NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...) | ||
# Attempt to preserve table type | ||
transformed_X = Tables.materializer(X)(transformed_X) | ||
return transformed_X | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,74 @@ | ||
# This is just some experimental code | ||
# to implement EntityEmbeddings for purely | ||
# categorical features | ||
""" | ||
A layer that implements entity embedding layers as presented in 'Entity Embeddings of | ||
Categorical Variables by Cheng Guo, Felix Berkhahn'. Expects a matrix of dimensions (numfeats, batchsize) | ||
and applies entity embeddings to each specified categorical feature. Other features will be left as is. | ||
|
||
# Arguments | ||
- `entityprops`: a vector of named tuples each of the form `(index=..., levels=..., newdim=...)` to | ||
specify the feature index, the number of levels and the desired embeddings dimensionality for selected features of the input. | ||
- `numfeats`: the number of features in the input. | ||
|
||
using Flux | ||
# Example | ||
```julia | ||
# Prepare a batch of four features where the 2nd and the 4th are categorical | ||
batch = [ | ||
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1; | ||
1 2 3 4 5 6 7 8 9 10; | ||
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1 | ||
1 1 2 2 1 1 2 2 1 1; | ||
] | ||
|
||
mutable struct EmbeddingMatrix | ||
e | ||
levels | ||
|
||
function EmbeddingMatrix(levels; dim=4) | ||
if dim <= 0 | ||
dimension = div(length(levels), 2) | ||
else | ||
dimension = min(length(levels), dim) # Dummy function for now | ||
end | ||
return new(Dense(length(levels), dimension), levels), dimension | ||
end | ||
entityprops = [ | ||
(index=2, levels=10, newdim=2), | ||
(index=4, levels=2, newdim=1) | ||
] | ||
numfeats = 4 | ||
|
||
# Run it through the categorical embedding layer | ||
embedder = EntityEmbedder(entityprops, 4) | ||
julia> output = embedder(batch) | ||
5×10 Matrix{Float64}: | ||
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1 | ||
-1.27129 -0.417667 -1.40326 -0.695701 0.371741 1.69952 -1.40034 -2.04078 | ||
-0.166796 0.657619 -0.659249 -0.337757 -0.717179 -0.0176273 -1.2817 -0.0372752 | ||
0.9 0.1 0.4 0.5 0.8 0.9 1.0 1.1 | ||
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354 | ||
``` | ||
""" # 1. Define layer struct to hold parameters | ||
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} | ||
|
||
embedders::A1 | ||
modifiers::A2 # applied on the input before passing it to the embedder | ||
numfeats::I | ||
end | ||
|
||
Flux.@treelike EmbeddingMatrix | ||
|
||
function (embed::EmbeddingMatrix)(ip) | ||
return embed.e(Flux.onehot(ip, embed.levels)) | ||
end | ||
# 2. Define the forward pass (i.e., calling an instance of the layer) | ||
(m::EntityEmbedder)(x) = | ||
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...) | ||
|
||
mutable struct EntityEmbedding | ||
embeddingmatrix | ||
# 3. Define the constructor which initializes the parameters and returns the instance | ||
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) | ||
embedders = [] | ||
modifiers = [] | ||
# Setup entityprops | ||
cat_inds = [entityprop.index for entityprop in entityprops] | ||
levels_per_feat = [entityprop.levels for entityprop in entityprops] | ||
newdims = [entityprop.newdim for entityprop in entityprops] | ||
|
||
function EntityEmbedding(a...) | ||
return new(a) | ||
c = 1 | ||
for i in 1:numfeats | ||
if i in cat_inds | ||
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init = init)) | ||
push!(modifiers, (x, i) -> Int.(x[i, :])) | ||
c += 1 | ||
else | ||
push!(embedders, feat -> feat) | ||
push!(modifiers, (x, i) -> x[i:i, :]) | ||
end | ||
end | ||
end | ||
|
||
Flux.@treelike EntityEmbedding | ||
|
||
|
||
# ip is an array of tuples | ||
function (embed::EntityEmbedding)(ip) | ||
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...) | ||
EntityEmbedder(embedders, modifiers, numfeats) | ||
end | ||
|
||
|
||
# Q1. How should this be called in the API? | ||
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5) | ||
# | ||
# | ||
# | ||
# 4. Register it as layer with Flux | ||
Flux.@layer EntityEmbedder |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think MLJBase, which depends on ScientificTypes, re-exports all the public ScientificTypes methods, so you may be able to dump it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ablaom So replace
using ScientificTypes: coerce, Multiclass, OrderedFactor
withusing MLJBase: coerce, Multiclass, OrderedFactor
in the test file? This is minor because it's in the test only right and it shouldn't redownload already downloaded package?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think
using MLJBase
suffices. All those objects are re-exported.