Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(narugo): embedding inverse #94

Merged
merged 12 commits into from
May 15, 2024
2 changes: 1 addition & 1 deletion .github/workflows/wd14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
# push:
workflow_dispatch:
schedule:
- cron: '30 18 * * *'
- cron: '30 18 * * 0'

jobs:
sync:
Expand Down
71 changes: 70 additions & 1 deletion imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
Tagging utils based on wd14 v2, inspired by
`SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
import json
from functools import lru_cache
from typing import List, Tuple
from typing import List, Tuple, Optional, Dict, Union

import numpy as np
import onnxruntime
Expand Down Expand Up @@ -242,3 +243,71 @@ def get_wd14_tags(
'prediction': preds[0].astype(np.float32),
}
)


@lru_cache()
def _inv_data(model_name: str = _DEFAULT_MODEL_NAME):
data = np.load(hf_hub_download(
repo_id='deepghs/wd14_tagger_with_embeddings',
repo_type='model',
filename=f'{MODEL_NAMES[model_name]}/inv.npz',
))
return data['best_epi'], data['inv_weights'], data['bias']


def _inv_sigmoid(x):
return np.log(x) - np.log(1 - x)


def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME,
epi: Optional[float] = None, norm: bool = False) -> np.ndarray:
best_epi, inv_weights, bias = _inv_data(model_name)
eps = 10 ** -(epi if epi is not None else best_epi)
pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps)
inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights
if norm:
inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None]
return inv_emb_output


@lru_cache()
def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]:
df_tags = pd.read_csv(hf_hub_download(
repo_id='deepghs/wd14_tagger_with_embeddings',
repo_type='model',
filename=f'{MODEL_NAMES[model_name]}/tags_info.csv',
))

from .match import _cached_singular_form, _cache_plural_form

retval = {}
for i, item in enumerate(df_tags.to_dict('records')):
tags = sorted({item['name'], *json.loads(item['aliases'])})
for tag in tags:
forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)})
for tag_form in forms:
retval[tag_form] = (item['name'], i)

return retval, len(df_tags)


def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]],
model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray:
from .format import add_underline

if isinstance(tags, (list, tuple)):
tags = {tag: 1.0 for tag in tags}

mapping, width = _wd14_alias_map(model_name)
arr = np.zeros((width,), dtype=np.float32)
# arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25
# arr = np.clip(arr, a_min=0.0, a_max=1.0)
for tag, value in tags.items():
origin_tag, tag = tag, add_underline(tag)
if tag not in mapping:
raise ValueError(f'Unknown tag {origin_tag!r}.')

real_tag_name, position = mapping[tag]
arr[position] = value

return arr
3 changes: 2 additions & 1 deletion requirements-zoo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ lighttuner
natsort
tabulate
hfmirror>=0.0.7
tabulate
tabulate
git+https://github.com/deepghs/waifuc.git@main#egg=waifuc
24 changes: 24 additions & 0 deletions zoo/wd14/inv.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non
def inv_sigmoid(x):
return np.log(x) - np.log(1 - x)

def is_inv_safe(v_epi):
eps = 10 ** -v_epi
p = np.concatenate([
np.ones(10).astype(np.float32),
np.zeros(10).astype(np.float32),
])
x = np.clip(p, a_min=eps, a_max=1.0 - eps)
y = inv_sigmoid(x)
return not bool(np.isnan(y).any() or np.isinf(y).any())

def get_max_safe_epi(tol=1e-6):
sl, sr = 1.0, 30.0
while sl < sr - tol:
sm = (sl + sr) / 2
if is_inv_safe(sm):
sl = sm
else:
sr = sm

return sl

origin = np.load(hf_hub_download(
repo_id='deepghs/wd14_tagger_inversion',
repo_type='dataset',
Expand All @@ -71,6 +92,8 @@ def inv_sigmoid(x):
predictions = origin['preds']
embeddings = origin['embs']

max_safe_epi = get_max_safe_epi()
right = min(right, max_safe_epi)
records = []
for r in range(rounds):
xs, ys = [], []
Expand Down Expand Up @@ -109,6 +132,7 @@ def inv_sigmoid(x):

rg = right - left
left, right = xs[idx] - rg * 0.1, xs[idx] + rg * 0.1
right = min(right, max_safe_epi)

df = pd.DataFrame(records)
df = df.sort_values(by=['epi'], ascending=[True])
Expand Down
6 changes: 6 additions & 0 deletions zoo/wd14/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from imgutils.tagging.wd14 import MODEL_NAMES
from imgutils.utils import open_onnx_model
from .inv import _make_inverse
from .tags import _make_tag_info

logging.try_init_root(logging.INFO)

Expand All @@ -42,6 +43,7 @@ def _seg_split(text):

_FC_KEYWORDS_FOR_V2 = {'predictions_dense'}
_FC_NODE_PREFIXES_FOR_V3 = {
"SwinV2": ('core_model', 'head', 'fc'),
"SwinV2_v3": ('core_model', 'head', 'fc'),
"ConvNext_v3": ('core_model', 'head', 'fc'),
"ViT_v3": ('core_model', 'head'),
Expand Down Expand Up @@ -131,6 +133,10 @@ def _is_fc(name):
else:
invertible = False

df = _make_tag_info(model_name)
assert len(df) == _get_model_tags_length(model_name)
df.to_csv(os.path.join(td, MODEL_NAMES[model_name], 'tags_info.csv'), index=False)

records.append({
'Name': model_name,
'Source Repository': f'[{MODEL_NAMES[model_name]}](https://huggingface.co/{MODEL_NAMES[model_name]})',
Expand Down
102 changes: 102 additions & 0 deletions zoo/wd14/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
import logging
from functools import lru_cache
from typing import List, Set

import pandas as pd
from ditk import logging
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from waifuc.source import DanbooruSource
from waifuc.utils import srequest

from imgutils.tagging.wd14 import MODEL_NAMES, LABEL_FILENAME


@lru_cache()
def _db_session():
s = DanbooruSource(['1girl'])
s._prune_session()
return s.session


@lru_cache(maxsize=65536)
def _get_tag_by_id(tag_id: int):
session = _db_session()
return srequest(session, "GET", f'https://danbooru.donmai.us/tags/{tag_id}.json').json()


@lru_cache(maxsize=125536)
def _get_tag_by_name(tag_name: str):
session = _db_session()
vs = srequest(
session, 'GET', f'https://danbooru.donmai.us/tags.json',
params={'search[name]': tag_name}
).json()
return vs[0] if vs else None


@lru_cache(maxsize=65536)
def _simple_search_related_tags(tag: str) -> List[str]:
session = _db_session()
tags = []
for item in srequest(
session, 'GET', 'https://danbooru.donmai.us/tag_aliases.json',
params={
'search[name_matches]': tag,
}
).json():
if item['consequent_name'] == tag:
tags.append(item['antecedent_name'])

return tags


@lru_cache(maxsize=65536)
def _search_related_tags(tag: str, model_name: str = 'ConvNext') -> List[str]:
existing_names = _tags_name_set(model_name)
tags = [tag]
i = 0
while i < len(tags):
append_tags = _simple_search_related_tags(tags[i])
for tag_ in append_tags:
if tag_ not in tags and tag_ not in existing_names:
tags.append(tag_)

i += 1

return tags


@lru_cache()
def _tags_list(model_name) -> pd.DataFrame:
return pd.read_csv(hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME))


@lru_cache()
def _tags_name_set(model_name) -> Set[str]:
return set(_tags_list(model_name)['name'])


def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
df = _tags_list(model_name)
records = []
for item in tqdm(df.to_dict('records')):
if item['category'] != 9:
tag_info = _get_tag_by_id(item['tag_id'])
item['count'] = tag_info['post_count']
aliases = _search_related_tags(item['name'], model_name)
logging.info(f'Aliases {aliases!r} --> {item["name"]!r}')
item['aliases'] = json.dumps(aliases)
else:
item['aliases'] = json.dumps([item['name']])
records.append(item)

df_records = pd.DataFrame(records)
return df_records


if __name__ == "__main__":
logging.try_init_root(logging.INFO)
df = _make_tag_info()
df.to_csv('test_tags_info.csv', index=False)
Loading