Skip to content

Commit

Permalink
dev(narugo): new code support, ci skip
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Jul 29, 2024
1 parent c92aa23 commit 3ae7b5f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/wd14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ on:
# push:
workflow_dispatch:
inputs:
models:
description: 'Models To Make'
type: str
default: ''
tag_lazy_mode:
description: 'Enable Tag Lazy Mode'
type: boolean
Expand Down Expand Up @@ -55,5 +59,6 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
MODELS: ${{ github.event.inputs.models || '' }}
run: |
python -m zoo.wd14.sync
2 changes: 2 additions & 0 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3'
VIT_LARGE_MODEL_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

_IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'

MODEL_NAMES = {
"EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO,
"ViT_Large": VIT_LARGE_MODEL_REPO,

"SwinV2": SWIN_MODEL_REPO,
Expand Down
17 changes: 14 additions & 3 deletions zoo/wd14/sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import re
from functools import lru_cache
from typing import List, Optional

import numpy as np
import onnx
Expand Down Expand Up @@ -48,16 +49,24 @@ def _seg_split(text):
"ConvNext_v3": ('core_model', 'head', 'fc'),
"ViT_v3": ('core_model', 'head'),
"ViT_Large": ('core_model', 'head'),
"EVA02_Large": ('core_model', 'head'),
}


def sync(tag_lazy_mode: bool = False):
def sync(tag_lazy_mode: bool = False, models: Optional[List[str]] = None):
hf_fs = get_hf_fs()

if models:
_make_all = False
_model_names = models
else:
_make_all = True
_model_names = MODEL_NAMES

import onnxruntime
with TemporaryDirectory() as td:
records = []
for model_name in tqdm(MODEL_NAMES):
for model_name in tqdm(_model_names):
model_file = _get_model_file(model_name)
logging.info(f'Model name: {model_name!r}, model file: {model_file!r}')
logging.info(f'Loading model {model_name!r} ...')
Expand Down Expand Up @@ -171,11 +180,13 @@ def _is_fc(name):
local_directory=td,
path_in_repo='.',
message=f'Upload {plural_word(len(df_records), "models")}',
clear=True,
clear=True if _make_all else False,
)


if __name__ == '__main__':
_MODELS = list(filter(bool, re.split('[,\s]+', os.environ.get('MODELS') or '')))
sync(
tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')),
models=_MODELS if _MODELS else None,
)

0 comments on commit 3ae7b5f

Please sign in to comment.