From 2326d15d9f05f225063fcb15ff1194ee71261515 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Sun, 5 May 2024 00:31:22 +0800 Subject: [PATCH] dev(narugo): add algo support --- requirements.txt | 2 +- sdeval/controllability/bikini_plus.py | 20 +++++++++++--- test/controllability/test_bikini_plus.py | 34 ++++++++++++------------ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/requirements.txt b/requirements.txt index d91eaf8..b60fedb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ pillow numpy huggingface_hub tqdm -dghs-imgutils>=0.2.10 +dghs-imgutils>=0.4.0 diff --git a/sdeval/controllability/bikini_plus.py b/sdeval/controllability/bikini_plus.py index e6ced70..492d835 100644 --- a/sdeval/controllability/bikini_plus.py +++ b/sdeval/controllability/bikini_plus.py @@ -217,20 +217,34 @@ def get_tag(self, tag_text): def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs): _ = kwargs - _, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold) + _, features, characters = get_deepdanbooru_tags( + image=image, + use_real_name=use_real_name, + general_threshold=general_threshold, + character_threshold=character_threshold + ) return {**features, **characters} def _wd14_tagging(image: Image.Image, model_name: str, general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs): _ = kwargs - _, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold) + _, features, characters = get_wd14_tags( + image=image, + model_name=model_name, + general_threshold=general_threshold, + character_threshold=character_threshold + ) return {**features, **characters} def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.0, **kwargs): _ = kwargs - features = get_mldanbooru_tags(image, use_real_name, general_threshold) + features = get_mldanbooru_tags( + image=image, + use_real_name=use_real_name, + threshold=general_threshold + ) return features diff --git a/test/controllability/test_bikini_plus.py b/test/controllability/test_bikini_plus.py index 2ccdf55..505ef45 100644 --- a/test/controllability/test_bikini_plus.py +++ b/test/controllability/test_bikini_plus.py @@ -5,7 +5,7 @@ import numpy as np import pytest from hbutils.system import TemporaryDirectory -from imgutils.data import load_image +from imgutils.data import load_image, istack from imgutils.sd import get_sdmeta_from_image from sdeval.controllability import BikiniPlusMetrics @@ -87,44 +87,44 @@ def bikini_image_prompts_noneg(bikini_image_files): @pytest.fixture() def bikini_images(bikini_image_files): - return [load_image(file) for file in bikini_image_files] + return [istack('white', load_image(file)) for file in bikini_image_files] @pytest.mark.unittest class TestControllabilityBikiniPlus: def test_score(self, bikini_plus_metrics, bikini_image_files): assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_files] == pytest.approx([ - 0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468, - 0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645, - 0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855, + 0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192, + 0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658, + 0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015 ]) def test_score_files(self, bikini_plus_metrics, bikini_image_files): assert np.isclose(bikini_plus_metrics.score(bikini_image_files, mode='seq'), np.array([ - 0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468, - 0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645, - 0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855, + 0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192, + 0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658, + 0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015 ])).all() def test_score_dirs(self, bikini_image_dirs, bikini_plus_metrics): assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_dirs] == pytest.approx([ - 0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468, - 0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645, - 0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855, + 0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192, + 0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658, + 0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015 ]) def test_score_prompts(self, bikini_image_prompts, bikini_plus_metrics): assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_prompts] == pytest.approx([ - 0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468, - 0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645, - 0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855, + 0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192, + 0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658, + 0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015 ]) def test_score_prompts_noneg(self, bikini_image_prompts_noneg, bikini_plus_metrics): assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_prompts_noneg] == pytest.approx([ - 0.8057833110347624, 0.8216747866489079, 0.8421623072638775, 0.8129909450109516, 0.8206196565969767, - 0.7439588227367002, 0.7331108945469114, 0.6815734980363065, 0.7266735161122132, 0.7308833770642613, - 0.7590933092927681, 0.7624429822550852, 0.7263612574844726, 0.731894281100077, 0.7342754910276161, + 0.8179163420535865, 0.827850576086485, 0.8467362594800669, 0.8158260787603662, 0.8278514338670646, + 0.7489325614901721, 0.7330391553389971, 0.6893156535159684, 0.7340261506217604, 0.7385667816619259, + 0.7655768831064597, 0.7674244855061132, 0.7375836722664456, 0.7410484713055709, 0.7444564178495885 ]) def test_score_raw_images(self, bikini_images, bikini_plus_metrics):