Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapting the script movielens_recommendations_transformers.py to be Backend-Agnostic #2039

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 104 additions & 86 deletions examples/structured_data/movielens_recommendations_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Title: A Transformer-based recommendation system
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
Date created: 2020/12/30
Last modified: 2025/01/03
Last modified: 2025/01/27
Description: Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens.
Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
Accelerator: GPU
"""

Expand Down Expand Up @@ -52,7 +53,7 @@

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["KERAS_BACKEND"] = "jax" # or torch, or tensorflow

import math
from zipfile import ZipFile
Expand All @@ -61,8 +62,7 @@
import keras
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import layers
from keras import layers, ops
from keras.layers import StringLookup

"""
Expand Down Expand Up @@ -254,29 +254,89 @@ def create_sequences(values, window_size, step_size):

MOVIE_FEATURES = ["genres"]


"""
## Create `tf.data.Dataset` for training and evaluation
## Encode input features

The `encode_input_features` function works as follows:

1. Each categorical user feature is encoded using `layers.Embedding`, with embedding
dimension equals to the square root of the vocabulary size of the feature.
The embeddings of these features are concatenated to form a single input tensor.

2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,
where the dimension size is the square root of the number of movies.

3. A multi-hot genres vector for each movie is concatenated with its embedding vector,
and processed using a non-linear `layers.Dense` to output a vector of the same movie
embedding dimensions.

4. A positional embedding is added to each movie embedding in the sequence, and then
multiplied by its rating from the ratings sequence.

5. The target movie embedding is concatenated to the sequence movie embeddings, producing
a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected
by the attention layer for the transformer architecture.

6. The method returns a tuple of two elements: `encoded_transformer_features` and
`encoded_other_features`.
"""

# Required for tf.data.Dataset
import tensorflow as tf


def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):

def process(features):
movie_ids_string = features["sequence_movie_ids"]
sequence_movie_ids = tf.strings.split(movie_ids_string, ",").to_tensor()

# The last movie id in the sequence is the target movie.
features["target_movie_id"] = sequence_movie_ids[:, -1]
features["sequence_movie_ids"] = sequence_movie_ids[:, :-1]

# Sequence ratings
ratings_string = features["sequence_ratings"]
sequence_ratings = tf.strings.to_number(
tf.strings.split(ratings_string, ","), tf.dtypes.float32
).to_tensor()

# The last rating in the sequence is the target for the model to predict.
target = sequence_ratings[:, -1]
features["sequence_ratings"] = sequence_ratings[:, :-1]

def encoding_helper(feature_name):

# This are target_movie_id and sequence_movie_ids and they have the same
# vocabulary as movie_id.
if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
index_lookup = StringLookup(
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
)
# Convert the string input values into integer indices.
value_index = index_lookup(features[feature_name])
features[feature_name] = value_index

else:
# movie_id is not part of the features, hence not processed. It was mainly required
# for its vocabulary above.
if feature_name == "movie_id":
pass
else:
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
index_lookup = StringLookup(
vocabulary=vocabulary, mask_token=None, num_oov_indices=0
)
# Convert the string input values into integer indices.
value_index = index_lookup(features[feature_name])
features[feature_name] = value_index

# Encode the user features
for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
encoding_helper(feature_name)
# Encoding target_movie_id and returning it as the target variable
encoding_helper("target_movie_id")
# Encoding sequence movie_ids.
encoding_helper("sequence_movie_ids")
return dict(features), target

