diff --git a/recnn/data/utils.py b/recnn/data/utils.py index c569f03..3c06246 100755 --- a/recnn/data/utils.py +++ b/recnn/data/utils.py @@ -128,9 +128,7 @@ def prepare_batch_static_size(batch, item_embeddings_tensor=False, frame_size=10 # id_to_key - dict index -> key -def make_items_tensor(items_embeddings_key_dict, include_zero=True): - if include_zero: - items_embeddings_key_dict[0] = torch.zeros([128, ]) +def make_items_tensor(items_embeddings_key_dict): keys = list(sorted(items_embeddings_key_dict.keys())) key_to_id = dict(zip(keys, range(len(keys)))) id_to_key = dict(zip(range(len(keys)), keys))