Skip to content

Commit

Permalink
Fixing dict
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Mar 4, 2024
2 parents f048165 + 3912cf9 commit ad8b369
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 120 deletions.
103 changes: 32 additions & 71 deletions create_dictionary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle
import json

# Across all dictionaries, how many entry word sets total should we regularly prune the
Expand Down Expand Up @@ -60,95 +59,57 @@ def convert_to_array(obj):
result.append(key_with_space)
return result

def main():
# Path configuration
dictionaries_path = 'training/dictionaries'
scores_3_words_file_path = 'training/scores_3_words.pkl'
scores_2_words_file_path = 'training/scores_2_words.pkl'
scores_1_word_file_path = 'training/scores_1_word.pkl'
def main(trie_store):
# Path configuration is no longer needed for file paths but retained for logical separation
# in trie_store
output_file = 'dictionary.js'

# Prune the dictionaries first
prune_unpopular(scores_3_words_file_path, os.path.join(dictionaries_path, "3_words"), target_dictionary_count=int(TARGET_DICTIONARY_COUNT * THREE_WORD_STAKE_PERCENT))
prune_unpopular(scores_2_words_file_path, os.path.join(dictionaries_path, "2_words"), target_dictionary_count=int(TARGET_DICTIONARY_COUNT * TWO_WORD_STAKE_PERCENT))
prune_unpopular(scores_1_word_file_path, os.path.join(dictionaries_path, "1_word"), target_dictionary_count=int(TARGET_DICTIONARY_COUNT * ONE_WORD_STAKE_PERCENT))
prune_unpopular(trie_store, "3_words", target_dictionary_count=int(TARGET_DICTIONARY_COUNT * THREE_WORD_STAKE_PERCENT))
prune_unpopular(trie_store, "2_words", target_dictionary_count=int(TARGET_DICTIONARY_COUNT * TWO_WORD_STAKE_PERCENT))
prune_unpopular(trie_store, "1_word", target_dictionary_count=int(TARGET_DICTIONARY_COUNT * ONE_WORD_STAKE_PERCENT))

# Initialize the dictionary object
dictionary = {}

# Iterate over every .pkl file in the dictionaries directory
# Define the subdirectories
subdirectories = ["3_words", "2_words", "1_word"]

print("Getting all dictionaries...")
for subdirectory in subdirectories:
# Construct the path to the subdirectory
subdirectory_path = os.path.join(dictionaries_path, subdirectory)

# Iterate over every .pkl file in the current subdirectory
for filename in os.listdir(subdirectory_path):
if filename.endswith('.pkl'):
slug, _ = os.path.splitext(filename)
file_path = os.path.join(subdirectory_path, filename)

with open(file_path, 'rb') as file:
trie = pickle.load(file)
# Convert trie to the specified array format
# Use a modified slug that includes the subdirectory name for uniqueness
dictionary[slug] = convert_to_array(trie)
# Iterate over trie_store's sub-keys instead of .pkl files
for dictionary_key in ["3_words", "2_words", "1_word"]:
for slug, trie in trie_store['tries'].get(dictionary_key, {}).items():
# Directly use trie from trie_store for conversion
dictionary[slug] = convert_to_array(trie)

print(f"Dictionary width is %s" % dictionary.keys().__len__())
print(f"Dictionary width is {len(dictionary.keys())}")
# Write the dictionary object to dictionary.js in the desired format
with open(output_file, 'w') as js_file:
# Minimize by removing unnecessary whitespace in json.dumps and adjusting js_content formatting
minimized_json = json.dumps(dictionary, separators=(',', ':'))
js_content = f"export const dictionary = {minimized_json};"
js_file.write(js_content)

def prune_unpopular(scores_file_path, dictionaries_path, target_dictionary_count=TARGET_DICTIONARY_COUNT):
# Load the scores
if os.path.exists(scores_file_path):
with open(scores_file_path, 'rb') as file:
scores = pickle.load(file)
else:
print("Scores file does not exist.")
return

print(f"\nStopping to prune least popular entries down to target dictionary size of %s..." % target_dictionary_count)