dataset = tf.data.experimental.make_csv_dataset(
Expand All @@ -292,62 +352,11 @@ def process(features):
return dataset


"""
## Create model inputs
"""


def create_model_inputs():
return {
"user_id": keras.Input(name="user_id", shape=(1,), dtype="string"),
"sequence_movie_ids": keras.Input(
name="sequence_movie_ids", shape=(sequence_length - 1,), dtype="string"
),
"target_movie_id": keras.Input(
name="target_movie_id", shape=(1,), dtype="string"
),
"sequence_ratings": keras.Input(
name="sequence_ratings", shape=(sequence_length - 1,), dtype=tf.float32
),
"sex": keras.Input(name="sex", shape=(1,), dtype="string"),
"age_group": keras.Input(name="age_group", shape=(1,), dtype="string"),
"occupation": keras.Input(name="occupation", shape=(1,), dtype="string"),
}


"""
## Encode input features

The `encode_input_features` method works as follows:

1. Each categorical user feature is encoded using `layers.Embedding`, with embedding
dimension equals to the square root of the vocabulary size of the feature.
The embeddings of these features are concatenated to form a single input tensor.

2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,
where the dimension size is the square root of the number of movies.

3. A multi-hot genres vector for each movie is concatenated with its embedding vector,
and processed using a non-linear `layers.Dense` to output a vector of the same movie
embedding dimensions.

4. A positional embedding is added to each movie embedding in the sequence, and then
multiplied by its rating from the ratings sequence.

5. The target movie embedding is concatenated to the sequence movie embeddings, producing
a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected
by the attention layer for the transformer architecture.

6. The method returns a tuple of two elements: `encoded_transformer_features` and
`encoded_other_features`.
"""


def encode_input_features(
inputs,
include_user_id=True,
include_user_features=True,
include_movie_features=True,
include_user_id,
include_user_features,
include_movie_features,
):
encoded_transformer_features = []
encoded_other_features = []
Expand All @@ -360,11 +369,7 @@ def encode_input_features(

## Encode user features
for feature_name in other_feature_names:
# Convert the string input values into integer indices.
vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
idx = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)(
inputs[feature_name]
)
# Compute embedding dimensions
embedding_dims = int(math.sqrt(len(vocabulary)))
# Create an embedding layer with the specified dimensions.
Expand All @@ -374,7 +379,7 @@ def encode_input_features(
name=f"{feature_name}_embedding",
)
# Convert the index values to embedding representations.
encoded_other_features.append(embedding_encoder(idx))
encoded_other_features.append(embedding_encoder(inputs[feature_name]))

## Create a single embedding vector for the user features
if len(encoded_other_features) > 1:
Expand All @@ -387,13 +392,6 @@ def encode_input_features(
## Create a movie embedding encoder
movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))
# Create a lookup to convert string values to integer indices.
movie_index_lookup = StringLookup(
vocabulary=movie_vocabulary,
mask_token=None,
num_oov_indices=0,
name="movie_index_lookup",
)
# Create an embedding layer with the specified dimensions.
movie_embedding_encoder = layers.Embedding(
input_dim=len(movie_vocabulary),
Expand All @@ -419,11 +417,10 @@ def encode_input_features(
## Define a function to encode a given movie id.
def encode_movie(movie_id):
# Convert the string input values into integer indices.
movie_idx = movie_index_lookup(movie_id)
movie_embedding = movie_embedding_encoder(movie_idx)
movie_embedding = movie_embedding_encoder(movie_id)
encoded_movie = movie_embedding
if include_movie_features:
movie_genres_vector = movie_genres_lookup(movie_idx)
movie_genres_vector = movie_genres_lookup(movie_id)
encoded_movie = movie_embedding_processor(
layers.concatenate([movie_embedding, movie_genres_vector])
)
Expand All @@ -442,11 +439,11 @@ def encode_movie(movie_id):
output_dim=movie_embedding_dims,
name="position_embedding",
)
positions = tf.range(start=0, limit=sequence_length - 1, delta=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the old really need extensive refactoring? Best I can tell only this line needed to change (to become ops.arange()

Copy link
Contributor Author

@Humbulani1234 Humbulani1234 Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing only the tf.range to ops.arange does not work and I believe it is because StringLookup is part of the model, not of tf.data.Dataset in the original script, and only Tensorflow can handle strings.

So, my approach was to split the function encode_input_features into two parts StringLookups and Embeddings. I placed the StringLookups functionality into tf.data.Dataset processing step, and created a class to handle each input embeddings. And also the issue that this script must run whether one chooses to include user_features or not also adds some refactoring.

However, if there is a less-refactoring approach/method, I'm wiling to learn and implement it. I must confess also that the refactoring felt a bit extensive, but required by my approach.

positions = ops.arange(start=0, stop=sequence_length - 1, step=1)
encodded_positions = position_embedding_encoder(positions)
# Retrieve sequence ratings to incorporate them into the encoding of the movie.
sequence_ratings = inputs["sequence_ratings"]
sequence_ratings = keras.ops.expand_dims(sequence_ratings, -1)
sequence_ratings = ops.expand_dims(sequence_ratings, -1)
# Add the positional encoding to the movie encodings and multiply them by rating.
encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(
[(encoded_sequence_movies + encodded_positions), sequence_ratings]
Expand All @@ -455,17 +452,38 @@ def encode_movie(movie_id):
# Construct the transformer inputs.
for i in range(sequence_length - 1):
feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]
feature = keras.ops.expand_dims(feature, 1)
feature = ops.expand_dims(feature, 1)
encoded_transformer_features.append(feature)
encoded_transformer_features.append(encoded_target_movie)

encoded_transformer_features = layers.concatenate(
encoded_transformer_features, axis=1
)

return encoded_transformer_features, encoded_other_features


"""
## Create model inputs
"""


def create_model_inputs():
return {
"user_id": keras.Input(name="user_id", shape=(1,), dtype="int32"),
"sequence_movie_ids": keras.Input(
name="sequence_movie_ids", shape=(sequence_length - 1,), dtype="int32"
),
"target_movie_id": keras.Input(
name="target_movie_id", shape=(1,), dtype="int32"
),
"sequence_ratings": keras.Input(
name="sequence_ratings", shape=(sequence_length - 1,), dtype="float32"
),
"sex": keras.Input(name="sex", shape=(1,), dtype="int32"),
"age_group": keras.Input(name="age_group", shape=(1,), dtype="int32"),
"occupation": keras.Input(name="occupation", shape=(1,), dtype="int32"),
}


"""
## Create a BST model
"""
Expand All @@ -480,11 +498,11 @@ def encode_movie(movie_id):


def create_model():

inputs = create_model_inputs()
transformer_features, other_features = encode_input_features(
inputs, include_user_id, include_user_features, include_movie_features
)

# Create a multi-headed attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate
Expand All @@ -501,21 +519,20 @@ def create_model():
transformer_features = layers.LayerNormalization()(transformer_features)
features = layers.Flatten()(transformer_features)

# Included the other features.
# Included the other_features.
if other_features is not None:
features = layers.concatenate(
[features, layers.Reshape([other_features.shape[-1]])(other_features)]
)

# Fully-connected layers.
for num_units in hidden_units:
features = layers.Dense(num_units)(features)
features = layers.BatchNormalization()(features)
features = layers.LeakyReLU()(features)
features = layers.Dropout(dropout_rate)(features)

outputs = layers.Dense(units=1)(features)
model = keras.Model(inputs=inputs, outputs=outputs)

return model


Expand All @@ -533,6 +550,7 @@ def create_model():
)

# Read the training data.

train_dataset = get_dataset_from_csv("train_data.csv", batch_size=265, shuffle=True)

# Fit the model with the training data.
Expand Down