From 712b42fbf8969ed111f07a146053474ee9380648 Mon Sep 17 00:00:00 2001 From: Sindre Eiklid Date: Mon, 16 Oct 2023 08:48:58 +0200 Subject: [PATCH] Initial commit of `DataLoaderCooccurrence` --- source/datahandler/loaders.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/source/datahandler/loaders.py b/source/datahandler/loaders.py index 3aab1ae..758bf35 100644 --- a/source/datahandler/loaders.py +++ b/source/datahandler/loaders.py @@ -17,6 +17,7 @@ import matplotlib.pyplot as plt import matplotlib.ticker as ticker import scipy.stats +import scipy.sparse class Corpus(): @@ -378,3 +379,20 @@ def __iter__(self): def __len__(self): return (self._num_samples + self._batch_size - 1) // self._batch_size + + +class DataLoaderCooccurrence: + def __init__(self, batch_size: int): + self._token_ids = None + self._cooccurr_counts = None + + self._num_samples = len(self._token_ids) + self._batch_size = batch_size + + def __iter__(self): + for start in range(0, self._num_samples, self._batch_size): + end = min(start + self._batch_size, self._num_samples) + yield (self._token_ids[start:end], self._cooccurr_counts[start:end]) + + def __len__(self): + return (self._num_samples + self._batch_size - 1) // self._batch_size