diff --git a/.github/workflows/wd14.yml b/.github/workflows/wd14.yml index 45812a3aa8..22a42e30dc 100644 --- a/.github/workflows/wd14.yml +++ b/.github/workflows/wd14.yml @@ -4,6 +4,10 @@ on: # push: workflow_dispatch: inputs: + models: + description: 'Models To Make' + type: str + default: '' tag_lazy_mode: description: 'Enable Tag Lazy Mode' type: boolean @@ -55,5 +59,6 @@ jobs: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} + MODELS: ${{ github.event.inputs.models || '' }} run: | python -m zoo.wd14.sync diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index fbd716571f..43b2145223 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -27,12 +27,14 @@ SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' VIT_LARGE_MODEL_REPO = 'SmilingWolf/wd-vit-large-tagger-v3' +EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" _IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17' MODEL_NAMES = { + "EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO, "ViT_Large": VIT_LARGE_MODEL_REPO, "SwinV2": SWIN_MODEL_REPO, diff --git a/zoo/wd14/sync.py b/zoo/wd14/sync.py index 59bfeeffd7..9d1f6ffd80 100644 --- a/zoo/wd14/sync.py +++ b/zoo/wd14/sync.py @@ -1,6 +1,7 @@ import os.path import re from functools import lru_cache +from typing import List, Optional import numpy as np import onnx @@ -48,16 +49,24 @@ def _seg_split(text): "ConvNext_v3": ('core_model', 'head', 'fc'), "ViT_v3": ('core_model', 'head'), "ViT_Large": ('core_model', 'head'), + "EVA02_Large": ('core_model', 'head'), } -def sync(tag_lazy_mode: bool = False): +def sync(tag_lazy_mode: bool = False, models: Optional[List[str]] = None): hf_fs = get_hf_fs() + if models: + _make_all = False + _model_names = models + else: + _make_all = True + _model_names = MODEL_NAMES + import onnxruntime with TemporaryDirectory() as td: records = [] - for model_name in tqdm(MODEL_NAMES): + for model_name in tqdm(_model_names): model_file = _get_model_file(model_name) logging.info(f'Model name: {model_name!r}, model file: {model_file!r}') logging.info(f'Loading model {model_name!r} ...') @@ -171,11 +180,13 @@ def _is_fc(name): local_directory=td, path_in_repo='.', message=f'Upload {plural_word(len(df_records), "models")}', - clear=True, + clear=True if _make_all else False, ) if __name__ == '__main__': + _MODELS = list(filter(bool, re.split('[,\s]+', os.environ.get('MODELS') or ''))) sync( tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')), + models=_MODELS if _MODELS else None, )