Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce EntityEmbeddings #267

Merged
merged 42 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
fb98c27
✅ Add CategoricalEmbedding layer with documentation and tests
EssamWisam Aug 3, 2024
3e67625
➕ Add ordinal encoder
EssamWisam Aug 4, 2024
d2fc2a9
➕ Add embedding transformer and more test encoder tests
EssamWisam Aug 5, 2024
41ca8e0
𝚫 Override fitresult and input scitypes for the four models
EssamWisam Aug 5, 2024
afab1e0
𝚫 Change name of embedding layer
EssamWisam Aug 5, 2024
86fa74c
👨‍💻 Refactor fit and update in mlj_model_interface
EssamWisam Aug 5, 2024
33cb7d9
🔥 Delete replaced file
EssamWisam Aug 5, 2024
cf30f90
➕ Include necessary files in MLJFlux.jl
EssamWisam Aug 5, 2024
48b1ce2
✅ Expose embedding_dims hyperparameter and update docs
EssamWisam Aug 5, 2024
d1e65c0
✅ Finish adding entity embedding tests
EssamWisam Aug 5, 2024
af91ee9
👨‍🔧 Fix small error
EssamWisam Aug 5, 2024
e7b0fff
🔥 Delete old file
EssamWisam Aug 6, 2024
833b845
🤔 Make all input tables float
EssamWisam Aug 6, 2024
962bbda
🤔 Further ensure all inputs are float
EssamWisam Aug 6, 2024
4ff7cbe
Update src/types.jl
EssamWisam Aug 7, 2024
878a905
Update src/mlj_model_interface.jl
EssamWisam Aug 7, 2024
c90cf6e
Update src/MLJFlux.jl
EssamWisam Aug 31, 2024
4547f2a
Update src/classifier.jl
EssamWisam Aug 31, 2024
e8e87cd
Update src/regressor.jl
EssamWisam Aug 31, 2024
9a836b2
Update src/regressor.jl
EssamWisam Aug 31, 2024
6fb81ff
✅ Fix type for binary classifier
EssamWisam Aug 31, 2024
3a8bdcb
➕ No longer export EntityEmbedder
EssamWisam Sep 1, 2024
f7abf07
✍🏻 Improve docstring
EssamWisam Sep 1, 2024
bbe121f
🚨 Get rid of ScientificTypes
EssamWisam Sep 1, 2024
706b630
Follow up
EssamWisam Sep 1, 2024
b0d8e44
🔼 Better docstrings
EssamWisam Sep 1, 2024
4458627
✅ Add missing dependency
EssamWisam Sep 1, 2024
f1f7dfe
✅ Improve docstring
EssamWisam Sep 1, 2024
daa4b2a
✅ Better default for embedding dims
EssamWisam Sep 1, 2024
59c9bb9
➕ Include in top-level only
EssamWisam Sep 1, 2024
749a247
👨‍🔧 Fix docstring spacing
EssamWisam Sep 1, 2024
46577f7
✅ Enforce columns
EssamWisam Sep 1, 2024
e5d8141
✅ Add trait for embedding model
EssamWisam Sep 1, 2024
f60f789
✅ Get rid of case distinction
EssamWisam Sep 1, 2024
5fd85a3
🌟 Introduce integration tests
EssamWisam Sep 1, 2024
2cc227a
Add `Tables.columns`
EssamWisam Sep 1, 2024
eac7727
✅ Bring back is_entity_enabled and more
EssamWisam Sep 1, 2024
cc56439
✅ Raise tolerance
EssamWisam Sep 1, 2024
407edc5
🫧 Some final polishing
EssamWisam Sep 8, 2024
21a2fdd
tweak document strings
ablaom Sep 9, 2024
e0a2eb8
tag some methods as private to avoid any possible ambiguity
ablaom Sep 9, 2024
310cb12
bump 0.6.0
ablaom Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "ScientificTypes", "Test"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think MLJBase, which depends on ScientificTypes, re-exports all the public ScientificTypes methods, so you may be able to dump it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ablaom So replace using ScientificTypes: coerce, Multiclass, OrderedFactor with using MLJBase: coerce, Multiclass, OrderedFactor in the test file? This is minor because it's in the test only right and it shouldn't redownload already downloaded package?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think using MLJBase suffices. All those objects are re-exported.

11 changes: 6 additions & 5 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module MLJFlux

export CUDALibs, CPU1

import Flux
using MLJModelInterface
using MLJModelInterface.ScientificTypesBase
Expand All @@ -17,22 +16,24 @@ import Metalhead
import Optimisers

include("utilities.jl")
const MMI=MLJModelInterface
const MMI = MLJModelInterface

