Skip to content

Commit

Permalink
Add build method to DataLoaderSkipGram class
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 15, 2023
1 parent c8fa59f commit c85e6a6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
3 changes: 3 additions & 0 deletions source/architechtures/skipgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ def run() -> None:
# get vocabulary
vocabulary = datahandler.loaders.Vocabulary(add_padding=True, add_unknown=False)
vocabulary.build(corpus.words, config.vocabulary_size)
# get training data
training_dataloader = datahandler.loaders.DataLoaderSkipGram(config.batch_size)
training_dataloader.build(corpus.sentences, vocabulary, config.window_size, config.device)
46 changes: 46 additions & 0 deletions source/datahandler/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,49 @@ def __init__(self, batch_size: int):

self._num_samples = 0
self._batch_size = batch_size

def build(self, sentences: list[list[str]], vocabulary: Vocabulary, window_size: int, device: str):
context_words_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "skipgram", "training_data", "context_words.npy")
target_words_filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "skipgram", "training_data", "target_words.npy")

if os.path.exists(context_words_filepath) and os.path.exists(target_words_filepath):
progress_bar = tqdm.tqdm(desc="Building training data", total=1)
self.context_words = torch.tensor(utils.load_numpy(context_words_filepath), dtype=torch.long, device=device)
self.target_words = torch.tensor(utils.load_numpy(target_words_filepath), dtype=torch.long, device=device)
self._num_samples = len(self.target_words)
progress_bar.update(1)
return

context_words = []
target_words = []
for sentence in tqdm.tqdm(sentences, desc="Building training data"):
for center_position, center_word in enumerate(sentence):
if center_word not in vocabulary or vocabulary.subsample(center_word, threshold=1e-5):
continue
# define the boundaries of the window
start_position = max(0, center_position - window_size)
end_position = min(len(sentence), center_position + window_size + 1)
# extract words around the center word within the window
for word in sentence[start_position:end_position]:
if word == center_word or word not in vocabulary:
continue
context_words.append(vocabulary.get_index(word))
target_words.append(vocabulary.get_index(word))

context_words = np.array(context_words)
target_words = np.array(target_words)
utils.save_numpy(context_words_filepath, context_words)
utils.save_numpy(target_words_filepath, target_words)

self.context_words = torch.tensor(context_words, dtype=torch.long, device=device)
self.target_words = torch.tensor(target_words, dtype=torch.long, device=device)
self._num_samples = len(self.target_words)
utils.plot_target_words_occurances(target_words, data_directory="skipgram")

def __iter__(self):
for start in range(0, self._num_samples, self._batch_size):
end = min(start + self._batch_size, self._num_samples)
yield (self.context_words[start:end], self.target_words[start:end])

def __len__(self):
return (self._num_samples + self._batch_size - 1) // self._batch_size

0 comments on commit c85e6a6

Please sign in to comment.