Skip to content

Commit

Permalink
Refactor of MinHash to work with a single class and fix the shelve ba…
Browse files Browse the repository at this point in the history
…ckend (#937)

* Initial work for minhash

* Add minhash step redirect

* Add first version of minhash and minhashlsh

* Add unit tests for minhash dedup

* Add pipeline testing deduplication

* Add tests to run with disk backend

* Add tests for the disk and ensure unload

* Add private _datasketch module to include a custom storage configuration for the minhash index

* Add docstrings to the internal classes/functions

* Add docstrings for the user facing classes

* Update src/distilabel/steps/filtering/minhash.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Update src/distilabel/steps/filtering/minhash.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Update tests/integration/test_deduplication.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Update src/distilabel/steps/filtering/minhash.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Update src/distilabel/steps/filtering/minhash.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Add installation dependencies

* Apply comments from code review

* Add nltk as a dependency for the tests

* Update tests and interpretation of keep rows vs duplicates

* Remove disk backend from tests temporarily

* Add note in the docs related to minhash storage on disk

* Update tests to run on dict instead of disk as it never ends on CI

* Fix integration test

* Hide import inside of function to avoid installing it on docs building

* Update command to download nltk

* Allow for a name in the shelve based backend to avoid overwrites

* Refactor MinHash to use a single MinHashDedup class that controls all the process

* Refactor tests to use the new class

* Redirect import to steps level

* Create new disk based storage using diskcache

* Add docstrings to clarify the difference between dict/disk

* Refactor to use diskcache

* Fix docstring example

* Update src/distilabel/steps/filtering/minhash.py

Co-authored-by: Gabriel Martín Blázquez <[email protected]>

* Update definition of the step

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Sep 2, 2024
1 parent 88615c7 commit 4b3c9c0
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 223 deletions.
5 changes: 2 additions & 3 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from distilabel.steps.deita import DeitaFiltering
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
from distilabel.steps.filtering.minhash import MinHash, MinHashLSH
from distilabel.steps.filtering.minhash import MinHashDedup
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
FormatChatGenerationDPO,
Expand Down Expand Up @@ -74,8 +74,7 @@
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
"MinHash",
"MinHashLSH",
"MinHashDedup",
"make_generator_step",
"PushToHub",
"Step",
Expand Down
85 changes: 38 additions & 47 deletions src/distilabel/steps/filtering/_datasketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
creating a PR to `datasketch`.
"""

import shelve
import shutil
import struct
from pathlib import Path
from typing import Callable, Dict, Final, Optional, Tuple
Expand All @@ -30,48 +30,46 @@
from datasketch.storage import ordered_storage as _ordered_storage
from datasketch.storage import unordered_storage as _unordered_storage

SHELVE_DIR: Path = Path.home() / ".cache" / "distilabel"
SHELVE_LIST_NAME: Final[str] = ".shelve_list_storage"
SHELVE_SET_NAME: Final[str] = ".shelve_set_storage"
KEY_VALUE_DISK_DIR: Path = Path.home() / ".cache" / "distilabel" / "key_value_store"
KV_DISK_LIST_NAME: Final[str] = "disckache_list_storage"
KV_DISK_SET_NAME: Final[str] = "diskcache_set_storage"


class ShelveListStorage(OrderedStorage):
"""Key/Value storage using shelve to store the hash tables in disk.
It mimics the behaviour of `datasketch.DictListStorage`.
The only difference is the storage in disk.
The functionality is on purpose to avoid unnecessary errors.
"""
class DiskCacheListStorage(OrderedStorage):
def __init__(self, config, name) -> None:
path = config.get("path", self._get_db_name(name))
try:
from diskcache import Index
except ImportError as e:
raise ImportError(
"`diskcache` is required for disk storage using `MinHashDedup`. "
"Please install it using `pip install diskcache`."
) from e

def __init__(self, config) -> None:
path = config.get("path", self._get_db_name())
# Read about writeback here: https://docs.python.org/3/library/shelve.html#shelve.open
writeback = config.get("writeback", True)
# The flag is set to "n" to recreate the file always, we assume
# every pipeline works on it's own and recomputes it instead of trusting
# the cache.
self._db = shelve.open(path, writeback=writeback, flag="n")
# Start with a clean file on each pipeline
if Path(path).exists():
shutil.rmtree(path)
self._db = Index(path)

def _get_db_name(self):
return str(SHELVE_DIR / SHELVE_LIST_NAME)
def _get_db_name(self, name):
return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_LIST_NAME}")

def keys(self):
return self._db.keys()

def get(self, key):
return self._db.get(str(key), [])
return self._db.get(key, [])

def remove(self, *keys):
for key in keys:
del self._db[str(key)]
self._db.clear()

def remove_val(self, key, val):
self._db[str(key)].remove(val)
self.get(key).remove(val)

def insert(self, key, *vals, **kwargs):
key = str(key)
if not self._db.get(key):
self._db[key] = []
self._db[key].extend(vals)
res = self.get(key)
res.extend(vals)
self._db[key] = res

def size(self):
return len(self._db)
Expand All @@ -83,42 +81,35 @@ def has_key(self, key):
return key in self._db

def close(self):
self._db.close()

self._db._cache.close()

class ShelveSetStorage(UnorderedStorage, ShelveListStorage):
"""Key/Value storage using shelve to store the hash tables in disk.
It mimics the behaviour of `datasketch.DictSetStorage`.
The only difference is the storage in disk.
The functionality is on purpose to avoid unnecessary errors.
"""

def _get_db_name(self):
return str(SHELVE_DIR / SHELVE_SET_NAME)
class DiskCacheSetStorage(UnorderedStorage, DiskCacheListStorage):
def _get_db_name(self, name):
return str(KEY_VALUE_DISK_DIR / f"{name}_{KV_DISK_SET_NAME}")

def get(self, key):
return self._db.get(str(key), set())
return self._db.get(key, set())

def insert(self, key, *vals, **kwargs):
key = str(key)
if not self._db.get(key):
self._db[key] = set()
self._db[key].update(vals)
res = self.get(key)
res.update(vals)
self._db[key] = res


def ordered_storage(config, name=None):
"""Copy of `datasketch.storage.ordered_storage` with the addition of `ShelveListStorage`."""
tp = config["type"]
if tp == "disk":
return ShelveListStorage(config)
return DiskCacheListStorage(config, name=name)
return _ordered_storage(config, name=name)


def unordered_storage(config, name=None):
"""Copy of `datasketch.storage.ordered_storage` with the addition of `ShelveSetStorage`."""
tp = config["type"]
if tp == "disk":
return ShelveSetStorage(config)
return DiskCacheSetStorage(config, name=name)
return _unordered_storage(config, name=name)


Expand Down Expand Up @@ -192,8 +183,8 @@ def __init__(
self.keys = ordered_storage(storage_config, name=b"".join([basename, b"_keys"]))

def close(self):
"""Closes the shelve objects."""
if isinstance(self.hashtables[0], ShelveListStorage):
"""Closes the internal connections."""
if isinstance(self.hashtables[0], DiskCacheListStorage):
for ht in self.hashtables:
ht.close()
self.keys.close()
Loading

0 comments on commit 4b3c9c0

Please sign in to comment.