Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
feat: add doccache
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianmtr committed May 27, 2021
1 parent fb6d5a2 commit 1d8cca2
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 0 deletions.
7 changes: 7 additions & 0 deletions indexers/keyvalue/DocCache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# DocCache Executor

This is the hub module for a Cache executor for Jina.

This checks whether a document has been indexed already. It does this checking the hash of the values of the combination of fields you want to cache on. By default, it checks the `.text` field.

**NOTE**: If you use a DocCache in your Flow, it will override the traversal path for subsequent Executors. This means that whatever traversal path you passed to the DocCache will become the new root path (`'r'`). This is done in order to reduce network traffic to the next executor.
112 changes: 112 additions & 0 deletions indexers/keyvalue/DocCache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import hashlib
import os
import pickle
from typing import Tuple, Optional, Dict

from jina import Executor, DocumentArray, requests, Document
from jina.logging import JinaLogger


class _CacheHandler:
"""A handler for loading and serializing the in-memory cache of the DocCache.
:param path: Path to the file from which to build the actual paths.
:param logger: Instance of logger.
"""

def __init__(self, path, logger):
self.path = path
try:
self.id_to_cache_val = pickle.load(open(path + '.ids', 'rb'))
self.cache_val_to_id = pickle.load(open(path + '.cache', 'rb'))
except FileNotFoundError as e:
logger.warning(
f'File path did not exist : {path}.ids or {path}.cache: {e!r}. Creating new CacheHandler...'
)
self.id_to_cache_val = dict()
self.cache_val_to_id = dict()

def close(self):
"""Flushes the in-memory cache to pickle files."""
pickle.dump(self.id_to_cache_val, open(self.path + '.ids', 'wb'))
pickle.dump(self.cache_val_to_id, open(self.path + '.cache', 'wb'))


default_fields = ('text',)


class DocCache(Executor):
"""A cache Executor
Checks if a Document has already been indexed.
If it hasn't, it is kept
If it has been indexed before, it will be removed from the list of Documents
NOTE: The Traversal path used in processing the request will become the root
path (`r`) afterwards. This is required in order to save network traffic.
"""

def __init__(
self,
fields: Optional[Tuple[str]] = None,
default_traversal_paths: Tuple[str] = ('r',),
tag: str = 'cache_hit',
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if fields is None:
fields = default_fields
self.fields = fields
self.tag = tag
self.logger = JinaLogger('DocCache')
self.cache_handler = _CacheHandler(
os.path.join(self.metas.workspace, 'cache'), self.logger
)
self.default_traversal_paths = default_traversal_paths

@requests(on='/index')
def cache(self, docs: DocumentArray, parameters: Dict, **kwargs):
"""Method to handle the index process for caching"""
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
)
docs_to_be_traversed = docs.traverse_flat(traversal_paths)
idx_to_remove = []
for i, d in enumerate(docs_to_be_traversed):
cache_value = DocCache.hash_doc(d, self.fields)
exists = cache_value in self.cache_handler.cache_val_to_id.keys()
if not exists:
self.cache_handler.id_to_cache_val[d.id] = cache_value
self.cache_handler.cache_val_to_id[cache_value] = d.id
else:
idx_to_remove.append(i)
for i in sorted(idx_to_remove, reverse=True):
del docs_to_be_traversed[i]
# when you use a cache the traversal path is assumed from the Cache side
return docs_to_be_traversed

def close(self) -> None:
"""Make sure to flush to file"""
self.cache_handler.close()

@staticmethod
def hash_doc(doc: Document, fields: Tuple[str]) -> bytes:
"""Calculate hash by which we cache.
:param doc: the Document
:param fields: the list of fields
:return: the hash value of the fields
"""
values = doc.get_attributes(*fields)
if not isinstance(values, list):
values = [values]
data = ''
for field, value in zip(fields, values):
data += f'{field}:{value};'
digest = hashlib.sha256(bytes(data.encode('utf8'))).digest()
return digest

