Skip to content

Commit

Permalink
dev(narugo): add algo support
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 4, 2024
1 parent 4b70ef7 commit 2326d15
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pillow
numpy
huggingface_hub
tqdm
dghs-imgutils>=0.2.10
dghs-imgutils>=0.4.0
20 changes: 17 additions & 3 deletions sdeval/controllability/bikini_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,20 +217,34 @@ def get_tag(self, tag_text):
def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False,
general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs):
_ = kwargs
_, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold)
_, features, characters = get_deepdanbooru_tags(
image=image,
use_real_name=use_real_name,
general_threshold=general_threshold,
character_threshold=character_threshold
)
return {**features, **characters}


def _wd14_tagging(image: Image.Image, model_name: str,
general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs):
_ = kwargs
_, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold)
_, features, characters = get_wd14_tags(
image=image,
model_name=model_name,
general_threshold=general_threshold,
character_threshold=character_threshold
)
return {**features, **characters}


def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.0, **kwargs):
_ = kwargs
features = get_mldanbooru_tags(image, use_real_name, general_threshold)
features = get_mldanbooru_tags(
image=image,
use_real_name=use_real_name,
threshold=general_threshold
)
return features


Expand Down
34 changes: 17 additions & 17 deletions test/controllability/test_bikini_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest
from hbutils.system import TemporaryDirectory
from imgutils.data import load_image
from imgutils.data import load_image, istack
from imgutils.sd import get_sdmeta_from_image

from sdeval.controllability import BikiniPlusMetrics
Expand Down Expand Up @@ -87,44 +87,44 @@ def bikini_image_prompts_noneg(bikini_image_files):

@pytest.fixture()
def bikini_images(bikini_image_files):
return [load_image(file) for file in bikini_image_files]
return [istack('white', load_image(file)) for file in bikini_image_files]


@pytest.mark.unittest
class TestControllabilityBikiniPlus:
def test_score(self, bikini_plus_metrics, bikini_image_files):
assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_files] == pytest.approx([
0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468,
0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645,
0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855,
0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192,
0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658,
0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015
])

def test_score_files(self, bikini_plus_metrics, bikini_image_files):
assert np.isclose(bikini_plus_metrics.score(bikini_image_files, mode='seq'), np.array([
0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468,
0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645,
0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855,
0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192,
0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658,
0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015
])).all()

def test_score_dirs(self, bikini_image_dirs, bikini_plus_metrics):
assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_dirs] == pytest.approx([
0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468,
0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645,
0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855,
0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192,
0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658,
0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015
])

def test_score_prompts(self, bikini_image_prompts, bikini_plus_metrics):
assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_prompts] == pytest.approx([
0.8837757309353425, 0.8933908126091592, 0.9055491415894145, 0.8882521965374851, 0.8927720615148468,
0.8469945459720423, 0.8399211360890133, 0.8098674415860692, 0.8363121274014674, 0.8389884182718645,
0.8555319857366422, 0.8579074531926136, 0.8362479325036504, 0.839586421120691, 0.83640841923855,
0.8910938778209551, 0.8970947795984634, 0.9082931538344428, 0.8899684368361426, 0.8970807609124192,
0.8499742177135515, 0.83972368664295, 0.8144942443611721, 0.8406730655450563, 0.84360939347658,
0.8595620814191306, 0.8608996205181652, 0.8430540710658955, 0.8451211202660482, 0.8432768553401015
])

def test_score_prompts_noneg(self, bikini_image_prompts_noneg, bikini_plus_metrics):
assert [bikini_plus_metrics.score(img_file) for img_file in bikini_image_prompts_noneg] == pytest.approx([
0.8057833110347624, 0.8216747866489079, 0.8421623072638775, 0.8129909450109516, 0.8206196565969767,
0.7439588227367002, 0.7331108945469114, 0.6815734980363065, 0.7266735161122132, 0.7308833770642613,
0.7590933092927681, 0.7624429822550852, 0.7263612574844726, 0.731894281100077, 0.7342754910276161,
0.8179163420535865, 0.827850576086485, 0.8467362594800669, 0.8158260787603662, 0.8278514338670646,
0.7489325614901721, 0.7330391553389971, 0.6893156535159684, 0.7340261506217604, 0.7385667816619259,
0.7655768831064597, 0.7674244855061132, 0.7375836722664456, 0.7410484713055709, 0.7444564178495885
])

def test_score_raw_images(self, bikini_images, bikini_plus_metrics):
Expand Down

0 comments on commit 2326d15

Please sign in to comment.