From 21d001c2fbc13b2f28cd23b0e1532f1260dd43f2 Mon Sep 17 00:00:00 2001 From: Sindre Eiklid Date: Mon, 16 Oct 2023 09:02:37 +0200 Subject: [PATCH] Add method `build` to `DataLoaderCooccurrence` --- source/datahandler/loaders.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/source/datahandler/loaders.py b/source/datahandler/loaders.py index 758bf35..68a773e 100644 --- a/source/datahandler/loaders.py +++ b/source/datahandler/loaders.py @@ -389,6 +389,40 @@ def __init__(self, batch_size: int): self._num_samples = len(self._token_ids) self._batch_size = batch_size + def build(self, words: list[str], vocabulary: Vocabulary, window_size: int, device: str): + cooccurrence_matrix: scipy.sparse.coo_matrix = None + cooccurrence_matrix_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "glove", "training_data", "cooccurrence_matrix.npz") + if os.path.exists(cooccurrence_matrix_filepath): + progress_bar = tqdm.tqdm(desc="Building training data", total=1) + cooccurrence_matrix = utils.load_npz(cooccurrence_matrix_filepath) + progress_bar.update(1) + else: + vocabulary_size = len(vocabulary) + cooccurrence_matrix = scipy.sparse.lil_matrix((vocabulary_size, vocabulary_size), dtype=np.float32) + + for idx, word in enumerate(tqdm(words, desc="Building training data")): + word_idx = vocabulary[word] + if word not in vocabulary or vocabulary.subsample(word, threshold=1e-5): + continue + # define the boundaries of the window + start_position = max(0, idx - window_size) + end_position = min(len(words), idx + window_size + 1) + # iterate over the context window. + for j in range(start_position, end_position): + if j == idx: + continue + context_word_idx = vocabulary.get_index(words[j]) + if context_word_idx == vocabulary.unknown_index or context_word_idx == vocabulary.padding_index: + continue + cooccurrence_matrix[word_idx, context_word_idx] += 1.0 + + cooccurrence_matrix = cooccurrence_matrix.tocoo() + utils.save_npz(cooccurrence_matrix_filepath, cooccurrence_matrix) + + self._token_ids = torch.tensor(np.array(list(zip(cooccurrence_matrix.row, cooccurrence_matrix.col))), dtype=torch.long, device=device) + self._cooccurr_counts = torch.tensor(cooccurrence_matrix.data, dtype=torch.float32, device=device) + + def __iter__(self): for start in range(0, self._num_samples, self._batch_size): end = min(start + self._batch_size, self._num_samples)