diff --git a/.github/workflows/wd14.yml b/.github/workflows/wd14.yml index eaa58ddbfef..dbb98e7933e 100644 --- a/.github/workflows/wd14.yml +++ b/.github/workflows/wd14.yml @@ -4,7 +4,7 @@ on: # push: workflow_dispatch: schedule: - - cron: '30 18 * * *' + - cron: '30 18 * * 0' jobs: sync: diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 053b6bbaa2b..91c80e26f0c 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -3,8 +3,9 @@ Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags `_ . """ +import json from functools import lru_cache -from typing import List, Tuple +from typing import List, Tuple, Optional, Dict, Union import numpy as np import onnxruntime @@ -242,3 +243,71 @@ def get_wd14_tags( 'prediction': preds[0].astype(np.float32), } ) + + +@lru_cache() +def _inv_data(model_name: str = _DEFAULT_MODEL_NAME): + data = np.load(hf_hub_download( + repo_id='deepghs/wd14_tagger_with_embeddings', + repo_type='model', + filename=f'{MODEL_NAMES[model_name]}/inv.npz', + )) + return data['best_epi'], data['inv_weights'], data['bias'] + + +def _inv_sigmoid(x): + return np.log(x) - np.log(1 - x) + + +def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, + epi: Optional[float] = None, norm: bool = False) -> np.ndarray: + best_epi, inv_weights, bias = _inv_data(model_name) + eps = 10 ** -(epi if epi is not None else best_epi) + pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps) + inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights + if norm: + inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None] + return inv_emb_output + + +@lru_cache() +def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]: + df_tags = pd.read_csv(hf_hub_download( + repo_id='deepghs/wd14_tagger_with_embeddings', + repo_type='model', + filename=f'{MODEL_NAMES[model_name]}/tags_info.csv', + )) + + from .match import _cached_singular_form, _cache_plural_form + + retval = {} + for i, item in enumerate(df_tags.to_dict('records')): + tags = sorted({item['name'], *json.loads(item['aliases'])}) + for tag in tags: + forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)}) + for tag_form in forms: + retval[tag_form] = (item['name'], i) + + return retval, len(df_tags) + + +def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]], + model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray: + from .format import add_underline + + if isinstance(tags, (list, tuple)): + tags = {tag: 1.0 for tag in tags} + + mapping, width = _wd14_alias_map(model_name) + arr = np.zeros((width,), dtype=np.float32) + # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25 + # arr = np.clip(arr, a_min=0.0, a_max=1.0) + for tag, value in tags.items(): + origin_tag, tag = tag, add_underline(tag) + if tag not in mapping: + raise ValueError(f'Unknown tag {origin_tag!r}.') + + real_tag_name, position = mapping[tag] + arr[position] = value + + return arr diff --git a/requirements-zoo.txt b/requirements-zoo.txt index a74313b8a53..3cbd6963dd1 100644 --- a/requirements-zoo.txt +++ b/requirements-zoo.txt @@ -23,4 +23,5 @@ lighttuner natsort tabulate hfmirror>=0.0.7 -tabulate \ No newline at end of file +tabulate +git+https://github.com/deepghs/waifuc.git@main#egg=waifuc \ No newline at end of file diff --git a/zoo/wd14/inv.py b/zoo/wd14/inv.py index 7999691403e..b2db60a60e8 100644 --- a/zoo/wd14/inv.py +++ b/zoo/wd14/inv.py @@ -63,6 +63,27 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non def inv_sigmoid(x): return np.log(x) - np.log(1 - x) + def is_inv_safe(v_epi): + eps = 10 ** -v_epi + p = np.concatenate([ + np.ones(10).astype(np.float32), + np.zeros(10).astype(np.float32), + ]) + x = np.clip(p, a_min=eps, a_max=1.0 - eps) + y = inv_sigmoid(x) + return not bool(np.isnan(y).any() or np.isinf(y).any()) + + def get_max_safe_epi(tol=1e-6): + sl, sr = 1.0, 30.0 + while sl < sr - tol: + sm = (sl + sr) / 2 + if is_inv_safe(sm): + sl = sm + else: + sr = sm + + return sl + origin = np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_inversion', repo_type='dataset', @@ -71,6 +92,8 @@ def inv_sigmoid(x): predictions = origin['preds'] embeddings = origin['embs'] + max_safe_epi = get_max_safe_epi() + right = min(right, max_safe_epi) records = [] for r in range(rounds): xs, ys = [], [] @@ -109,6 +132,7 @@ def inv_sigmoid(x): rg = right - left left, right = xs[idx] - rg * 0.1, xs[idx] + rg * 0.1 + right = min(right, max_safe_epi) df = pd.DataFrame(records) df = df.sort_values(by=['epi'], ascending=[True]) diff --git a/zoo/wd14/sync.py b/zoo/wd14/sync.py index 7824bd02ace..d664b95c6b0 100644 --- a/zoo/wd14/sync.py +++ b/zoo/wd14/sync.py @@ -16,6 +16,7 @@ from imgutils.tagging.wd14 import MODEL_NAMES from imgutils.utils import open_onnx_model from .inv import _make_inverse +from .tags import _make_tag_info logging.try_init_root(logging.INFO) @@ -42,6 +43,7 @@ def _seg_split(text): _FC_KEYWORDS_FOR_V2 = {'predictions_dense'} _FC_NODE_PREFIXES_FOR_V3 = { + "SwinV2": ('core_model', 'head', 'fc'), "SwinV2_v3": ('core_model', 'head', 'fc'), "ConvNext_v3": ('core_model', 'head', 'fc'), "ViT_v3": ('core_model', 'head'), @@ -131,6 +133,10 @@ def _is_fc(name): else: invertible = False + df = _make_tag_info(model_name) + assert len(df) == _get_model_tags_length(model_name) + df.to_csv(os.path.join(td, MODEL_NAMES[model_name], 'tags_info.csv'), index=False) + records.append({ 'Name': model_name, 'Source Repository': f'[{MODEL_NAMES[model_name]}](https://huggingface.co/{MODEL_NAMES[model_name]})', diff --git a/zoo/wd14/tags.py b/zoo/wd14/tags.py new file mode 100644 index 00000000000..d611c6d609d --- /dev/null +++ b/zoo/wd14/tags.py @@ -0,0 +1,102 @@ +import json +import logging +from functools import lru_cache +from typing import List, Set + +import pandas as pd +from ditk import logging +from huggingface_hub import hf_hub_download +from tqdm import tqdm +from waifuc.source import DanbooruSource +from waifuc.utils import srequest + +from imgutils.tagging.wd14 import MODEL_NAMES, LABEL_FILENAME + + +@lru_cache() +def _db_session(): + s = DanbooruSource(['1girl']) + s._prune_session() + return s.session + + +@lru_cache(maxsize=65536) +def _get_tag_by_id(tag_id: int): + session = _db_session() + return srequest(session, "GET", f'https://danbooru.donmai.us/tags/{tag_id}.json').json() + + +@lru_cache(maxsize=125536) +def _get_tag_by_name(tag_name: str): + session = _db_session() + vs = srequest( + session, 'GET', f'https://danbooru.donmai.us/tags.json', + params={'search[name]': tag_name} + ).json() + return vs[0] if vs else None + + +@lru_cache(maxsize=65536) +def _simple_search_related_tags(tag: str) -> List[str]: + session = _db_session() + tags = [] + for item in srequest( + session, 'GET', 'https://danbooru.donmai.us/tag_aliases.json', + params={ + 'search[name_matches]': tag, + } + ).json(): + if item['consequent_name'] == tag: + tags.append(item['antecedent_name']) + + return tags + + +@lru_cache(maxsize=65536) +def _search_related_tags(tag: str, model_name: str = 'ConvNext') -> List[str]: + existing_names = _tags_name_set(model_name) + tags = [tag] + i = 0 + while i < len(tags): + append_tags = _simple_search_related_tags(tags[i]) + for tag_ in append_tags: + if tag_ not in tags and tag_ not in existing_names: + tags.append(tag_) + + i += 1 + + return tags + + +@lru_cache() +def _tags_list(model_name) -> pd.DataFrame: + return pd.read_csv(hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME)) + + +@lru_cache() +def _tags_name_set(model_name) -> Set[str]: + return set(_tags_list(model_name)['name']) + + +def _make_tag_info(model_name='ConvNext') -> pd.DataFrame: + df = _tags_list(model_name) + records = [] + for item in tqdm(df.to_dict('records')): + if item['category'] != 9: + tag_info = _get_tag_by_id(item['tag_id']) + item['count'] = tag_info['post_count'] + aliases = _search_related_tags(item['name'], model_name) + logging.info(f'Aliases {aliases!r} --> {item["name"]!r}') + item['aliases'] = json.dumps(aliases) + else: + item['aliases'] = json.dumps([item['name']]) + records.append(item) + + df_records = pd.DataFrame(records) + return df_records + + +if __name__ == "__main__": + logging.try_init_root(logging.INFO) + df = _make_tag_info() + df.to_csv('test_tags_info.csv', index=False)