Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.npz serialization #39

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
exclude: src/blacksquare/word_list.npz
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
hooks:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ local_scheme = "no-local-version"
where = ["src"]

[tool.setuptools.package-data]
blacksquare = ["*.dict"]
blacksquare = ["*.npz"]
Binary file not shown.
72 changes: 53 additions & 19 deletions src/blacksquare/word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,39 @@ def __init__(
"""Representation of a scored word list.

Args:
source (Union[ str, Path, List[str], Dict[str, Union[int, float]], ]): The
source for the word list. Can be a list of strings, a dict of strings
to scores, or a path to a .dict file with words in "word;score" format.
Words will be normalized and scores will be scaled from 0-1.
source: The source for the word list. Can be a list of strings, a dict of
strings to scores, a path to a .dict file with words in "word;score"
format, or a path to a .npz file (produced to `.to_npz`) Words will be
normalized and scores will be scaled from 0-1.

Raises:
ValueError: If input type is not recognized
"""
if (
isinstance(source, str)
or isinstance(source, Path)
or isinstance(source, io.IOBase)
):
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
if isinstance(source, str) or isinstance(source, Path):
if Path(source).suffix == ".npz":
loaded = np.load(source)
length_keys = {
k.split("_")[0]
for k in loaded.keys()
if k not in ("words", "scores")
}
self._words = loaded["words"]
self._scores = loaded["scores"]
self._word_scores_by_length = {
int(k): (loaded[f"{k}_words"], loaded[f"{k}_scores"])
for k in length_keys
}
return
else:
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
elif isinstance(source, list):
assert len(source) > 0 and isinstance(source[0], str)
raw_words_scores = [(w, 1) for w in source]
Expand Down Expand Up @@ -210,8 +221,31 @@ def score_filter(self, threshold: float) -> WordList:
return WordList(dict(zip(self._words[score_mask], self._scores[score_mask])))

def filter(self, filter_fn: Callable[[ScoredWord], bool]) -> WordList:
"""Returns a new word list filtered by a custom function.

Args:
filtern_fn: The filtering function. Takes a ScoredWord as an
input and outputs a bool.

Returns:
The resulting WordList
"""
return WordList(dict([w for w in self if filter_fn(w)]))

def to_npz(self, file: str | Path) -> None:
"""Serializes word list to a .npz format that is fast to load from disk.

Args:
file: The output file path.
"""
by_length_arrays = {}
for k, v in self._word_scores_by_length.items():
by_length_arrays[f"{k}_words"] = v[0]
by_length_arrays[f"{k}_scores"] = v[1]
np.savez_compressed(
file, words=self._words, scores=self._scores, **by_length_arrays
)

def __len__(self):
return len(self._words)

Expand Down Expand Up @@ -347,4 +381,4 @@ def _normalize(word: str) -> str:
return word.upper().replace(" ", "")


DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("xwordlist.dict"))
DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("word_list.npz"))
18 changes: 18 additions & 0 deletions tests/test_word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def test_add_word_list(self):
def test_score_filter(self, word_list):
assert len(word_list.score_filter(0.5)) == 6

def test_serialization(self, tmp_path):
words = """
AAA;1.0
BBB;0.99
BB;0.5
C;0.1
"""
with (tmp_path / "list.dict").open("w") as f:
f.write(words)

xw = Crossword(3)
wl = WordList(tmp_path / "list.dict")
dict_matches = wl.find_matches(xw[ACROSS, 1])
wl.to_npz(tmp_path / "list.npz")
wl_from_npz = WordList(tmp_path / "list.npz")
npz_matches = wl_from_npz.find_matches(xw[ACROSS, 1])
assert dict_matches.words == npz_matches.words


class TestMatchWordList:
@pytest.fixture
Expand Down