Skip to content
This repository has been archived by the owner on Feb 7, 2023. It is now read-only.

Commit

Permalink
Convert Deprecation (#240)
Browse files Browse the repository at this point in the history
* added-use-removed-convert

* docs-update

* modified-test

* typo

* test-fix

* removed-convert-from-tests

* convert-doh-import

* import-bug

* maybe this

* yet-another-fix

* found-the-bug

* multi-lang
  • Loading branch information
koaning authored Oct 7, 2020
1 parent f0df039 commit a418c49
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 166 deletions.
1 change: 1 addition & 0 deletions docs/api/language/universal_sentence.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: whatlies.language._sentence_encode_lang
5 changes: 5 additions & 0 deletions docs/releases.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
v0.5.3

- Deprecated the `ConveRTLanguage` backend. The original authors removed the embeddings.
- Added the support for the Universal Sentence Encoder.

v0.5.2

- Fixed the `ConveRTLanguage` backend. The original source changed their download url.
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ nav:
- Gensim: api/language/gensim_lang.md
- Huggingface: api/language/transformers.md
- TFHub: api/language/tfhub.md
- Universal Sentence Encoder: api/language/universal_sentence.md
- Examples:
- Debiasing Projections: examples/lipstick-pig.md
- Roadmap: roadmap.md
Expand Down
63 changes: 0 additions & 63 deletions tests/test_lang/test_convert_lang.py

This file was deleted.

11 changes: 11 additions & 0 deletions tests/test_lang/test_universal_sentence_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np

from whatlies.language import TFHubLanguage, UniversalSentenceLanguage


def test_same_results():
use_lang = UniversalSentenceLanguage("multi", 3)
tf_lang = TFHubLanguage(
"https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
)
assert np.allclose(use_lang["hello world"].vector, tf_lang["hello world"].vector)
4 changes: 0 additions & 4 deletions tests/test_sklearn/test_simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,19 @@

from whatlies.language import (
FasttextLanguage,
CountVectorLanguage,
SpacyLanguage,
GensimLanguage,
BytePairLanguage,
TFHubLanguage,
ConveRTLanguage,
HFTransformersLanguage,
)


backends = [
SpacyLanguage("tests/custom_test_lang/"),
FasttextLanguage("tests/custom_fasttext_model.bin"),
CountVectorLanguage(n_components=10),
BytePairLanguage("en", vs=1000, dim=25, cache_dir="tests/cache"),
GensimLanguage("tests/cache/custom_gensim_vectors.kv"),
ConveRTLanguage(),
HFTransformersLanguage("sshleifer/tiny-gpt2", framework="tf"),
TFHubLanguage("https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"),
]
Expand Down
12 changes: 7 additions & 5 deletions whatlies/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
from ._fasttext_lang import FasttextLanguage
from ._countvector_lang import CountVectorLanguage
from ._bpemblang import BytePairLanguage
from ._bpemblang import BytePairLanguage as BytePairLang
from ._gensim_lang import GensimLanguage
from ._convert_lang import ConveRTLanguage

from whatlies.error import NotInstalled

try:
from ._convert_lang import ConveRTLanguage
from ._tfhub_lang import TFHubLanguage
except ModuleNotFoundError as e:
TFHubLanguage = NotInstalled("TFHubLanguage", "tfhub")
ConveRTLanguage = NotInstalled("ConveRTLanguage", "tfhub")

try:
from ._sentence_encode_lang import UniversalSentenceLanguage
except ModuleNotFoundError as e:
UniversalSentenceLanguage = NotInstalled("UniversalSentenceLanguage", "tfhub")

try:
from ._hftransformers_lang import HFTransformersLanguage
except ModuleNotFoundError as e:
HFTransformersLanguage = NotInstalled("HFTransformersLanguage", "transformers")


try:
from ._sense2vec_lang import Sense2VecLanguage
except ModuleNotFoundError as e:
Expand All @@ -31,10 +33,10 @@
"Sense2VecLanguage",
"FasttextLanguage",
"CountVectorLanguage",
"BytePairLang",
"BytePairLanguage",
"GensimLanguage",
"ConveRTLanguage",
"TFHubLanguage",
"HFTransformersLanguage",
"UniversalSentenceLanguage",
]
101 changes: 7 additions & 94 deletions whatlies/language/_convert_lang.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,7 @@
from typing import Union, List

import tensorflow_text # noqa: F401
import tensorflow as tf
import tensorflow_hub as tfhub

from whatlies.embedding import Embedding
from whatlies.embeddingset import EmbeddingSet
from whatlies.language._common import SklearnTransformerMixin, HiddenPrints