@property
def size(self):
return len(self.cache_handler.id_to_cache_val)
1 change: 1 addition & 0 deletions indexers/keyvalue/DocCache/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
jina==2.0.0
Empty file.
17 changes: 17 additions & 0 deletions indexers/keyvalue/DocCache/tests/flow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
jtype: Flow
version: '1'
with:
rest_api: true
port_expose: 9000
return_results: true
executors:
- name: cache
timeout_ready: '-1'
uses:
jtype: DocCache
metas:
workspace: $WORKSPACE
- name: indexer
timeout_ready: '-1'
uses:
jtype: MyIndexer
103 changes: 103 additions & 0 deletions indexers/keyvalue/DocCache/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Tuple, Dict

import numpy as np
from jina import Executor, DocumentArray, requests, Document
from jina.logging import JinaLogger


def _safeget(dct, *keys):
for key in keys:
try:
dct = dct[key]
except KeyError:
return None
return dct


class MyIndexer(Executor):
def __init__(
self,
default_traversal_paths: Tuple[str] = ('r',),
**kwargs,
):
self.logger = JinaLogger('MyIndexer')
self.default_traversal_paths = default_traversal_paths
super().__init__(**kwargs)
self._docs = DocumentArray()

@requests(on='/index')
def index(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
# you cannot change the traversal path here
# it is assumed from the DocCache side
self._docs.extend(docs)

@requests(on=['/search', '/eval'])
def search(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
traversal_paths = parameters.get(
'traversal_paths', self.default_traversal_paths
)
docs_to_search_with = docs.traverse_flat(traversal_paths)
self.logger.info(f'Searching with {len(docs_to_search_with)} documents...')

a = np.stack(docs_to_search_with.get_attributes('embedding'))
b = np.stack(self._docs.get_attributes('embedding'))
q_emb = _ext_A(_norm(a))
d_emb = _ext_B(_norm(b))
dists = _cosine(q_emb, d_emb)
idx, dist = self._get_sorted_top_k(dists, int(parameters['top_k']))
for _q, _ids, _dists in zip(docs_to_search_with, idx, dist):
for _id, _dist in zip(_ids, _dists):
d = Document(self._docs[int(_id)], copy=True)
d.score.value = 1 - _dist
_q.matches.append(d)

@staticmethod
def _get_sorted_top_k(
dist: 'np.array', top_k: int
) -> Tuple['np.ndarray', 'np.ndarray']:
if top_k >= dist.shape[1]:
idx = dist.argsort(axis=1)[:, :top_k]
dist = np.take_along_axis(dist, idx, axis=1)
else:
idx_ps = dist.argpartition(kth=top_k, axis=1)[:, :top_k]
dist = np.take_along_axis(dist, idx_ps, axis=1)
idx_fs = dist.argsort(axis=1)
idx = np.take_along_axis(idx_ps, idx_fs, axis=1)
dist = np.take_along_axis(dist, idx_fs, axis=1)

return idx, dist


def _get_ones(x, y):
return np.ones((x, y))


def _ext_A(A):
nA, dim = A.shape
A_ext = _get_ones(nA, dim * 3)
A_ext[:, dim: 2 * dim] = A
A_ext[:, 2 * dim:] = A ** 2
return A_ext


def _ext_B(B):
nB, dim = B.shape
B_ext = _get_ones(dim * 3, nB)
B_ext[:dim] = (B ** 2).T
B_ext[dim: 2 * dim] = -2.0 * B.T
del B
return B_ext


def _euclidean(A_ext, B_ext):
sqdist = A_ext.dot(B_ext).clip(min=0)
return np.sqrt(sqdist)


def _norm(A):
return A / np.linalg.norm(A, ord=2, axis=1, keepdims=True)


def _cosine(A_norm_ext, B_norm_ext):
return A_norm_ext.dot(B_norm_ext).clip(min=0) / 2

31 changes: 31 additions & 0 deletions indexers/keyvalue/DocCache/tests/test_doc_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os

from jina import Flow

from .helpers import *
from .. import DocCache


def duplicate_docs():
for _ in range(3):
d = Document()
d.text = 'abc'
d.embedding = np.array([1, 3, 4])
yield d


def test_doc_cache(tmpdir):
os.environ['WORKSPACE'] = str(tmpdir)
f = Flow.load_config('flow.yml')
with f:
f.post(
on='/index',
inputs=duplicate_docs(),
)

results = f.post(
on='/search',
inputs=duplicate_docs(),
parameters={'top_k': 3}
)
assert len(results[0].docs[0].matches) == 1

0 comments on commit 1d8cca2

Please sign in to comment.