From 052404f415864e2d9255c0a43513ca6b4fdf8cb1 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 29 Jul 2024 17:03:51 +0800 Subject: [PATCH] dev(narugo): add tag lazy mode, ci skip --- .github/workflows/wd14.yml | 15 +++++++++++++ zoo/wd14/sync.py | 8 ++++--- zoo/wd14/tags.py | 46 +++++++++++++++++++++++++++++--------- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/.github/workflows/wd14.yml b/.github/workflows/wd14.yml index dbb98e7933..a2b757ca33 100644 --- a/.github/workflows/wd14.yml +++ b/.github/workflows/wd14.yml @@ -3,6 +3,11 @@ name: Sync WD14 Models on: # push: workflow_dispatch: + inputs: + tag_lazy_mode: + description: 'Enable Tag Lazy Mode' + type: boolean + default: false schedule: - cron: '30 18 * * 0' @@ -36,6 +41,16 @@ jobs: if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi pip install --upgrade build + - name: Enable Tag Lazy Mode + if: ${{ (github.event.inputs.drop_multi || 'false') == 'true' }} + shell: bash + run: | + echo 'TAG_LAZY_MODE=1' >> $GITHUB_ENV + - name: Disable Tag Lazy Mode + if: ${{ (github.event.inputs.drop_multi || 'false') == 'false' }} + shell: bash + run: | + echo 'TAG_LAZY_MODE=' >> $GITHUB_ENV - name: Sync Models env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/zoo/wd14/sync.py b/zoo/wd14/sync.py index 92f30458c8..59bfeeffd7 100644 --- a/zoo/wd14/sync.py +++ b/zoo/wd14/sync.py @@ -51,7 +51,7 @@ def _seg_split(text): } -def sync(): +def sync(tag_lazy_mode: bool = False): hf_fs = get_hf_fs() import onnxruntime @@ -134,7 +134,7 @@ def _is_fc(name): else: invertible = False - df = _make_tag_info(model_name) + df = _make_tag_info(model_name, lazy_mode=tag_lazy_mode) assert len(df) == _get_model_tags_length(model_name) df.to_csv(os.path.join(td, MODEL_NAMES[model_name], 'tags_info.csv'), index=False) @@ -176,4 +176,6 @@ def _is_fc(name): if __name__ == '__main__': - sync() + sync( + tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')), + ) diff --git a/zoo/wd14/tags.py b/zoo/wd14/tags.py index 773fcbffa7..f122d9eb79 100644 --- a/zoo/wd14/tags.py +++ b/zoo/wd14/tags.py @@ -1,10 +1,13 @@ import json import logging +import os from functools import lru_cache -from typing import List, Set +from typing import List, Set, Dict +import numpy as np import pandas as pd from ditk import logging +from hfutils.operate import get_hf_fs from huggingface_hub import hf_hub_download from tqdm import tqdm from waifuc.source import DanbooruSource @@ -79,7 +82,24 @@ def _tags_name_set(model_name) -> Set[str]: return set(_tags_list(model_name)['name']) -def _make_tag_info(model_name='ConvNext') -> pd.DataFrame: +@lru_cache() +def _load_all_previous(repository: str = 'deepghs/wd14_tagger_with_embeddings') -> Dict[int, dict]: + hf_fs = get_hf_fs() + d = {} + for path in hf_fs.glob(f'{repository}/**/tags_info.csv'): + relpath = os.path.relpath(path, f'{repository}') + df = pd.read_csv(hf_hub_download( + repo_id=repository, + repo_type='model', + filename=relpath, + )).replace(np.nan, None) + for item in df.to_dict('records'): + if item['tag_id'] not in d: + d[item['tag_id']] = item + return d + + +def _make_tag_info(model_name='ConvNext', lazy_mode: bool = False) -> pd.DataFrame: with open(hf_hub_download( repo_id='deepghs/tags_meta', repo_type='dataset', @@ -88,18 +108,22 @@ def _make_tag_info(model_name='ConvNext') -> pd.DataFrame: attire_tags = json.load(f) df = _tags_list(model_name) + d = _load_all_previous() records = [] for item in tqdm(df.to_dict('records')): - if item['category'] != 9: - tag_info = _get_tag_by_id(item['tag_id']) - item['count'] = tag_info['post_count'] - aliases = _search_related_tags(item['name'], model_name) - logging.info(f'Aliases {aliases!r} --> {item["name"]!r}') - item['aliases'] = json.dumps(aliases) + if lazy_mode and item['tag_id'] in d: + item = d[item['tag_id']] else: - item['aliases'] = json.dumps([item['name']]) - item['is_core'] = (item['category'] == 0) and (is_basic_character_tag(item['name'])) - item['is_attire'] = (item['category'] == 0) and (item['name'] in attire_tags) + if item['category'] != 9: + tag_info = _get_tag_by_id(item['tag_id']) + item['count'] = tag_info['post_count'] + aliases = _search_related_tags(item['name'], model_name) + logging.info(f'Aliases {aliases!r} --> {item["name"]!r}') + item['aliases'] = json.dumps(aliases) + else: + item['aliases'] = json.dumps([item['name']]) + item['is_core'] = (item['category'] == 0) and (is_basic_character_tag(item['name'])) + item['is_attire'] = (item['category'] == 0) and (item['name'] in attire_tags) records.append(item) df_records = pd.DataFrame(records)