Skip to content

Commit

Permalink
Add method build to DataLoaderCooccurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 16, 2023
1 parent e4ecfac commit 21d001c
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions source/datahandler/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,40 @@ def __init__(self, batch_size: int):
self._num_samples = len(self._token_ids)
self._batch_size = batch_size

def build(self, words: list[str], vocabulary: Vocabulary, window_size: int, device: str):
cooccurrence_matrix: scipy.sparse.coo_matrix = None
cooccurrence_matrix_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "glove", "training_data", "cooccurrence_matrix.npz")
if os.path.exists(cooccurrence_matrix_filepath):
progress_bar = tqdm.tqdm(desc="Building training data", total=1)
cooccurrence_matrix = utils.load_npz(cooccurrence_matrix_filepath)
progress_bar.update(1)
else:
vocabulary_size = len(vocabulary)
cooccurrence_matrix = scipy.sparse.lil_matrix((vocabulary_size, vocabulary_size), dtype=np.float32)

for idx, word in enumerate(tqdm(words, desc="Building training data")):
word_idx = vocabulary[word]
if word not in vocabulary or vocabulary.subsample(word, threshold=1e-5):
continue
# define the boundaries of the window
start_position = max(0, idx - window_size)
end_position = min(len(words), idx + window_size + 1)
# iterate over the context window.
for j in range(start_position, end_position):
if j == idx:
continue
context_word_idx = vocabulary.get_index(words[j])
if context_word_idx == vocabulary.unknown_index or context_word_idx == vocabulary.padding_index:
continue
cooccurrence_matrix[word_idx, context_word_idx] += 1.0

cooccurrence_matrix = cooccurrence_matrix.tocoo()
utils.save_npz(cooccurrence_matrix_filepath, cooccurrence_matrix)

self._token_ids = torch.tensor(np.array(list(zip(cooccurrence_matrix.row, cooccurrence_matrix.col))), dtype=torch.long, device=device)
self._cooccurr_counts = torch.tensor(cooccurrence_matrix.data, dtype=torch.float32, device=device)


def __iter__(self):
for start in range(0, self._num_samples, self._batch_size):
end = min(start + self._batch_size, self._num_samples)
Expand Down

0 comments on commit 21d001c

Please sign in to comment.