class ConveRTLanguage(SklearnTransformerMixin):
class ConveRTLanguage:
"""
This object is used to fetch [Embedding][whatlies.embedding.Embedding]s or
[EmbeddingSet][whatlies.embeddingset.EmbeddingSet]s from a
[ConveRT](https://github.com/PolyAI-LDN/polyai-models) model.
This object is meant for retreival, not plotting.
Important:
This object will automatically download a large file if it is not cached yet.
This language model does not contain a vocabulary, so it cannot be used
to retreive similar tokens. Use an `EmbeddingSet` instead.
This language backend might require you to manually install extra dependencies
unless you installed via either;
```
pip install whatlies[tfhub]
pip install whatlies[all]
```
Arguments:
model_id: identifier used for loading the corresponding TFHub module, which could be one of `'convert`, `'convert-multi-context'` or `'convert-ubuntu'`.
Each one of these correspond to a different model as described in [ConveRT manual](https://github.com/PolyAI-LDN/polyai-models#models).
signature: the TFHub signature of the model, which could be one of `'default'`, `'encode_context'`, `'encode_response'` or `'encode_sequence'`.
Note that `'encode_context'` is not currently supported with `'convert-multi-context'` or `'convert-ubuntu'` models.
**Usage**:
```python
> from whatlies.language import ConveRTLanguage
> lang = ConveRTLanguage()
> lang['bank']
> lang = ConveRTLanguage(model_id='convert-multi-context', signature='encode_sequence')
> lang[['bank of the river', 'money on the bank', 'bank']]
```
This model has been deprecated. The original authors took the embeddings down.
"""

MODEL_URL = {
Expand All @@ -61,53 +18,9 @@ class ConveRTLanguage(SklearnTransformerMixin):
]

def __init__(self, model_id: str = "convert", signature: str = "default") -> None:
if model_id not in self.MODEL_URL:
raise ValueError(
f"The `model_id` value should be one of {list(self.MODEL_URL.keys())}"
)
if signature not in self.MODEL_SIGNATURES:
raise ValueError(
f"The `signature` value should be one of {self.MODEL_SIGNATURES}"
)
if signature == "encode_context" and model_id in [
"convert-multi-context",
"convert-ubuntu",
]:
raise NotImplementedError(
"Currently 'encode_context' signature is not support with multi-context and ubuntu models."
)
self.model_id = model_id
self.signature = signature

with HiddenPrints():
self.module = tfhub.load(self.MODEL_URL[self.model_id])
self.model = self.module.signatures[self.signature]

def __getitem__(
self, query: Union[str, List[str]]
) -> Union[Embedding, EmbeddingSet]:
"""
Retreive a single embedding or a set of embeddings.
Arguments:
query: single string or list of strings
**Usage**
pass

```python
> from whatlies.language import ConveRTLanguage
> lang = ConveRTLanguage()
> lang['bank']
> lang = ConveRTLanguage()
> lang[['bank of the river', 'money on the bank', 'bank']]
```
"""
if isinstance(query, str):
query_tensor = tf.convert_to_tensor([query])
encoding = self.model(query_tensor)
if self.signature == "encode_sequence":
vec = encoding["sequence_encoding"].numpy().sum(axis=1)[0]
else:
vec = encoding["default"].numpy()[0]
return Embedding(query, vec)
return EmbeddingSet(*[self[tok] for tok in query])
def __getitem__(self, item):
raise DeprecationWarning(
"This model has been deprecated. The original authors took the embeddings down."
)
46 changes: 46 additions & 0 deletions whatlies/language/_sentence_encode_lang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Union

from ._tfhub_lang import TFHubLanguage


def UniversalSentenceLanguage(variant: str = "base", version: Union[int, None] = None):
"""
Retreive a [universal sentence encoder](https://tfhub.dev/google/collections/universal-sentence-encoder/1) model from tfhub.
You can download specific versions for specific variants. The variants that we support are listed below.
- `"base"`: the base variant (915MB) [link](https://tfhub.dev/google/universal-sentence-encoder/4)
- `"large"`: the large variant (523MB) [link](https://tfhub.dev/google/universal-sentence-encoder-large/5)
- `"qa"`: the variant based on question/answer (528MB) [link](https://tfhub.dev/google/universal-sentence-encoder-qa/3)
- `"multi"`: the multi-language variant (245MB) [link](https://tfhub.dev/google/universal-sentence-encoder-multilingual/3)
- `"multi-large"`: the large multi-language variant (303MB) [link](https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3)
- `"multi-qa"`: the multi-language qa variant (310MB) [link](https://tfhub.dev/google/universal-sentence-encoder-multilingual-qa/3)
TFHub reports that the multi-language models support Arabic, Chinese-simplified, Chinese-traditional,
English, French, German, Italian, Japanese, Korean, Dutch, Polish, Portuguese, Spanish, Thai, Turkish and Russian.
Arguments:
variant: select a specific variant
version: select a specific version, if kept `None` we'll assume the most recent version
"""
urls = {
"base": "https://tfhub.dev/google/universal-sentence-encoder/",
"large": "https://tfhub.dev/google/universal-sentence-encoder-large/",
"qa": "https://tfhub.dev/google/universal-sentence-encoder-qa/",
"multi": "https://tfhub.dev/google/universal-sentence-encoder-multilingual/",
"multi-large": "https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/",
"multi-qa": "https://tfhub.dev/google/universal-sentence-encoder-multilingual-qa/3",
}

versions = {
"base": 4,
"large": 5,
"qa": 3,
"multi": 3,
"multi-large": 3,
"multi-qa": 3,
}

version = versions[variant] if not version else version
url = urls[variant] + str(version)
return TFHubLanguage(url=url)

0 comments on commit a418c49

Please sign in to comment.