Skip to content

Conversation

VincentAuriau
Copy link
Collaborator

Shopper's 'end-of-basket' is now handled differently:

  • Trip Dataset should be instantiated normally - meaning without giving the eob item (not in assortment, not in prices, not in purchases, etc...)
  • model.n_items is now exactly set to n_items
  • embeddings of eob item are separated and concatenated to treat the eob item as the n_items + 1 item
  • model.compute_item_likelihood & others take 'n_items' as input and outputs values for 'n_items+1' with the last one being the eob

Copy link
Contributor

Summary of Changes

Hello @VincentAuriau, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant architectural change to the Shopper model by implementing explicit and automatic handling of an 'end-of-basket' item. This refactoring ensures that the model consistently recognizes and processes the termination of a basket as a distinct event, impacting how item embeddings are managed, how data is augmented, and how item likelihoods are computed. The changes aim to improve the model's accuracy and robustness by providing a dedicated representation for the end-of-basket state.

Highlights

  • Explicit End-of-Basket (EoB) Handling: The Shopper model now explicitly manages an 'end-of-basket' item, treating it as a distinct entity (n_items + 1) with its own dedicated embeddings.
  • Refined Data Processing: The TripDataset and its data augmentation methods have been updated to correctly incorporate the EoB item in permutations, padded baskets, and availability/price arrays.
  • Updated Likelihood Computations: Core model methods like compute_item_likelihood, thinking_ahead, and compute_batch_utility have been adjusted to account for the EoB item when calculating utilities and probabilities.
  • Test Suite Alignment: Integration tests were modified to reflect the new EoB handling, ensuring correctness and consistency with the updated model logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Shopper model to automatically handle an 'end-of-basket' item, which is a significant improvement. The changes are mostly well-implemented, separating the end-of-basket embeddings and updating the data generation and utility computation logic accordingly.

I've found a few areas for improvement:

  • There's a critical bug in compute_item_likelihood where the end-of-basket item is made unavailable, which would prevent it from ever being chosen.
  • Some debugging print statements have been left in the code.
  • There are opportunities to reduce code duplication in basket_dataset.py and shopper.py to improve maintainability.

My detailed comments are below.

