Skip to content

Commit

Permalink
Initial commit of ModelSkipGram class
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 15, 2023
1 parent c85e6a6 commit 3cac3cd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
3 changes: 3 additions & 0 deletions source/architechtures/skipgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def run() -> None:
# get training data
training_dataloader = datahandler.loaders.DataLoaderSkipGram(config.batch_size)
training_dataloader.build(corpus.sentences, vocabulary, config.window_size, config.device)
# get validation data
validation_dataloader = datahandler.loaders.ValidationLoader(data_directory="skipgram")
validation_dataloader.build(vocabulary)
38 changes: 38 additions & 0 deletions source/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,41 @@ def fit(
# empty GPU cache
if "cuda" in self.device:
torch.cuda.empty_cache()


class ModelSkipGram(torch.nn.Module):
def __init__(
self,
device: str,
vocabulary_size: int,
embedding_size: int,
padding_idx: int = None
):
super().__init__()
self.device = device
self.filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", "skipgram", "model.pt")
self.padding_idx = padding_idx
# init embedding layers
self.input_embeddings = torch.nn.Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_size,
padding_idx=self.padding_idx,
dtype=torch.float32,
sparse=True
)
self.output_embeddings = torch.nn.Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_size,
padding_idx=self.padding_idx,
dtype=torch.float32,
sparse=True
)
# set the initial weights to be between -0.5 and 0.5
self.input_embeddings.weight.data.uniform_(-0.5, 0.5)
self.output_embeddings.weight.data.uniform_(-0.5, 0.5)
# set padding vector to zero
if self.padding_idx is not None:
self.input_embeddings.weight.data[self.padding_idx, :] = 0
self.output_embeddings.weight.data[self.padding_idx, :] = 0
# send model to device
self.to(self.device)

0 comments on commit 3cac3cd

Please sign in to comment.