Skip to content

Commit

Permalink
dev(narugo): no tags, ci skip
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 15, 2024
1 parent d3dba8f commit c4685af
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
3 changes: 2 additions & 1 deletion requirements-zoo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ lighttuner
natsort
tabulate
hfmirror>=0.0.7
tabulate
tabulate
git+https://github.com/deepghs/waifuc.git@main#egg=waifuc
99 changes: 99 additions & 0 deletions zoo/wd14/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import logging
from functools import lru_cache
from typing import List, Set

import pandas as pd
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from waifuc.source import DanbooruSource

from imgutils.tagging.wd14 import MODEL_NAMES, LABEL_FILENAME


@lru_cache()
def _db_session():
s = DanbooruSource(['1girl'])
s._prune_session()
return s.session


@lru_cache(maxsize=65536)
def _get_tag_by_id(tag_id: int):
session = _db_session()
return session.get(f'https://danbooru.donmai.us/tags/{tag_id}.json').json()


@lru_cache(maxsize=125536)
def _get_tag_by_name(tag_name: str):
session = _db_session()
vs = session.get(
f'https://danbooru.donmai.us/tags.json',
params={'search[name]': tag_name}
).json()
return vs[0] if vs else None


@lru_cache(maxsize=65536)
def _simple_search_related_tags(tag: str) -> List[str]:
session = _db_session()
tags = []
for item in session.get(
'https://danbooru.donmai.us/tag_aliases.json',
params={
'search[name_matches]': tag,
}
).json():
if item['consequent_name'] == tag:
tags.append(item['antecedent_name'])

return tags


@lru_cache(maxsize=65536)
def _search_related_tags(tag: str, model_name: str = 'ConvNext') -> List[str]:
existing_names = _tags_name_set(model_name)
tags = [tag]
i = 0
while i < len(tags):
append_tags = _simple_search_related_tags(tags[i])
for tag_ in append_tags:
if tag_ not in tags and tag_ not in existing_names:
tags.append(tag_)

i += 1

return tags


@lru_cache()
def _tags_list(model_name) -> pd.DataFrame:
return pd.read_csv(hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME))


@lru_cache()
def _tags_name_set(model_name) -> Set[str]:
return set(_tags_list(model_name)['name'])


def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
df = _tags_list(model_name)
records = []
for item in tqdm(df.to_dict('records')):
if item['category'] != 9:
tag_info = _get_tag_by_id(item['tag_id'])
item['count'] = tag_info['post_count']
aliases = _search_related_tags(item['name'], model_name)
logging.info(f'Aliases {aliases!r} --> {item["name"]!r}')
item['aliases'] = json.dumps(aliases)
else:
item['aliases'] = json.dumps([item['name']])
records.append(item)

df_records = pd.DataFrame(records)
return df_records


if __name__ == "__main__":
df = _make_tag_info()
df.to_csv('test_tags_info.csv', index=False)

0 comments on commit c4685af

Please sign in to comment.