# Sort scores by value in descending order and get the top_n keys
def prune_unpopular(trie_store, dictionary_key, target_dictionary_count=TARGET_DICTIONARY_COUNT):
# Access the nested scores directly
scores = trie_store['scores'].get(dictionary_key, {})

# Sort scores by value in descending order to identify top slugs
top_slugs = sorted(scores, key=scores.get, reverse=True)[:target_dictionary_count]

# Convert to set for faster lookup
top_slugs_set = set(top_slugs)

# Track slugs to be removed from scores
slugs_to_remove = []

# Iterate over all files in dictionaries directory
for filename in os.listdir(dictionaries_path):
slug, ext = os.path.splitext(filename)
full_path = os.path.join(dictionaries_path, filename)
if slug not in top_slugs_set:
# This file is not among the top scoring, so delete it
os.remove(full_path)
slugs_to_remove.append(slug)
else:
# Since we're keeping the file, let's prune its branches.
with open(full_path, 'rb') as f:
trie = pickle.load(f)

# Apply branch pruning to the trie
branch_pruner(trie)

# Save the pruned trie back to the file
with open(full_path, 'wb') as f:
pickle.dump(trie, f, protocol=pickle.HIGHEST_PROTOCOL)
existing_slugs = set(trie_store['tries'][dictionary_key].keys())
slugs_to_keep = set(top_slugs)
slugs_to_prune = existing_slugs - slugs_to_keep

# Remove the pruned entries from scores
for slug in slugs_to_remove:
# Prune the tries not in top slugs
for slug in slugs_to_prune:
del trie_store['tries'][dictionary_key][slug]

# Optionally, prune scores not associated with top slugs
for slug in slugs_to_prune:
if slug in scores:
del scores[slug]
del trie_store['scores'][dictionary_key][slug]

# Apply branch pruning to each trie in top slugs within the specific dictionary_key
for slug in slugs_to_keep:
trie = trie_store['tries'][dictionary_key].get(slug, {})
branch_pruner(trie)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion dictionary.js

Large diffs are not rendered by default.

88 changes: 40 additions & 48 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Import necessary modules
import os
import sys
import pickle
import re
from collections import defaultdict
import shutil
from tqdm import tqdm
from slugify import slugify
Expand All @@ -14,6 +11,12 @@
PRUNE_FREQUENCY = 200000 # Every this many document positions
CHUNK_SIZE = 1024 # 1KB per chunk

# Global variable to hold tries and scores
trie_store = {
'tries': {'3_words': {}, '2_words': {}, '1_word': {}},
'scores': {}
}

# Define a flag to indicate when an interrupt has been caught
interrupted = False

Expand Down Expand Up @@ -53,32 +56,31 @@ def update_trie(trie, predictive_words):
trie['\ranked'].insert(max(0, index - 1), trie['\ranked'].pop(index))
trie = trie[word]

# Define a function to load or initialize the trie from a .pkl file
def load_trie(file_path):
if os.path.exists(file_path):
with open(file_path, 'rb') as file:
trie = pickle.load(file)
else:
trie = {}
return trie

# Define a function to save the updated trie back to the .pkl file
def save_trie(trie, file_path):
with open(file_path, 'wb') as file:
pickle.dump(trie, file, protocol=pickle.HIGHEST_PROTOCOL)

# Define a function to update scores in scores.pkl
def update_scores(scores_file_path, context_slug):
if os.path.exists(scores_file_path):
with open(scores_file_path, 'rb') as file:
scores = pickle.load(file)
else:
scores = {}

scores[context_slug] = scores.get(context_slug, 0) + 1
# Define a function to load or initialize the trie from memory
def load_trie(path, context_slug):
# Access the trie data by first navigating to the path, then the context_slug
return trie_store['tries'].get(path, {}).get(context_slug, {})

def save_trie(trie, path, context_slug):
# Check if the path exists in 'tries'; if not, create it
if path not in trie_store['tries']:
trie_store['tries'][path] = {}

