diff --git a/source/architechtures/skipgram.py b/source/architechtures/skipgram.py index b58e5df..c970125 100644 --- a/source/architechtures/skipgram.py +++ b/source/architechtures/skipgram.py @@ -27,3 +27,6 @@ def run() -> None: # get vocabulary vocabulary = datahandler.loaders.Vocabulary(add_padding=True, add_unknown=False) vocabulary.build(corpus.words, config.vocabulary_size) + # get training data + training_dataloader = datahandler.loaders.DataLoaderSkipGram(config.batch_size) + training_dataloader.build(corpus.sentences, vocabulary, config.window_size, config.device) diff --git a/source/datahandler/loaders.py b/source/datahandler/loaders.py index bea7036..3aab1ae 100644 --- a/source/datahandler/loaders.py +++ b/source/datahandler/loaders.py @@ -332,3 +332,49 @@ def __init__(self, batch_size: int): self._num_samples = 0 self._batch_size = batch_size + + def build(self, sentences: list[list[str]], vocabulary: Vocabulary, window_size: int, device: str): + context_words_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "skipgram", "training_data", "context_words.npy") + target_words_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "skipgram", "training_data", "target_words.npy") + + if os.path.exists(context_words_filepath) and os.path.exists(target_words_filepath): + progress_bar = tqdm.tqdm(desc="Building training data", total=1) + self.context_words = torch.tensor(utils.load_numpy(context_words_filepath), dtype=torch.long, device=device) + self.target_words = torch.tensor(utils.load_numpy(target_words_filepath), dtype=torch.long, device=device) + self._num_samples = len(self.target_words) + progress_bar.update(1) + return + + context_words = [] + target_words = [] + for sentence in tqdm.tqdm(sentences, desc="Building training data"): + for center_position, center_word in enumerate(sentence): + if center_word not in vocabulary or vocabulary.subsample(center_word, threshold=1e-5): + continue + # define the boundaries of the window + start_position = max(0, center_position - window_size) + end_position = min(len(sentence), center_position + window_size + 1) + # extract words around the center word within the window + for word in sentence[start_position:end_position]: + if word == center_word or word not in vocabulary: + continue + context_words.append(vocabulary.get_index(word)) + target_words.append(vocabulary.get_index(word)) + + context_words = np.array(context_words) + target_words = np.array(target_words) + utils.save_numpy(context_words_filepath, context_words) + utils.save_numpy(target_words_filepath, target_words) + + self.context_words = torch.tensor(context_words, dtype=torch.long, device=device) + self.target_words = torch.tensor(target_words, dtype=torch.long, device=device) + self._num_samples = len(self.target_words) + utils.plot_target_words_occurances(target_words, data_directory="skipgram") + + def __iter__(self): + for start in range(0, self._num_samples, self._batch_size): + end = min(start + self._batch_size, self._num_samples) + yield (self.context_words[start:end], self.target_words[start:end]) + + def __len__(self): + return (self._num_samples + self._batch_size - 1) // self._batch_size