Skip to content

Commit

Permalink
Add plot_frequency_distribution helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
sindre0830 committed Oct 15, 2023
1 parent 84cae04 commit b0b4df5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
3 changes: 2 additions & 1 deletion source/architechtures/cbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def run() -> None:
pipeline=[
corpus.download,
corpus.flatten
]
],
data_directory="cbow"
)
# get vocabulary
vocabulary = datahandler.loaders.Vocabulary(add_padding=True, add_unknown=False)
Expand Down
2 changes: 1 addition & 1 deletion source/config_cbow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ device: "cuda:0"
vocabulary_size: 10000
batch_size: 32
window_size: 2
embedding_size: 100
embedding_size: 200

learning_rate: 0.05
max_epochs: 20
Expand Down
4 changes: 3 additions & 1 deletion source/datahandler/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def __init__(self):
self.sentences = None
self.words = None

def build(self, pipeline) -> None:
def build(self, pipeline, data_directory: str) -> None:
for step in tqdm.tqdm(pipeline, desc="Building corpus"):
step()
if self.words is not None:
utils.plot_frequency_distribution(self.words, data_directory)

def download(self) -> None:
self.sentences = list(gensim.downloader.load("text8"))
Expand Down
27 changes: 27 additions & 0 deletions source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import tqdm
import collections


def load_config(filepath: str) -> argparse.Namespace:
Expand Down Expand Up @@ -160,3 +161,29 @@ def plot_target_words_occurances(target_words: np.ndarray, data_directory: str):
# save plot
save_plot(filepath)
plt.close()


def plot_frequency_distribution(corpus, data_directory: str):
# check if it already exists
title = "Word Frequencies in Descending Order"
filepath = os.path.join(PROJECT_DIRECTORY_PATH, "data", data_directory, "plots", f"{title}.png")

word_freq = collections.Counter(corpus)
word_freq = sorted(word_freq.values(), reverse=True)
ranks = np.arange(1, len(word_freq) + 1)

plt.figure(figsize=(12, 6))
# add a Zipfian reference line
x = np.linspace(min(ranks), max(ranks), len(word_freq))
y = word_freq[0] * (x ** -1)
plt.plot(x, y, linestyle="--", color="red", label="Zipfian Reference")

plt.loglog(ranks, word_freq, label="Actual Data")
plt.xlabel("Rank (log scale)")
plt.ylabel("Frequency (log scale)")
plt.legend()
plt.title(title)
plt.grid(True)
# save plot
save_plot(filepath)
plt.close()

0 comments on commit b0b4df5

Please sign in to comment.