Skip to content

Commit

Permalink
Merge pull request #64 from deepghs/dev/safe
Browse files Browse the repository at this point in the history
dev(narugo): add safe check model
  • Loading branch information
narugo1992 authored Jan 2, 2024
2 parents e9a2e57 + b846435 commit 97ca445
Show file tree
Hide file tree
Showing 19 changed files with 2,682 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/validate/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ imgutils.validate
portrait
rating
real
safe
style_age
teen
truncate
21 changes: 21 additions & 0 deletions docs/source/api_doc/validate/safe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
imgutils.validate.safe
=============================================

.. currentmodule:: imgutils.validate.safe

.. automodule:: imgutils.validate.safe


safe_check_score
-----------------------------

.. autofunction:: safe_check_score



safe_check
-----------------------------

.. autofunction:: safe_check


46 changes: 46 additions & 0 deletions docs/source/api_doc/validate/safe_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import random

from huggingface_hub import HfFileSystem
from natsort import natsorted

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import safe_check

hf_fs = HfFileSystem()

REPOSITORY = 'mf666/shit-checker'
MODELS = natsorted([
os.path.splitext(os.path.relpath(file, REPOSITORY))[0]
for file in hf_fs.glob(f'{REPOSITORY}/*.onnx')
])


class SafeCheckBenchmark(BaseBenchmark):
def __init__(self, model):
BaseBenchmark.__init__(self)
self.model = model

def load(self):
from imgutils.validate.safe import _open_model
_ = _open_model(self.model)

def unload(self):
from imgutils.validate.safe import _open_model
_open_model.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = safe_check(image_file, self.model)


if __name__ == '__main__':
create_plot_cli(
[
(name, SafeCheckBenchmark(name))
for name in MODELS
],
title='Benchmark for Safe Check Models',
run_times=10,
try_times=20,
)()
Loading

0 comments on commit 97ca445

Please sign in to comment.