From 5c414059e1ab908f77318df85b3905605fd897b4 Mon Sep 17 00:00:00 2001 From: lthoang Date: Tue, 19 Dec 2023 19:29:19 +0800 Subject: [PATCH] Add input-output items iteration --- cornac/data/dataset.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index 29d8c7d5..36132e06 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -1360,3 +1360,35 @@ def si_iter(self, batch_size=1, shuffle=False): for batch_session_indices, batch_mapped_ids in self.s_iter(batch_size, shuffle): batch_session_items = [[self.uir_tuple[1][i] for i in ids] for ids in batch_mapped_ids] yield batch_session_indices, batch_mapped_ids, batch_session_items + + def io_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of input item indices, batch of output item indices. + A sequence `a b c d` produces [a, b, c] and [b, c, d] as input items and output items respectively. + + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of triplets will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of input item indices, batch of output item indices + + """ + input_iids = np.asarray([], dtype="int") + output_iids = np.asarray([], dtype="int") + for _, [mapped_ids] in self.s_iter(1, shuffle): + if len(mapped_ids) < 2: + continue + input_iids = np.concatenate([input_iids, self.uir_tuple[1][mapped_ids[:-1]]]) + output_iids = np.concatenate([output_iids, self.uir_tuple[1][mapped_ids[1:]]]) + if len(input_iids) >= batch_size: + batch_input_iids = input_iids[:batch_size] + batch_output_iids = output_iids[:batch_size] + input_iids = input_iids[batch_size:] + output_iids = output_iids[batch_size:] + yield batch_input_iids, batch_output_iids + if len(input_iids) >= 0: + yield input_iids, output_iids