if len(prices) == self.n_items:
prices = np.concatenate([prices, [0.0]], axis=0)
if len(available_items_copy) == self.n_items:
available_items_copy = np.concatenate([available_items_copy, [0.0]], axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The end-of-basket item's availability is being set to 0.0. This will prevent it from ever being chosen, as softmax_with_availabilities will multiply its probability by zero. Based on the logic elsewhere (e.g., in basket_dataset.py), it should be available for selection, so this should be 1.0.

Suggested change
available_items_copy = np.concatenate([available_items_copy, [0.0]], axis=0)
available_items_copy = np.concatenate([available_items_copy, [1.0]], axis=0)

Comment on lines +872 to +874
print(batch_size, available_item_batch.shape, item_batch.shape)
print(basket_batch.shape, future_batch.shape, item_batch.shape)
print(available_item_batch[batch_size - 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements appear to be for debugging purposes. They should be removed before merging to keep the output clean and avoid potential performance issues.

Comment on lines 526 to 547
if data_method == "shopper":
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
np.empty((0, self.max_length), dtype=int), # Future purchases
np.empty(0, dtype=int), # Stores
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items + 1), dtype=int), # Prices
np.empty((0, self.n_items + 1), dtype=int), # Available items
)
elif data_method == "aleacarta":
buffer = (
np.empty(0, dtype=int), # Items
np.empty((0, self.max_length), dtype=int), # Baskets
np.empty((0, self.max_length), dtype=int), # Future purchases
np.empty(0, dtype=int), # Stores
np.empty(0, dtype=int), # Weeks
np.empty((0, self.n_items), dtype=int), # Prices
np.empty((0, self.n_items), dtype=int), # Available items
)
else:
raise ValueError(f"Unknown data method: {data_method}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in the initialization of the buffer. The first five elements of the tuple are identical for both shopper and aleacarta data methods. Consider refactoring this to improve maintainability by defining the common part of the buffer first, and then appending the method-specific parts based on data_method.

Comment on lines +171 to +192
# end-of-basket rho
self.rho_eob = tf.Variable(
tf.random_normal_initializer(mean=0, stddev=1.0, seed=42)(
shape=(1, self.latent_sizes["preferences"])
), # Dimension for 1 item: latent_sizes["preferences"]
trainable=True,
name="rho_eob",
)
self.alpha = tf.Variable(
tf.random_normal_initializer(mean=0, stddev=1.0, seed=42)(
shape=(n_items, self.latent_sizes["preferences"])
), # Dimension for 1 item: latent_sizes["preferences"]
trainable=True,
name="alpha",
)
self.alpha_eob = tf.Variable( # end-of-basket alpha
tf.random_normal_initializer(mean=0, stddev=1.0, seed=42)(
shape=(1, self.latent_sizes["preferences"])
), # Dimension for 1 item: latent_sizes["preferences"]
trainable=True,
name="alpha_eob",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of embedding variables like rho, alpha, and their _eob counterparts is quite repetitive. You could introduce a private helper method to create these tf.Variable embeddings. This would reduce code duplication and make the instantiate method cleaner and easier to maintain. For example:

def _create_embedding(self, name, shape, stddev, latent_size_key, mean=0.0):
    initializer = tf.random_normal_initializer(mean=mean, stddev=stddev, seed=42)
    return tf.Variable(
        initializer(shape=(shape, self.latent_sizes[latent_size_key])),
        trainable=True,
        name=name,
    )

# In instantiate method:
self.rho = self._create_embedding("rho", n_items, 1.0, "preferences")
self.rho_eob = self._create_embedding("rho_eob", 1, 1.0, "preferences")
# ... and so on for other embeddings

Copy link
Contributor

github-actions bot commented Sep 27, 2025

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1364567%84–88, 90–94, 96–100, 104, 107, 129, 157, 230–241, 418–443, 515–518, 570–654
   base_basket_model.py2119356%110–111, 122, 140, 183, 251, 352, 457, 554–556, 602–616, 643, 683–778, 827–853, 863–885, 896–899, 921–944
   basic_attention_model.py781186%353–384, 408, 411, 417, 424, 430–432
   shopper.py19715223%132, 156–261, 272–283, 300, 342–401, 457–579, 621–739, 774–817, 868–950
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1431987%73–76, 284, 317, 393, 483, 554, 562, 583–584, 589, 601, 610–618
   preprocessing.py947817%43–45, 128–364
   synthetic_dataset.py72692%54, 158–163, 203
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2691594%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439, 540, 654–655
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL5229100481% 

Tests Skipped Failures Errors Time
204 0 💤 7 ❌ 2 🔥 6m 41s ⏱️

Copy link
Contributor

github-actions bot commented Sep 27, 2025

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1364567%84–88, 90–94, 96–100, 104, 107, 129, 157, 230–241, 418–443, 515–518, 570–654
   base_basket_model.py2119356%110–111, 122, 140, 183, 251, 352, 457, 554–556, 602–616, 643, 683–778, 827–853, 863–885, 896–899, 921–944
   basic_attention_model.py781186%353–384, 408, 411, 417, 424, 430–432
   shopper.py19715223%132, 156–261, 272–283, 300, 342–401, 457–579, 621–739, 774–817, 868–950
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1431987%73–76, 284, 317, 393, 482, 553, 561, 582–583, 588, 600, 609–617
   preprocessing.py947817%43–45, 128–364
   synthetic_dataset.py72692%54, 158–163, 203
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2691594%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439, 540, 654–655
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL5231100681% 

Tests Skipped Failures Errors Time
204 0 💤 7 ❌ 2 🔥 6m 37s ⏱️

Copy link
Contributor

github-actions bot commented Sep 27, 2025

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1364567%84–88, 90–94, 96–100, 104, 107, 129, 157, 230–241, 418–443, 515–518, 570–654
   base_basket_model.py2119356%110–111, 122, 140, 183, 251, 352, 457, 554–556, 602–616, 643, 683–778, 827–853, 863–885, 896–899, 921–944
   basic_attention_model.py781186%353–384, 408, 411, 417, 424, 430–432
   shopper.py19715223%132, 156–261, 272–283, 300, 342–401, 457–579, 621–739, 774–817, 868–950
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1431987%73–76, 284, 317, 393, 482, 553, 561, 582–583, 588, 600, 609–617
   preprocessing.py947817%43–45, 128–364
   synthetic_dataset.py72692%54, 158–163, 203
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2691594%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439, 540, 654–655
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL5231100681% 

Tests Skipped Failures Errors Time
204 0 💤 7 ❌ 2 🔥 7m 21s ⏱️

Copy link
Contributor

github-actions bot commented Sep 27, 2025

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py62198%283
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py1364567%84–88, 90–94, 96–100, 104, 107, 129, 157, 230–241, 418–443, 515–518, 570–654
   base_basket_model.py2119356%110–111, 122, 140, 183, 251, 352, 457, 554–556, 602–616, 643, 683–778, 827–853, 863–885, 896–899, 921–944
   basic_attention_model.py781186%353–384, 408, 411, 417, 424, 430–432
   shopper.py19715223%132, 156–261, 272–283, 300, 342–401, 457–579, 621–739, 774–817, 868–950
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1431987%73–76, 284, 317, 393, 482, 553, 561, 582–583, 588, 600, 609–617
   preprocessing.py947817%43–45, 128–364
   synthetic_dataset.py72692%54, 158–163, 203
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py2691594%144, 186, 283, 302, 342, 349, 378, 397, 428–429, 438–439, 540, 654–655
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
TOTAL5231100681% 

Tests Skipped Failures Errors Time
204 0 💤 7 ❌ 2 🔥 7m 44s ⏱️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants