Skip to content

Commit

Permalink
Merge pull request #35 from stephantul/34-add-pooling-functions-with-…
Browse files Browse the repository at this point in the history
…safeguard

34 add pooling functions with safeguard
  • Loading branch information
stephantul authored May 3, 2023
2 parents 338da69 + 8f5f814 commit c3dbd05
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 50 deletions.
109 changes: 96 additions & 13 deletions reach/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import logging
from io import TextIOWrapper, open
from pathlib import Path
from typing import Any, Dict, Generator, Hashable, List, Optional, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Hashable,
Iterable,
List,
Optional,
Tuple,
Union,
)

import numpy as np
from tqdm import tqdm
Expand All @@ -16,7 +26,7 @@
Matrix = Union[np.ndarray, List[np.ndarray]]
SimilarityItem = List[Tuple[Hashable, float]]
SimilarityResult = List[SimilarityItem]
Tokens = List[Hashable]
Tokens = Iterable[Hashable]


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,8 +112,8 @@ def indices(self) -> Dict[int, Hashable]:
return self._indices

@property
def sorted_items(self) -> List[Hashable]:
items: List[Hashable] = [
def sorted_items(self) -> Tokens:
items: Tokens = [
item for item, _ in sorted(self.items.items(), key=lambda x: x[1])
]
return items
Expand Down Expand Up @@ -209,8 +219,8 @@ def load(
desired_dtype,
**kwargs,
)
except ValueError as e:
raise e
except ValueError as exc:
raise exc
finally:
if came_from_path:
file_handle.close()
Expand Down Expand Up @@ -371,6 +381,82 @@ def vectorize(
else:
return self.vectors[index]

def mean_pool(
self, tokens: Tokens, remove_oov: bool = False, safeguard: bool = True
) -> np.ndarray:
"""
Mean pool a list of tokens.
Parameters
----------
tokens : list.
The list of items to vectorize and then mean pool.
remove_oov : bool.
Whether to remove OOV items from the input.
If this is False, and an unknown item is encountered, then
the <UNK> symbol will be inserted if it is set. If it is not set,
then the function will throw a ValueError.
safeguard : bool.
There are a variety of reasons why we can't vectorize a list
of tokens:
- The list might be empty after removing OOV
- We remove OOV but haven't set <UNK>
- The list of tokens is empty
If safeguard is False, we simply supply a zero vector instead
of erroring out.
Returns
-------
vector: np.ndarray
a vector of the correct size, which is the mean of all tokens
in the sentence.
"""
try:
return self.vectorize(tokens, remove_oov, False).mean(0)
except ValueError as exc:
if safeguard:
raise exc
return np.zeros(self.size)

def mean_pool_corpus(
self, corpus: List[Tokens], remove_oov: bool = False, safeguard: bool = True
) -> np.ndarray:
"""
Mean pool a list of list of tokens.
Parameters
----------
corpus : a list of list of tokens.
The list of items to vectorize and then mean pool.
remove_oov : bool.
Whether to remove OOV items from the input.
If this is False, and an unknown item is encountered, then
the <UNK> symbol will be inserted if it is set. If it is not set,
then the function will throw a ValueError.
safeguard : bool.
There are a variety of reasons why we can't vectorize a list of tokens:
- The list might be empty after removing OOV
- We remove OOV but haven't set <UNK>
- The list of tokens is empty
If safeguard is False, we simply supply a zero vector instead of erroring.
Returns
-------
vector: np.ndarray
a matrix with number of rows n, where n is the number of input lists, and
columns s, which is the number of columns of a single vector.
"""
out = []
for index, tokens in enumerate(corpus):
try:
out.append(self.mean_pool(tokens, remove_oov, safeguard))
except ValueError as exc:
raise ValueError(f"Tokens at {index} errored out") from exc

return np.stack(out)

def bow(self, tokens: Tokens, remove_oov: bool = False) -> List[int]:
"""
Create a bow representation of a list of tokens.
Expand Down Expand Up @@ -413,7 +499,7 @@ def bow(self, tokens: Tokens, remove_oov: bool = False) -> List[int]:
return out

def transform(
self, corpus: List[List[Hashable]], remove_oov: bool = False, norm: bool = False
self, corpus: List[Tokens], remove_oov: bool = False, norm: bool = False
) -> List[np.ndarray]:
"""
Transform a corpus by repeated calls to vectorize, defined above.
Expand Down Expand Up @@ -469,10 +555,7 @@ def most_similar(
"""
if isinstance(items, str):
if items not in self.items:
raise KeyError(f"{items} is not in the set of items.")
items = [items]

vectors = np.stack([self.norm_vectors[self.items[item]] for item in items])
result = self._most_similar_batch(
vectors, batch_size, num + 1, show_progressbar
Expand Down Expand Up @@ -519,8 +602,6 @@ def threshold(
"""
if isinstance(items, str):
if items not in self.items:
raise KeyError(f"{items} is not in the set of items.")
items = [items]

vectors = np.stack([self.norm_vectors[self.items[x]] for x in items])
Expand Down Expand Up @@ -711,6 +792,7 @@ def vector_similarity(self, vector: np.ndarray, items: Tokens) -> np.ndarray:
"""Compute the similarity between a vector and a set of items."""
if isinstance(items, str):
items = [items]

items_vec = np.stack([self.norm_vectors[self.items[item]] for item in items])
return self._sim(vector, items_vec)

Expand Down Expand Up @@ -741,6 +823,7 @@ def similarity(self, items_1: Tokens, items_2: Tokens) -> np.ndarray:
items_1 = [items_1]
if isinstance(items_2, str):
items_2 = [items_2]

items_1_matrix = np.stack(
[self.norm_vectors[self.items[item]] for item in items_1]
)
Expand All @@ -749,7 +832,7 @@ def similarity(self, items_1: Tokens, items_2: Tokens) -> np.ndarray:
)
return self._sim(items_1_matrix, items_2_matrix)

def intersect(self, itemlist: List[Hashable]) -> Reach:
def intersect(self, itemlist: Tokens) -> Reach:
"""
Intersect a reach instance with a list of items.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_auto.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Tuple, List
import unittest
from typing import Hashable, List, Tuple

import numpy as np

from reach import AutoReach, Reach


class TestAuto(unittest.TestCase):
def data(self) -> Tuple[List[str], np.ndarray]:
words = [
def data(self) -> Tuple[List[Hashable], np.ndarray]:
words: List[Hashable] = [
"donatello",
"leonardo",
"raphael",
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_lower(self) -> None:
instance = AutoReach(vectors, words, lowercase="auto")
self.assertTrue(instance.lowercase)

words[0] = words[0].title()
words[0] = words[0].title() # type: ignore
instance = AutoReach(vectors, words, lowercase="auto")
self.assertFalse(instance.lowercase)

Expand Down
12 changes: 7 additions & 5 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import unittest
from typing import Tuple, List
from typing import Hashable, List, Tuple

import numpy as np

from reach import Reach


class TestInit(unittest.TestCase):
def data(self) -> Tuple[List[str], np.ndarray]:
words = [
def data(self) -> Tuple[List[Hashable], np.ndarray]:
words: List[Hashable] = [
"donatello",
"leonardo",
"raphael",
"michelangelo",
"splinter",
"hideout",
]
vectors = np.stack([np.arange(1, 7)] * 5).T
random_generator = np.random.RandomState(seed=44)
vectors = random_generator.standard_normal((6, 50))

return words, vectors

Expand All @@ -24,7 +26,7 @@ def test_init(self) -> None:
instance = Reach(vectors, words)

self.assertEqual(len(instance), 6)
self.assertEqual(instance.size, 5)
self.assertEqual(instance.size, 50)
self.assertTrue(np.allclose(instance.vectors, vectors))

sorted_words, _ = zip(*sorted(instance.items.items(), key=lambda x: x[1]))
Expand Down
7 changes: 2 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
import unittest
from pathlib import Path
from tempfile import NamedTemporaryFile

import numpy as np

from reach import Reach
from tempfile import NamedTemporaryFile


class TestLoad(unittest.TestCase):
Expand Down Expand Up @@ -189,7 +189,6 @@ def test_load_from_file_without_header(self) -> None:
instance = Reach.load(tempfile.name, num_to_load=-1)

def test_load_from_file_with_header(self) -> None:

with NamedTemporaryFile(mode="w+") as tempfile:
lines = self.lines()
tempfile.write(lines)
Expand Down Expand Up @@ -229,7 +228,6 @@ def test_load_from_file_with_header(self) -> None:
instance = Reach.load(tempfile.name, num_to_load=-1)

def test_save_load_fast_format(self) -> None:

with NamedTemporaryFile("w+") as tempfile:
lines = self.lines()
tempfile.write(lines)
Expand All @@ -246,7 +244,6 @@ def test_save_load_fast_format(self) -> None:
self.assertEqual(instance.name, instance_2.name)

def test_save_load(self) -> None:

with NamedTemporaryFile("w+") as tempfile:
lines = self.lines()
tempfile.write(lines)
Expand Down
Loading

0 comments on commit c3dbd05

Please sign in to comment.