From 5fb64ae416188499e02a3fc5eed8f11f5dd03ae9 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 14 May 2024 17:10:27 +0800 Subject: [PATCH] dev(narugo): add prediction && use float as tagger output --- imgutils/tagging/wd14.py | 22 +++++++++++++++------- test/tagging/test_wd14.py | 5 +++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index ed86b613126..053b6bbaa2b 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -165,6 +165,16 @@ def get_wd14_tags( :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] + .. note:: + About ``fmt`` argument, these are the available names: + + * ``rating``, a dict containing ratings and their confidences + * ``general``, a dict containing general tags and their confidences + * ``character``, a dict containing character tags and their confidences + * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences + * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization + * ``prediction``, a 1-dim prediction result of image + Example: Here are some images for example @@ -202,16 +212,14 @@ def get_wd14_tags( preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) - ratings_names = [labels[i] for i in rating_indexes] - rating = dict(ratings_names) + rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) - general_res = [x for x in general_names if x[1] > general_threshold] - general_res = dict(general_res) + general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) @@ -221,8 +229,7 @@ def get_wd14_tags( character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) - character_res = [x for x in character_names if x[1] > character_threshold] - character_res = dict(character_res) + character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, @@ -231,6 +238,7 @@ def get_wd14_tags( 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, - 'embedding': embeddings[0], + 'embedding': embeddings[0].astype(np.float32), + 'prediction': preds[0].astype(np.float32), } ) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index cf4f9679583..92f3b4ac9c5 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -21,11 +21,16 @@ def test_get_wd14_tags(self): assert rating['general'] > 0.9 assert tags['cat_girl'] >= 0.8 assert not chars + assert isinstance(rating['general'], float) + assert isinstance(tags['cat_girl'], float) rating, tags, chars = get_wd14_tags(get_testfile('6125785.jpg')) assert 0.6 <= rating['general'] <= 0.8 assert tags['1girl'] >= 0.95 assert chars['hu_tao_(genshin_impact)'] >= 0.95 + assert isinstance(rating['general'], float) + assert isinstance(tags['1girl'], float) + assert isinstance(chars['hu_tao_(genshin_impact)'], float) def test_wd14_tags_sample(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'))