include("encoders.jl")
include("entity_embedding.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
include("core.jl")
include("regressor.jl")
include("classifier.jl")
include("image.jl")
include("fit_utils.jl")
include("entity_embedding_utils.jl")
include("mlj_model_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export CUDALibs, CPU1

include("deprecated.jl")


end #module
end # module
28 changes: 16 additions & 12 deletions src/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

A private method that returns the shape of the input and output of the model for given
data `X` and `y`.

"""
function MLJFlux.shape(model::NeuralNetworkClassifier, X, y)
X = X isa Matrix ? Tables.table(X) : X
Expand All @@ -29,24 +28,28 @@ MLJFlux.fitresult(
model::Union{NeuralNetworkClassifier, NeuralNetworkBinaryClassifier},
chain,
y,
) = (chain, MLJModelInterface.classes(y[1]))
ordinal_mappings = nothing,
embedding_matrices = nothing,
) = (chain, MLJModelInterface.classes(y[1]), ordinal_mappings, embedding_matrices)

function MLJModelInterface.predict(
model::NeuralNetworkClassifier,
fitresult,
Xnew,
)
chain, levels = fitresult
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) # what if Xnew is a matrix
X = reformat(Xnew)
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
return MLJModelInterface.UnivariateFinite(levels, probs)
end


MLJModelInterface.metadata_model(
NeuralNetworkClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite},
load_path="MLJFlux.NeuralNetworkClassifier",
input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)},
target_scitype = AbstractVector{<:Finite},
load_path = "MLJFlux.NeuralNetworkClassifier",
)

#### Binary Classifier
Expand All @@ -61,16 +64,17 @@ function MLJModelInterface.predict(
model::NeuralNetworkBinaryClassifier,
fitresult,
Xnew,
)
chain, levels = fitresult
)
chain, levels, ordinal_mappings, _ = fitresult
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
X = reformat(Xnew)
probs = vec(chain(X))
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
end

MLJModelInterface.metadata_model(
NeuralNetworkBinaryClassifier,
input_scitype=Union{AbstractMatrix{Continuous},Table(Continuous)},
target_scitype=AbstractVector{<:Finite{2}},
load_path="MLJFlux.NeuralNetworkBinaryClassifier",
input_scitype = Union{AbstractMatrix{Continuous}, Table(Continuous, Finite)},
target_scitype = AbstractVector{<:Finite{2}},
load_path = "MLJFlux.NeuralNetworkBinaryClassifier",
)
152 changes: 152 additions & 0 deletions src/encoders.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
File containing ordinal encoder and entity embedding encoder. Borrows code from the MLJTransforms package.
"""

### Ordinal Encoder
"""
**Private Method**

Fits an ordinal encoder to the table `X`, using only the columns with indices in `featinds`.

Returns a dictionary mapping each column index to a dictionary mapping each level in that column to an integer.
"""
function ordinal_encoder_fit(X; featinds)
# 1. Define mapping per column per level dictionary
mapping_matrix = Dict()

# 2. Use feature mapper to compute the mapping of each level in each column
for i in featinds
feat_col = Tables.getcolumn(Tables.columns(X), i)
feat_levels = levels(feat_col)
# Check if feat levels is already ordinal encoded in which case we skip
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, AbstractFloat}(
value => float(index) for (index, value) in enumerate(feat_levels)
)
end
return mapping_matrix
end

"""
**Private Method**

Checks that all levels in `test_levels` are also in `train_levels`. If not, throws an error.
"""
function check_unkown_levels(train_levels, test_levels)
# test levels must be a subset of train levels
if !issubset(test_levels, train_levels)
# get the levels in test that are not in train
lost_levels = setdiff(test_levels, train_levels)
error(
"While transforming, found novel levels for the column: $(lost_levels) that were not seen while training.",
)
end
end

"""
**Private Method**

Transforms the table `X` using the ordinal encoder defined by `mapping_matrix`.

Returns a new table with the same column names as `X`, but with categorical columns replaced by integer columns.
"""
function ordinal_encoder_transform(X, mapping_matrix)
isnothing(mapping_matrix) && return X
isempty(mapping_matrix) && return X
feat_names = Tables.schema(X).names
numfeats = length(feat_names)
new_feats = []
for ind in 1:numfeats
col = Tables.getcolumn(Tables.columns(X), ind)

# Create the transformation function for each column
if ind in keys(mapping_matrix)
train_levels = keys(mapping_matrix[ind])
test_levels = levels(col)
check_unkown_levels(train_levels, test_levels)
level2scalar = mapping_matrix[ind]
new_col = recode(col, level2scalar...)
push!(new_feats, new_col)
else
push!(new_feats, col)
end
end

transformed_X = NamedTuple{tuple(feat_names...)}(tuple(new_feats)...)
# Attempt to preserve table type
transformed_X = Tables.materializer(X)(transformed_X)
return transformed_X
end

"""
**Private Method**

Combine ordinal_encoder_fit and ordinal_encoder_transform and return both X and ordinal_mappings
"""
function ordinal_encoder_fit_transform(X; featinds)
ordinal_mappings = ordinal_encoder_fit(X; featinds = featinds)
return ordinal_encoder_transform(X, ordinal_mappings), ordinal_mappings
end



## Entity Embedding Encoder (assuming precomputed weights)
"""
**Private method.**

Function to generate new feature names: feat_name_0, feat_name_1,..., feat_name_n
"""
function generate_new_feat_names(feat_name, num_inds, existing_names)
conflict = true # will be kept true as long as there is a conflict
count = 1 # number of conflicts+1 = number of underscores

new_column_names = []
while conflict
suffix = repeat("_", count)
new_column_names = [Symbol("$(feat_name)$(suffix)$i") for i in 1:num_inds]
conflict = any(name -> name in existing_names, new_column_names)
count += 1
end
return new_column_names
end


"""
Given X and a dict of mapping_matrices that map each categorical column to a matrix, use the matrix to transform
each level in each categorical columns using the columns of the matrix.

This is used with the embedding matrices of the entity embedding layer in entity enabled models to implement entity embeddings.
"""
function embedding_transform(X, mapping_matrices)
(isempty(mapping_matrices)) && return X
feat_names = Tables.schema(X).names
new_feat_names = Symbol[]
new_cols = []
for feat_name in feat_names
col = Tables.getcolumn(Tables.columns(X), feat_name)
# Create the transformation function for each column
if feat_name in keys(mapping_matrices)
level2vector = mapping_matrices[feat_name]
new_multi_col = map(x -> level2vector[:, Int.(unwrap(x))], col)
new_multi_col = [col for col in eachrow(hcat(new_multi_col...))]
push!(new_cols, new_multi_col...)
feat_names_with_inds = generate_new_feat_names(
feat_name,
size(level2vector, 1),
feat_names,
)
push!(new_feat_names, feat_names_with_inds...)
else
# Not to be transformed => left as is
push!(new_feat_names, feat_name)
push!(new_cols, col)
end
end

transformed_X = NamedTuple{tuple(new_feat_names...)}(tuple(new_cols)...)
# Attempt to preserve table type
transformed_X = Tables.materializer(X)(transformed_X)
return transformed_X
end
100 changes: 62 additions & 38 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,74 @@
# This is just some experimental code
# to implement EntityEmbeddings for purely
# categorical features
"""
A layer that implements entity embedding layers as presented in 'Entity Embeddings of
Categorical Variables by Cheng Guo, Felix Berkhahn'. Expects a matrix of dimensions (numfeats, batchsize)
and applies entity embeddings to each specified categorical feature. Other features will be left as is.

# Arguments
- `entityprops`: a vector of named tuples each of the form `(index=..., levels=..., newdim=...)` to
specify the feature index, the number of levels and the desired embeddings dimensionality for selected features of the input.
- `numfeats`: the number of features in the input.

using Flux
# Example
```julia
# Prepare a batch of four features where the 2nd and the 4th are categorical
batch = [
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1;
]

mutable struct EmbeddingMatrix
e
levels

function EmbeddingMatrix(levels; dim=4)
if dim <= 0
dimension = div(length(levels), 2)
else
dimension = min(length(levels), dim) # Dummy function for now
end
return new(Dense(length(levels), dimension), levels), dimension
end
entityprops = [
(index=2, levels=10, newdim=2),
(index=4, levels=2, newdim=1)
]
numfeats = 4

# Run it through the categorical embedding layer
embedder = EntityEmbedder(entityprops, 4)
julia> output = embedder(batch)
5×10 Matrix{Float64}:
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1
-1.27129 -0.417667 -1.40326 -0.695701 0.371741 1.69952 -1.40034 -2.04078
-0.166796 0.657619 -0.659249 -0.337757 -0.717179 -0.0176273 -1.2817 -0.0372752
0.9 0.1 0.4 0.5 0.8 0.9 1.0 1.1
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}

embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

Flux.@treelike EmbeddingMatrix

function (embed::EmbeddingMatrix)(ip)
return embed.e(Flux.onehot(ip, embed.levels))
end
# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)

mutable struct EntityEmbedding
embeddingmatrix
# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
embedders = []
modifiers = []
# Setup entityprops
cat_inds = [entityprop.index for entityprop in entityprops]
levels_per_feat = [entityprop.levels for entityprop in entityprops]
newdims = [entityprop.newdim for entityprop in entityprops]

function EntityEmbedding(a...)
return new(a)
c = 1
for i in 1:numfeats
if i in cat_inds
push!(embedders, Flux.Embedding(levels_per_feat[c] => newdims[c], init = init))
push!(modifiers, (x, i) -> Int.(x[i, :]))
c += 1
else
push!(embedders, feat -> feat)
push!(modifiers, (x, i) -> x[i:i, :])
end
end
end

Flux.@treelike EntityEmbedding


# ip is an array of tuples
function (embed::EntityEmbedding)(ip)
return hcat((vcat((embed.embeddingmatrix[i](ip[idx][i]) for i=1:length(ip[idx]))...) for idx =1:length(ip))...)
EntityEmbedder(embedders, modifiers, numfeats)
end


# Q1. How should this be called in the API?
# nn = NeuralNetworkClassifier(builder=builder, optimiser = .., embeddingdimension = 5)
#
#
#
# 4. Register it as layer with Flux
Flux.@layer EntityEmbedder
Loading
Loading