Skip to content

Commit

Permalink
dev(narugo): add tag lazy mode, ci skip
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Jul 29, 2024
1 parent 5d632f2 commit 052404f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 14 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/wd14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ name: Sync WD14 Models
on:
# push:
workflow_dispatch:
inputs:
tag_lazy_mode:
description: 'Enable Tag Lazy Mode'
type: boolean
default: false
schedule:
- cron: '30 18 * * 0'

Expand Down Expand Up @@ -36,6 +41,16 @@ jobs:
if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
pip install --upgrade build
- name: Enable Tag Lazy Mode
if: ${{ (github.event.inputs.drop_multi || 'false') == 'true' }}
shell: bash
run: |
echo 'TAG_LAZY_MODE=1' >> $GITHUB_ENV
- name: Disable Tag Lazy Mode
if: ${{ (github.event.inputs.drop_multi || 'false') == 'false' }}
shell: bash
run: |
echo 'TAG_LAZY_MODE=' >> $GITHUB_ENV
- name: Sync Models
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down
8 changes: 5 additions & 3 deletions zoo/wd14/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _seg_split(text):
}


def sync():
def sync(tag_lazy_mode: bool = False):
hf_fs = get_hf_fs()

import onnxruntime
Expand Down Expand Up @@ -134,7 +134,7 @@ def _is_fc(name):
else:
invertible = False

df = _make_tag_info(model_name)
df = _make_tag_info(model_name, lazy_mode=tag_lazy_mode)
assert len(df) == _get_model_tags_length(model_name)
df.to_csv(os.path.join(td, MODEL_NAMES[model_name], 'tags_info.csv'), index=False)

Expand Down Expand Up @@ -176,4 +176,6 @@ def _is_fc(name):


if __name__ == '__main__':
sync()
sync(
tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')),
)
46 changes: 35 additions & 11 deletions zoo/wd14/tags.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import logging
import os
from functools import lru_cache
from typing import List, Set
from typing import List, Set, Dict

import numpy as np
import pandas as pd
from ditk import logging
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from waifuc.source import DanbooruSource
Expand Down Expand Up @@ -79,7 +82,24 @@ def _tags_name_set(model_name) -> Set[str]:
return set(_tags_list(model_name)['name'])


def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
@lru_cache()
def _load_all_previous(repository: str = 'deepghs/wd14_tagger_with_embeddings') -> Dict[int, dict]:
hf_fs = get_hf_fs()
d = {}
for path in hf_fs.glob(f'{repository}/**/tags_info.csv'):
relpath = os.path.relpath(path, f'{repository}')
df = pd.read_csv(hf_hub_download(
repo_id=repository,
repo_type='model',
filename=relpath,
)).replace(np.nan, None)
for item in df.to_dict('records'):
if item['tag_id'] not in d:
d[item['tag_id']] = item
return d


def _make_tag_info(model_name='ConvNext', lazy_mode: bool = False) -> pd.DataFrame:
with open(hf_hub_download(
repo_id='deepghs/tags_meta',
repo_type='dataset',
Expand All @@ -88,18 +108,22 @@ def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
attire_tags = json.load(f)

df = _tags_list(model_name)
d = _load_all_previous()
records = []
for item in tqdm(df.to_dict('records')):
if item['category'] != 9:
tag_info = _get_tag_by_id(item['tag_id'])
item['count'] = tag_info['post_count']
aliases = _search_related_tags(item['name'], model_name)
logging.info(f'Aliases {aliases!r} --> {item["name"]!r}')
item['aliases'] = json.dumps(aliases)
if lazy_mode and item['tag_id'] in d:
item = d[item['tag_id']]
else:
item['aliases'] = json.dumps([item['name']])
item['is_core'] = (item['category'] == 0) and (is_basic_character_tag(item['name']))
item['is_attire'] = (item['category'] == 0) and (item['name'] in attire_tags)
if item['category'] != 9:
tag_info = _get_tag_by_id(item['tag_id'])
item['count'] = tag_info['post_count']
aliases = _search_related_tags(item['name'], model_name)
logging.info(f'Aliases {aliases!r} --> {item["name"]!r}')
item['aliases'] = json.dumps(aliases)
else:
item['aliases'] = json.dumps([item['name']])
item['is_core'] = (item['category'] == 0) and (is_basic_character_tag(item['name']))
item['is_attire'] = (item['category'] == 0) and (item['name'] in attire_tags)
records.append(item)

df_records = pd.DataFrame(records)
Expand Down

0 comments on commit 052404f

Please sign in to comment.