Skip to content

Commit

Permalink
dev(narugo): save the data source
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 3, 2024
1 parent 1472a18 commit 31127c5
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 18 deletions.
3 changes: 3 additions & 0 deletions felinewhisker/datasource/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseDataSource
from .cheesechaser import CheeseChaserDataSource
from .local import LocalDataSource
12 changes: 7 additions & 5 deletions felinewhisker/datasource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from os import PathLike
from typing import Union, Any, Optional, ContextManager, Iterator
from typing import Union, Any, Optional, ContextManager, Iterator, Callable

from PIL import Image
from hbutils.random import random_sha1_with_timestamp
Expand All @@ -13,7 +13,7 @@

@dataclass
class ImageItem:
id: Optional[Union[int, str]]
id: str
image: Union[str, PathLike, Image.Image]
annotation: Optional[Any]

Expand Down Expand Up @@ -56,11 +56,12 @@ def make_file(self, max_size: int = 2048, format: str = 'webp', quality: Optiona


class BaseDataSource:
def __init__(self):
def __init__(self, fn_contains_id: Optional[Callable[[str], bool]] = None):
self._status = 'idle'
self._fn_contains_id = fn_contains_id or (lambda x: False)

def _iter(self):
raise NotImplementedError
raise NotImplementedError # pragma: no cover

def _init(self):
pass
Expand Down Expand Up @@ -109,4 +110,5 @@ def __iter__(self) -> Iterator[ImageItem]:
annotate = None

id_ = id_ or random_sha1_with_timestamp()
yield ImageItem(id_, v, annotate)
if not self._fn_contains_id(id_):
yield ImageItem(id_, v, annotate)
40 changes: 40 additions & 0 deletions felinewhisker/datasource/cheesechaser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional, Callable, Union, Iterator

from cheesechaser.datapool import DataPool
from cheesechaser.pipe import Pipe, SimpleImagePipe, PipeItem
from hbutils.string import underscore

from .base import BaseDataSource


class CheeseChaserDataSource(BaseDataSource):
def __init__(self, source: Union[DataPool, Pipe], id_generator: Iterator[Union[str, int]],
source_id: Optional[str] = None, fn_contains_id: Optional[Callable[[str], bool]] = None):
if isinstance(source, DataPool):
self._pipe = SimpleImagePipe(source)
default_source_id = underscore(source.__class__.__name__.replace('DataPool', ''))
elif isinstance(source, Pipe):
self._pipe = source
default_source_id = None
else:
raise TypeError(f'Unknown source type - {source!r}.')
self._source_id = source_id or default_source_id
self._id_generator = id_generator
BaseDataSource.__init__(self, fn_contains_id=fn_contains_id)

def _cid_to_id(self, cid) -> str:
if self._source_id:
return f'cheesechaser__{self._source_id}__{cid}'
else:
return f'cheesechaser__{cid}'

def _iter_cids(self):
for cid in self._id_generator:
if not self._fn_contains_id(self._cid_to_id(cid)):
yield cid

def _iter(self):
with self._pipe.batch_retrieve(self._iter_cids()) as session:
for item in session:
item: PipeItem
yield self._cid_to_id(item.id), item.data, None
10 changes: 6 additions & 4 deletions felinewhisker/datasource/local.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import mimetypes
import os.path
import re
from typing import Optional
from typing import Optional, Callable

from .base import BaseDataSource

mimetypes.add_type('image/webp', '.webp')


class LocalDataSource(BaseDataSource):
def __init__(self, local_dir: str, source_id: Optional[str] = None):
BaseDataSource.__init__(self)
def __init__(self, local_dir: str, source_id: Optional[str] = None,
fn_contains_id: Optional[Callable[[str], bool]] = None):
BaseDataSource.__init__(self, fn_contains_id=fn_contains_id)
self.local_dir = os.path.abspath(os.path.normpath(os.path.expanduser(os.path.normcase(local_dir))))
self.source_id = source_id or re.sub(r'[\W_]+', '_', self.local_dir).strip('_')

Expand All @@ -22,4 +23,5 @@ def _iter(self):
if mimetype.startswith('image/'):
file_token = re.sub(r'[\W_]+', '_', os.path.relpath(path, self.local_dir)).strip('_')
id_ = f'localdir__{self.source_id}__{file_token}'
yield id_, path, None
if not self._fn_contains_id(id_):
yield id_, path, None
25 changes: 20 additions & 5 deletions felinewhisker/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@

class WriterSession:
def __init__(self, author: Optional[str], checker: AnnotationChecker,
save_func: Callable[[str, str, str], None]):
fn_save: Callable[[str, str, str], None], fn_contains_id: Callable[[str], bool]):
self._author = author
self._checker = checker
self._token = random_sha1_with_timestamp()
if self._author:
self._token = f'{self._token}__{self._author}'
self._storage_tmpdir = TemporaryDirectory()
self._records = {}
self._save_func = save_func
self._fn_save = fn_save
self._fn_contains_id = fn_contains_id
self._lock = Lock()

def is_id_duplicated(self, id_: str) -> bool:
with self._lock:
return id_ in self._records or self._fn_contains_id(id_)

def add(self, id_: str, image_file: str, annotation):
with self._lock:
if annotation is not None:
Expand Down Expand Up @@ -62,6 +67,10 @@ def __len__(self):
with self._lock:
return len(self._records)

def __contains__(self, item):
with self._lock:
return item in self._records