with open(scores_file_path, 'wb') as file:
pickle.dump(scores, file, protocol=pickle.HIGHEST_PROTOCOL)
# Now, path exists for sure; check for context_slug under this path
# This step might be redundant if you're always going to assign a new trie,
# but it's crucial if you're updating or merging with existing data.
if context_slug not in trie_store['tries'][path]:
trie_store['tries'][path][context_slug] = {}

# Assign the trie to the specified path and context_slug
trie_store['tries'][path][context_slug] = trie

def update_scores(path, context_slug):
if path not in trie_store['scores']:
trie_store['scores'][path] = {}
if context_slug not in trie_store['scores'][path]:
trie_store['scores'][path][context_slug] = 0
trie_store['scores'][path][context_slug] += 1

# Define a main function to orchestrate the training process
def main():
Expand Down Expand Up @@ -107,12 +109,6 @@ def main():
scores_2_words_file_path = 'training/scores_2_words.pkl'
scores_1_word_file_path = 'training/scores_1_word.pkl'

# Set each score file with an empty object if they don't exist.
for path in [scores_1_word_file_path, scores_2_words_file_path, scores_3_words_file_path]:
if not os.path.exists(path):
with open(path, 'wb') as scores_file:
pickle.dump({}, scores_file, protocol=pickle.HIGHEST_PROTOCOL)

# Read the TXT file and process the training data

# Get the total size of the file to calculate the number of iterations needed
Expand Down Expand Up @@ -146,7 +142,7 @@ def main():
words = row.split()

# Every now and then save our progress.
print(f"Saving the current position of %s" % current_position)
# print(f"Saving the current position of %s" % current_position)
# Save the current progress (file position)
with open(progress_file, 'w') as f:
f.write(str(current_position))
Expand All @@ -159,7 +155,8 @@ def main():
if (current_position - prune_position_marker > PRUNE_FREQUENCY):
prune_position_marker = current_position
print(f"Passed %s positions. Time to optimize before continuing..." % PRUNE_FREQUENCY)
flatten_to_dictionary()
global trie_store
flatten_to_dictionary(trie_store)

# Process words three at a time with shifting window
for i in range(len(words) - 2):
Expand All @@ -185,37 +182,32 @@ def main():
if not predictive_words: # Skip if there are no predictive words
continue

finish_filing(context_words, predictive_words, scores_3_words_file_path, "3_words")
finish_filing(context_words, predictive_words, "3_words")

## Two word alternative
context_words_2 = words[i+1:i+3]
predictive_words_2 = predictive_words[:2]
finish_filing(context_words_2, predictive_words_2, scores_2_words_file_path, "2_words")
finish_filing(context_words_2, predictive_words_2, "2_words")

## Three word alternative
context_words_1 = words[i+2:i+3]
finish_filing(context_words_1, predictive_words_2, scores_1_word_file_path, "1_word")
finish_filing(context_words_1, predictive_words_2, "1_word")

def finish_filing(context_words, predictive_words, scores_file_path, dictionary_subpath):
def finish_filing(context_words, predictive_words, dictionary_subpath):
# Slugify the context words
context_slug = _slugify('_'.join(context_words))

# Before loading or initializing the trie, ensure the directory exists
dictionary_directory = os.path.join('training/dictionaries', dictionary_subpath)
os.makedirs(dictionary_directory, exist_ok=True)

# Now you can safely proceed with the trie file path
trie_file_path = os.path.join(dictionary_directory, f'{context_slug}.pkl')
trie = load_trie(trie_file_path)
trie = load_trie(dictionary_subpath, context_slug)

# Update the trie with the predictive words
update_trie(trie, predictive_words)

# Save the updated trie back to the .pkl file
save_trie(trie, trie_file_path)
save_trie(trie, dictionary_subpath, context_slug)

# Update the counts in scores_3_words.pkl for the context words slug
update_scores(scores_file_path, context_slug)
update_scores(dictionary_subpath, context_slug)

# Check if the script is being run directly and call the main function
if __name__ == "__main__":
Expand Down

0 comments on commit ad8b369

Please sign in to comment.