Skip to content

Commit

Permalink
Merge pull request #1 from adamjgrant/v2.1
Browse files Browse the repository at this point in the history
V2.1
  • Loading branch information
adamjgrant authored Mar 2, 2024
2 parents 13574a2 + a071150 commit 72e46b6
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 16 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
training
train.csv
train.txt
.DS_Store
__pycache__
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@
A demonstration of predictive text without an LLM, using permy.link

[Check it out](https://adamgrant.info/tiny-predictive-text)

## Training
`pip install tqdm`
`pip install python-slugify`
[Training data is here](https://cdn.everything.io/datasets/blogs-news-twitter.txt.zip)
Save it in the root as `train.txt`.
Run the training with `python train.py train.txt`.
102 changes: 102 additions & 0 deletions create_dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import pickle
import json

# Assuming prune_unpopular is defined in train.py and is importable
TARGET_DICTIONARY_COUNT = 10000

def convert_to_array(obj):
"""
Recursively converts dictionary objects to the specified array format, ensuring each string ends with a space.
"""
result = []
for key, value in obj.items():
key_with_space = key + " " # Append a space to the key
if isinstance(value, dict):
# If the value is a dictionary, recursively process it
child = convert_to_array(value)
result.append([key_with_space, child])
else:
result.append(key_with_space)
return result

def main():
# Path configuration
dictionaries_path = 'training/dictionaries'
scores_3_words_file_path = 'training/scores_3_words.pkl'
scores_2_words_file_path = 'training/scores_2_words.pkl'
scores_1_word_file_path = 'training/scores_1_word.pkl'
output_file = 'dictionary.js'

# Prune the dictionaries first
prune_unpopular(scores_3_words_file_path, os.path.join(dictionaries_path, "3_words"))
prune_unpopular(scores_2_words_file_path, os.path.join(dictionaries_path, "2_words"), target_dictionary_count=5000)
prune_unpopular(scores_1_word_file_path, os.path.join(dictionaries_path, "1_word"), target_dictionary_count=1000)

# Initialize the dictionary object
dictionary = {}

# Iterate over every .pkl file in the dictionaries directory
# Define the subdirectories
subdirectories = ["3_words", "2_words", "1_word"]

print("Getting all dictionaries...")
for subdirectory in subdirectories:
# Construct the path to the subdirectory
subdirectory_path = os.path.join(dictionaries_path, subdirectory)

# Iterate over every .pkl file in the current subdirectory
for filename in os.listdir(subdirectory_path):
if filename.endswith('.pkl'):
slug, _ = os.path.splitext(filename)
file_path = os.path.join(subdirectory_path, filename)

with open(file_path, 'rb') as file:
trie = pickle.load(file)
# Convert trie to the specified array format
# Use a modified slug that includes the subdirectory name for uniqueness
dictionary[slug] = convert_to_array(trie)

print(f"Dictionary length is %s" % dictionary.keys().__len__())
# Write the dictionary object to dictionary.js in the desired format
with open(output_file, 'w') as js_file:
# Minimize by removing unnecessary whitespace in json.dumps and adjusting js_content formatting
minimized_json = json.dumps(dictionary, separators=(',', ':'))
js_content = f"export const dictionary = {minimized_json};"
js_file.write(js_content)

def prune_unpopular(scores_file_path, dictionaries_path, target_dictionary_count=TARGET_DICTIONARY_COUNT):
# Load the scores
if os.path.exists(scores_file_path):
with open(scores_file_path, 'rb') as file:
scores = pickle.load(file)
else:
print("Scores file does not exist.")
return

print(f"\nStopping to prune least popular entries down to target dictionary size of %s..." % target_dictionary_count)

# Sort scores by value in descending order and get the top_n keys
top_slugs = sorted(scores, key=scores.get, reverse=True)[:target_dictionary_count]

# Convert to set for faster lookup
top_slugs_set = set(top_slugs)

# Track slugs to be removed from scores
slugs_to_remove = []

# Iterate over all files in dictionaries directory
for filename in os.listdir(dictionaries_path):
slug, ext = os.path.splitext(filename)
if slug not in top_slugs_set:
# This file is not among the top scoring, so delete it
os.remove(os.path.join(dictionaries_path, filename))
slugs_to_remove.append(slug)

# Remove the pruned entries from scores
for slug in slugs_to_remove:
if slug in scores:
del scores[slug]

if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion dict.js

This file was deleted.

1 change: 1 addition & 0 deletions dictionary.js

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion index.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ <h3>Made with <a href="https://permy.link">Permy</a></h3>
<div id="suggestion">...</div>

<textarea id="entry" placeholder="Start typing"></textarea>
<script src="script.min.js"></script>
<script type="module" src="permute.js"></script>
<script type="module" src="dictionary.js"></script>
<script type="module" src="script.js"></script>
2 changes: 2 additions & 0 deletions permute.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 72e46b6

Please sign in to comment.