def _save(self):
with TemporaryDirectory() as td:
records = []
Expand All @@ -78,7 +87,7 @@ def _save(self):
data_file = os.path.join(td, 'data.parquet')
df = pd.DataFrame(records)
df.to_parquet(data_file, engine='pyarrow', index=False)
self._save_func(tar_file, data_file, self._token)
self._fn_save(tar_file, data_file, self._token)

def save(self):
with self._lock:
Expand All @@ -105,6 +114,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class DatasetRepository:
def __init__(self):
self.meta_info = None
self._exist_ids = None
self._annotation_checker: Optional[AnnotationChecker] = None
self._lock = Lock()
self._sync()
Expand All @@ -119,7 +129,7 @@ def _squash(self):
raise NotImplementedError # pragma: no cover

def _sync(self):
self.meta_info = self._read()
self.meta_info, self._exist_ids = self._read()
self._annotation_checker = parse_annotation_checker_from_meta(self.meta_info)

def squash(self):
Expand All @@ -137,5 +147,10 @@ def write(self, author: Optional[str] = None):
return WriterSession(
author=author,
checker=self._annotation_checker,
save_func=self._write,
fn_save=self._write,
fn_contains_id=lambda id_: id_ in self._exist_ids,
)

def contains_id(self, id_: str):
with self._lock:
return id_ in self._exist_ids
31 changes: 29 additions & 2 deletions felinewhisker/repository/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import os
import shutil
from typing import Optional, List
Expand All @@ -8,6 +9,7 @@
from PIL import Image
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hfutils.cache import delete_detached_cache
from hfutils.index import tar_get_index_info, hf_tar_file_download
from hfutils.operate import upload_directory_as_directory, get_hf_fs, get_hf_client
from hfutils.utils import hf_normpath, hf_fs_path, parse_hf_fs_path
Expand Down Expand Up @@ -44,7 +46,10 @@ def _write(self, tar_file: str, data_file: str, token: str):
df = pd.DataFrame(records)
df.to_parquet(dst_data_file, engine='pyarrow', index=False)

named_authors = set(filter(bool, df['author'].tolist()))
if len(df) > 0:
named_authors = set(filter(bool, df['author'].tolist()))
else:
named_authors = []
pack_name = os.path.basename(dst_tar_file)
if named_authors:
commit_message = f'Add package with {plural_word(len(df), "sample")} contributed ' \
Expand All @@ -61,17 +66,36 @@ def _write(self, tar_file: str, data_file: str, token: str):
)

def _read(self):
delete_detached_cache(repo_id=self._repo_id, repo_type='dataset')
hf_fs = get_hf_fs(hf_token=os.environ.get('HF_TOKEN'))
hf_client = get_hf_client(hf_token=os.environ['HF_TOKEN'])

meta_info = json.loads(hf_fs.read_text(hf_fs_path(
repo_id=self._repo_id,
repo_type='dataset',
revision=self._revision,
filename='meta.json',
)))
return meta_info
if hf_fs.exists(hf_fs_path(
repo_id=self._repo_id,
repo_type='dataset',
revision=self._revision,
filename='data.parquet'
)):
df = pd.read_parquet(hf_client.hf_hub_download(
repo_id=self._repo_id,
repo_type='dataset',
revision=self._revision,
filename='data.parquet'
))
exist_ids = set(df['id'])
else:
exist_ids = set()

return meta_info, exist_ids

def _squash(self):
delete_detached_cache(repo_id=self._repo_id, repo_type='dataset')
hf_fs = get_hf_fs(hf_token=os.environ.get('HF_TOKEN'))
hf_client = get_hf_client(hf_token=os.environ.get('HF_TOKEN'))

Expand Down Expand Up @@ -128,6 +152,9 @@ def _load_image_by_id(id_: str):

with TemporaryDirectory() as td:
df = pd.DataFrame(list(records.values()))
if len(df) == 0:
logging.warning('No samples in total, squash operation cancelled.')
return
df = df.sort_values(by=['updated_at', 'id'], ascending=[False, True])
df.to_parquet(os.path.join(td, 'data.parquet'), engine='pyarrow', index=False)

Expand Down
9 changes: 8 additions & 1 deletion felinewhisker/repository/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import json
import logging
import os.path
import shutil
from typing import Optional, List
Expand Down Expand Up @@ -44,7 +45,11 @@ def _write(self, tar_file: str, data_file: str, token: str):
def _read(self):
with open(self._meta_info_file, 'r') as f:
meta_info = json.load(f)
return meta_info
if os.path.exists(self._data_file):
exist_ids = set(pd.read_parquet(self._data_file)['id'])
else:
exist_ids = set()
return meta_info, exist_ids

def _squash(self):
data_file = os.path.join(self._repo_dir, 'data.parquet')
Expand All @@ -60,6 +65,8 @@ def _squash(self):
records[item['id']] = item
files_to_drop.append(file)
df = pd.DataFrame(list(records.values()))
if len(df) == 0:
logging.warning('No samples in total, squash operation cancelled.')
df = df.sort_values(by=['updated_at', 'id'], ascending=[False, True])
df.to_parquet(data_file, engine='pyarrow', index=False)
for file in files_to_drop:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ numpy<2
pandas
pyyaml>=6
pyarrow
tabulate
tabulate
git+https://github.com/deepghs/cheesechaser.git

0 comments on commit 31127c5

Please sign in to comment.