Skip to content

Commit

Permalink
Merge pull request #84 from deepghs/dev/ccip
Browse files Browse the repository at this point in the history
dev(narugo): add ccip_merge function
  • Loading branch information
narugo1992 authored May 3, 2024
2 parents 8fff5d3 + 4191a5e commit c25b9e0
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
8 changes: 8 additions & 0 deletions docs/source/api_doc/metrics/ccip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ ccip_clustering



ccip_merge
--------------------------------------------

.. autofunction:: ccip_merge




38 changes: 38 additions & 0 deletions imgutils/metrics/ccip.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

'ccip_default_clustering_params',
'ccip_clustering',

'ccip_merge',
]


Expand Down Expand Up @@ -517,3 +519,39 @@ def _metric(x, y):
assert False, f'Unknown mode for CCIP clustering - {method!r}.' # pragma: no cover

return clustering.labels_.tolist()


def ccip_merge(images: Union[List[_FeatureOrImage], np.ndarray],
size: int = 384, model: str = _DEFAULT_MODEL_NAMES) -> np.ndarray:
"""
Merge multiple feature vectors into a single vector.
:param images: The feature vectors or images to merge.
:type images: Union[List[_FeatureOrImage], numpy.ndarray]
:param size: The size of the image. (default: 384)
:type size: int
:param model: The name of the model. (default: ``ccip-caformer-24-randaug-pruned``)
:type model: str
:return: The merged feature vector.
:rtype: numpy.ndarray
Examples::
>>> from imgutils.metrics import ccip_merge, ccip_batch_differences
>>>
>>> images = [f'ccip/{i}.jpg' for i in range(1, 4)]
>>>
>>> merged = ccip_merge(images)
>>> merged.shape
(768,)
>>>
>>> diffs = ccip_batch_differences([merged, *images])[0, 1:]
>>> diffs
array([0.07437477, 0.0356068 , 0.04396922], dtype=float32)
>>> diffs.mean()
0.05131693
"""
embs = np.stack([_p_feature(img, size, model) for img in images]).astype(np.float32)
lengths = np.linalg.norm(embs, axis=-1)
embs = embs / lengths.reshape(-1, 1)
ret_embedding = embs.mean(axis=0)
return ret_embedding / np.linalg.norm(ret_embedding) * lengths.mean()
62 changes: 60 additions & 2 deletions test/metrics/test_ccip.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import glob
import json
import os.path
from typing import List, Tuple
from functools import lru_cache
from typing import List, Tuple, Dict, Iterator

import numpy as np
import pytest
from hbutils.testing import disable_output
from huggingface_hub import HfFileSystem, HfApi
from natsort import natsorted
from sklearn.metrics import adjusted_rand_score

from imgutils.metrics import ccip_difference, ccip_default_threshold, ccip_extract_feature, ccip_same, ccip_batch_same, \
ccip_clustering
ccip_clustering, ccip_merge, ccip_batch_differences
from test.testings import get_testfile


Expand Down Expand Up @@ -99,6 +102,52 @@ def s_threshold(threshold) -> float:
return threshold + 0.05


MERGE_TAGS = [
'little_red_riding_hood_(grimm)',
'maria_cadenzavna_eve',
'misaka_mikoto',
'dido_(azur_lane)',
'hina_(dress)_(blue_archive)',
'warspite_(kancolle)',
'kallen_kaslana',
"kal'tsit_(arknights)",
'anastasia_(fate)',
"m16a1_(girls'_frontline)",
]

hf_fs = HfFileSystem(token=os.environ.get('HF_TOKEN'))
hf_client = HfApi(token=os.environ.get('HF_TOKEN'))
SRC_REPO = 'deepghs/character_index'


@lru_cache()
def _get_source_list() -> List[dict]:
return json.loads(hf_fs.read_text(f'datasets/{SRC_REPO}/characters.json'))


@lru_cache()
def _get_source_dict() -> Dict[str, dict]:
return {item['tag']: item for item in _get_source_list()}


def list_character_tags() -> Iterator[str]:
for item in _get_source_list():
yield item['tag']


def get_detailed_character_info(tag: str) -> dict:
return _get_source_dict()[tag]


def get_np_feats(tag):
item = get_detailed_character_info(tag)
return np.load(hf_client.hf_hub_download(
repo_id=SRC_REPO,
repo_type='dataset',
filename=f'{item["hprefix"]}/{item["short_tag"]}/feat.npy'
))


@pytest.mark.unittest
class TestMetricCCIP:
def test_ccip_difference(self, img_1, img_2, img_3, img_4, img_5, img_6, img_7, s_threshold):
Expand Down Expand Up @@ -159,3 +208,12 @@ def test_ccip_cluster(self, images_12, images_cids):

with pytest.raises(KeyError):
_ = ccip_clustering(images_12, min_samples=2, method='what_the_fxxk')

@pytest.mark.parametrize(['tag'], [
(tag,) for tag in MERGE_TAGS
])
def test_ccip_merge(self, tag):
feats = get_np_feats(tag)
merged_emb = ccip_merge(feats)
assert ccip_batch_differences([merged_emb, *feats])[0, 1:].mean() <= 0.085
assert ccip_batch_same([merged_emb, *feats])[0, 1:].all()

0 comments on commit c25b9e0

Please sign in to comment.