Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions choice_learn/basket_models/alea_carta.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,6 @@ def __init__(
# Add epsilon to prices to avoid NaN values (log(0))
self.epsilon_price = epsilon_price

if len(tf.config.get_visible_devices("GPU")):
# At least one available GPU
self.on_gpu = True
else:
# No available GPU
self.on_gpu = False
# /!\ If a model trained on GPU is loaded on CPU, self.on_gpu must be set
# to False manually after loading the model, and vice versa

self.instantiated = False

def instantiate(
Expand Down Expand Up @@ -278,7 +269,7 @@ def compute_batch_utility(
week_batch: np.ndarray,
price_batch: np.ndarray,
) -> tf.Tensor:
"""Compute the utility of all the items in item_batch.
"""Compute the utility of all the items in item_batch given the items in basket_batch.

Parameters
----------
Expand Down Expand Up @@ -357,7 +348,7 @@ def compute_batch_utility(
) # Shape: (batch_size,)

# Create a RaggedTensor from the indices with padding removed
item_indices_ragged = tf.cast(
"""item_indices_ragged = tf.cast(
tf.ragged.boolean_mask(basket_batch, basket_batch != -1),
dtype=tf.int32,
)
Expand All @@ -371,7 +362,11 @@ def compute_batch_utility(
)
else:
# Gather the embeddings using a ragged tensor of indices
alpha_by_basket = tf.ragged.map_flat_values(tf.gather, self.alpha, item_indices_ragged)
alpha_by_basket = tf.ragged.map_flat_values(tf.gather, self.alpha, item_indices_ragged)"""
alpha_by_basket = tf.gather(
tf.concat([tf.zeros((1, self.alpha.shape[1])), self.alpha], axis=0),
basket_batch + tf.ones_like(basket_batch),
)
# Basket interaction: one vs all
alpha_i = tf.expand_dims(alpha_item, axis=1) # Shape: (batch_size, 1, latent_size)
# Compute the dot product along the last dimension (latent_size)
Expand All @@ -393,8 +388,9 @@ def compute_basket_utility(
prices: Union[None, np.ndarray] = None,
trip: Union[None, Trip] = None,
) -> float:
"""Compute the utility of an (unordered) basket.
r"""Compute the utility of an (unordered) basket.

Corresponds to the sum of all the conditional utilities: \sum_{i \in basket} U(i | basket \ {i})
Take as input directly a Trip object or separately basket, store,
week and prices.

Expand Down Expand Up @@ -461,7 +457,7 @@ def compute_item_likelihood(
prices: Union[None, np.ndarray] = None,
trip: Union[None, Trip] = None,
) -> tf.Tensor:
"""Compute the likelihood of all items for a given trip.
"""Compute the likelihood for all items (as next item) with a given basket.

Take as input directly a Trip object or separately basket, available_items,
store, week and prices.
Expand Down Expand Up @@ -527,6 +523,7 @@ def compute_item_likelihood(
prices = trip.prices

# Prevent unintended side effects from in-place modifications
# Likelihood of an item in the basket = 0
available_items_copy = available_items.copy()
for basket_item in basket:
if basket_item != -1:
Expand All @@ -535,7 +532,7 @@ def compute_item_likelihood(
# Compute the utility of all the items
all_utilities = self.compute_batch_utility(
# All items
item_batch=np.array([item_id for item_id in range(self.n_items)]),
item_batch=np.arange(self.n_items),
# For each item: same basket / store / week / prices / available items
basket_batch=np.array([basket for _ in range(self.n_items)]),
store_batch=np.array([store for _ in range(self.n_items)]),
Expand Down
8 changes: 5 additions & 3 deletions notebooks/basket_models/alea_carta.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"import os\n",
"import sys\n",
"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
"\n",
"sys.path.append(\"./../../\")\n",
"print(os.getcwd())\n",
"\n",
Expand Down Expand Up @@ -132,7 +134,7 @@
"metadata": {},
"source": [
"### References\n",
"[1] Better Capturing Interactions between Products in Retail: Revisited Negative Sampling for Basket Choice Modeling, Désir, J.; Auriaut, V.; Možina, M.; Malherbe, E. (2025), ECML PKDDD Applied Data Science"
"[1] Better Capturing Interactions between Products in Retail: Revisited Negative Sampling for Basket Choice Modeling, Désir, J.; Auriau, V.; Možina, M.; Malherbe, E. (2025), ECML-PKDDD"
]
},
{
Expand All @@ -145,7 +147,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tf_env",
"display_name": "basics",
"language": "python",
"name": "python3"
},
Expand All @@ -159,7 +161,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
Loading