This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb6d5a2
commit 1d8cca2
Showing
7 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# DocCache Executor | ||
|
||
This is the hub module for a Cache executor for Jina. | ||
|
||
This checks whether a document has been indexed already. It does this checking the hash of the values of the combination of fields you want to cache on. By default, it checks the `.text` field. | ||
|
||
**NOTE**: If you use a DocCache in your Flow, it will override the traversal path for subsequent Executors. This means that whatever traversal path you passed to the DocCache will become the new root path (`'r'`). This is done in order to reduce network traffic to the next executor. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import hashlib | ||
import os | ||
import pickle | ||
from typing import Tuple, Optional, Dict | ||
|
||
from jina import Executor, DocumentArray, requests, Document | ||
from jina.logging import JinaLogger | ||
|
||
|
||
class _CacheHandler: | ||
"""A handler for loading and serializing the in-memory cache of the DocCache. | ||
:param path: Path to the file from which to build the actual paths. | ||
:param logger: Instance of logger. | ||
""" | ||
|
||
def __init__(self, path, logger): | ||
self.path = path | ||
try: | ||
self.id_to_cache_val = pickle.load(open(path + '.ids', 'rb')) | ||
self.cache_val_to_id = pickle.load(open(path + '.cache', 'rb')) | ||
except FileNotFoundError as e: | ||
logger.warning( | ||
f'File path did not exist : {path}.ids or {path}.cache: {e!r}. Creating new CacheHandler...' | ||
) | ||
self.id_to_cache_val = dict() | ||
self.cache_val_to_id = dict() | ||
|
||
def close(self): | ||
"""Flushes the in-memory cache to pickle files.""" | ||
pickle.dump(self.id_to_cache_val, open(self.path + '.ids', 'wb')) | ||
pickle.dump(self.cache_val_to_id, open(self.path + '.cache', 'wb')) | ||
|
||
|
||
default_fields = ('text',) | ||
|
||
|
||
class DocCache(Executor): | ||
"""A cache Executor | ||
Checks if a Document has already been indexed. | ||
If it hasn't, it is kept | ||
If it has been indexed before, it will be removed from the list of Documents | ||
NOTE: The Traversal path used in processing the request will become the root | ||
path (`r`) afterwards. This is required in order to save network traffic. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
fields: Optional[Tuple[str]] = None, | ||
default_traversal_paths: Tuple[str] = ('r',), | ||
tag: str = 'cache_hit', | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
if fields is None: | ||
fields = default_fields | ||
self.fields = fields | ||
self.tag = tag | ||
self.logger = JinaLogger('DocCache') | ||
self.cache_handler = _CacheHandler( | ||
os.path.join(self.metas.workspace, 'cache'), self.logger | ||
) | ||
self.default_traversal_paths = default_traversal_paths | ||
|
||
@requests(on='/index') | ||
def cache(self, docs: DocumentArray, parameters: Dict, **kwargs): | ||
"""Method to handle the index process for caching""" | ||
traversal_paths = parameters.get( | ||
'traversal_paths', self.default_traversal_paths | ||
) | ||
docs_to_be_traversed = docs.traverse_flat(traversal_paths) | ||
idx_to_remove = [] | ||
for i, d in enumerate(docs_to_be_traversed): | ||
cache_value = DocCache.hash_doc(d, self.fields) | ||
exists = cache_value in self.cache_handler.cache_val_to_id.keys() | ||
if not exists: | ||
self.cache_handler.id_to_cache_val[d.id] = cache_value | ||
self.cache_handler.cache_val_to_id[cache_value] = d.id | ||
else: | ||
idx_to_remove.append(i) | ||
for i in sorted(idx_to_remove, reverse=True): | ||
del docs_to_be_traversed[i] | ||
# when you use a cache the traversal path is assumed from the Cache side | ||
return docs_to_be_traversed | ||
|
||
def close(self) -> None: | ||
"""Make sure to flush to file""" | ||
self.cache_handler.close() | ||
|
||
@staticmethod | ||
def hash_doc(doc: Document, fields: Tuple[str]) -> bytes: | ||
"""Calculate hash by which we cache. | ||
:param doc: the Document | ||
:param fields: the list of fields | ||
:return: the hash value of the fields | ||
""" | ||
values = doc.get_attributes(*fields) | ||
if not isinstance(values, list): | ||
values = [values] | ||
data = '' | ||
for field, value in zip(fields, values): | ||
data += f'{field}:{value};' | ||
digest = hashlib.sha256(bytes(data.encode('utf8'))).digest() | ||
return digest | ||
|
||
@property | ||
def size(self): | ||
return len(self.cache_handler.id_to_cache_val) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
jina==2.0.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
jtype: Flow | ||
version: '1' | ||
with: | ||
rest_api: true | ||
port_expose: 9000 | ||
return_results: true | ||
executors: | ||
- name: cache | ||
timeout_ready: '-1' | ||
uses: | ||
jtype: DocCache | ||
metas: | ||
workspace: $WORKSPACE | ||
- name: indexer | ||
timeout_ready: '-1' | ||
uses: | ||
jtype: MyIndexer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import Tuple, Dict | ||
|
||
import numpy as np | ||
from jina import Executor, DocumentArray, requests, Document | ||
from jina.logging import JinaLogger | ||
|
||
|
||
def _safeget(dct, *keys): | ||
for key in keys: | ||
try: | ||
dct = dct[key] | ||
except KeyError: | ||
return None | ||
return dct | ||
|
||
|
||
class MyIndexer(Executor): | ||
def __init__( | ||
self, | ||
default_traversal_paths: Tuple[str] = ('r',), | ||
**kwargs, | ||
): | ||
self.logger = JinaLogger('MyIndexer') | ||
self.default_traversal_paths = default_traversal_paths | ||
super().__init__(**kwargs) | ||
self._docs = DocumentArray() | ||
|
||
@requests(on='/index') | ||
def index(self, docs: 'DocumentArray', parameters: Dict, **kwargs): | ||
# you cannot change the traversal path here | ||
# it is assumed from the DocCache side | ||
self._docs.extend(docs) | ||
|
||
@requests(on=['/search', '/eval']) | ||
def search(self, docs: 'DocumentArray', parameters: Dict, **kwargs): | ||
traversal_paths = parameters.get( | ||
'traversal_paths', self.default_traversal_paths | ||
) | ||
docs_to_search_with = docs.traverse_flat(traversal_paths) | ||
self.logger.info(f'Searching with {len(docs_to_search_with)} documents...') | ||
|
||
a = np.stack(docs_to_search_with.get_attributes('embedding')) | ||
b = np.stack(self._docs.get_attributes('embedding')) | ||
q_emb = _ext_A(_norm(a)) | ||
d_emb = _ext_B(_norm(b)) | ||
dists = _cosine(q_emb, d_emb) | ||
idx, dist = self._get_sorted_top_k(dists, int(parameters['top_k'])) | ||
for _q, _ids, _dists in zip(docs_to_search_with, idx, dist): | ||
for _id, _dist in zip(_ids, _dists): | ||
d = Document(self._docs[int(_id)], copy=True) | ||
d.score.value = 1 - _dist | ||
_q.matches.append(d) | ||
|
||
@staticmethod | ||
def _get_sorted_top_k( | ||
dist: 'np.array', top_k: int | ||
) -> Tuple['np.ndarray', 'np.ndarray']: | ||
if top_k >= dist.shape[1]: | ||
idx = dist.argsort(axis=1)[:, :top_k] | ||
dist = np.take_along_axis(dist, idx, axis=1) | ||
else: | ||
idx_ps = dist.argpartition(kth=top_k, axis=1)[:, :top_k] | ||
dist = np.take_along_axis(dist, idx_ps, axis=1) | ||
idx_fs = dist.argsort(axis=1) | ||
idx = np.take_along_axis(idx_ps, idx_fs, axis=1) | ||
dist = np.take_along_axis(dist, idx_fs, axis=1) | ||
|
||
return idx, dist | ||
|
||
|
||
def _get_ones(x, y): | ||
return np.ones((x, y)) | ||
|
||
|
||
def _ext_A(A): | ||
nA, dim = A.shape | ||
A_ext = _get_ones(nA, dim * 3) | ||
A_ext[:, dim: 2 * dim] = A | ||
A_ext[:, 2 * dim:] = A ** 2 | ||
return A_ext | ||
|
||
|
||
def _ext_B(B): | ||
nB, dim = B.shape | ||
B_ext = _get_ones(dim * 3, nB) | ||
B_ext[:dim] = (B ** 2).T | ||
B_ext[dim: 2 * dim] = -2.0 * B.T | ||
del B | ||
return B_ext | ||
|
||
|
||
def _euclidean(A_ext, B_ext): | ||
sqdist = A_ext.dot(B_ext).clip(min=0) | ||
return np.sqrt(sqdist) | ||
|
||
|
||
def _norm(A): | ||
return A / np.linalg.norm(A, ord=2, axis=1, keepdims=True) | ||
|
||
|
||
def _cosine(A_norm_ext, B_norm_ext): | ||
return A_norm_ext.dot(B_norm_ext).clip(min=0) / 2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
|
||
from jina import Flow | ||
|
||
from .helpers import * | ||
from .. import DocCache | ||
|
||
|
||
def duplicate_docs(): | ||
for _ in range(3): | ||
d = Document() | ||
d.text = 'abc' | ||
d.embedding = np.array([1, 3, 4]) | ||
yield d | ||
|
||
|
||
def test_doc_cache(tmpdir): | ||
os.environ['WORKSPACE'] = str(tmpdir) | ||
f = Flow.load_config('flow.yml') | ||
with f: | ||
f.post( | ||
on='/index', | ||
inputs=duplicate_docs(), | ||
) | ||
|
||
results = f.post( | ||
on='/search', | ||
inputs=duplicate_docs(), | ||
parameters={'top_k': 3} | ||
) | ||
assert len(results[0].docs[0].matches) == 1 |