Skip to content
This repository has been archived by the owner on Dec 17, 2021. It is now read-only.

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kavorite committed Jul 10, 2021
0 parents commit 6c4e369
Show file tree
Hide file tree
Showing 63 changed files with 12,111 additions and 0 deletions.
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.pythonPath": "C:\\Users\\staly\\AppData\\Local\\Programs\\Python\\Python39\\python.exe"
}
1 change: 1 addition & 0 deletions DeepDanbooru
Submodule DeepDanbooru added at 157e33
243 changes: 243 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import csv
import gzip
import io
import os
import urllib.request as http
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from threading import BoundedSemaphore
from urllib.error import HTTPError

import numpy as np
import tensorflow as tf
from tqdm import tqdm


def download_with_progress(endpoint, output_path=None, chunk_size=4096):
output_path = output_path or os.path.basename(endpoint)
with tqdm(unit_scale=True, unit_divisor=1024, unit="B") as progress:
with open(output_path, "wb+") as ostrm:
while (progress.n or -1) < (progress.total or 0):
try:
headers = {"User-Agent": "e6tag hydration by kavorite"}
if progress.total:
headers["Range"] = f"bytes={progress.n or 0}-{progress.total-1}"
request = http.Request(endpoint, headers=headers)
with http.urlopen(request, timeout=2) as rsp:
progress.total = progress.total or int(
rsp.headers["Content-Length"]
)
progress.refresh()
while chunk := rsp.read(chunk_size):
progress.update(len(chunk))
ostrm.write(chunk)
except OSError:
continue


yesterday = (datetime.today() - timedelta(days=1)).date()
if (
not os.path.exists("tags.csv.gz")
or datetime.fromtimestamp(os.stat("tags.csv.gz").st_mtime).date() <= yesterday
):
endpoint = f"https://e621.net/db_export/tags-{yesterday.isoformat()}.csv.gz"
download_with_progress(endpoint, "tags.csv.gz")

if (
not os.path.exists("posts.csv.gz")
or datetime.fromtimestamp(os.stat("posts.csv.gz").st_mtime).date() <= yesterday
):
endpoint = f"https://e621.net/db_export/posts-{yesterday.isoformat()}.csv.gz"
download_with_progress(endpoint, "posts.csv.gz")

tag_idx = dict()
NUM_TAGS = 1024
with gzip.open("./tags.csv.gz") as istrm:
tag_rows = csv.DictReader(io.TextIOWrapper(istrm, encoding="utf8"))
tags = [
tag
for tag in tqdm(tag_rows)
if tag["name"].isprintable()
and tag["category"] == "0"
and "comic" not in tag["name"]
]

tags.sort(key=lambda tag: int(tag["post_count"]))
tags = tags[::-1][:NUM_TAGS]
tag_idx = {tag["name"]: i for i, tag in enumerate(tags)}


csv.field_size_limit(1 << 20)
posts_by_id = dict()
hit_tags = {tag["name"]: 512 for tag in tags}
all_tags = set(tag_idx.keys())
min_tags = 16


def tags_of(post):
post_tags = [t for t in post["tag_string"].split() if t in tag_idx]
post_tags.sort(key=lambda t: tag_idx[t])
post_tags = post_tags[::-1]
return post_tags


with gzip.open("./posts.csv.gz") as istrm:
post_rows = csv.DictReader(io.TextIOWrapper(istrm, encoding="utf8"))
skip = 1_000_000
print(f"skipping the first {skip} posts...\n")
for i, post in tqdm(enumerate(post_rows), total=skip):
if i + 1 >= skip:
break
goal = sum(hit_tags.values())
print(f"pulling {goal} positive hits...")
for post in tqdm(post_rows, total=2_800_000 - skip):
if post["is_deleted"] == "t":
continue
if post["file_ext"] in {"webm", "gif", "swf"}:
continue
if "comic" in post["tag_string"]:
continue
post_tags = tags_of(post)
if len(post_tags) < min_tags:
continue
post_id = int(post["id"])
for t in post_tags:
if t not in hit_tags:
continue
hit_tags[t] -= 1
if hit_tags[t] == 0:
del hit_tags[t]
md5 = post["md5"]
ext = post["file_ext"] or "jpg"
link = f"https://static1.e621.net/data/sample/{md5[0:2]}/{md5[2:4]}/{md5}.{ext}"
post["link"] = link
post = {k: post[k] for k in ["link", "id", "tag_string"]}
posts_by_id[post_id] = post
break


indices = set()
for i, post in enumerate(tqdm(posts_by_id.values())):
for t in tags_of(post):
if t in tag_idx:
j = tag_idx[t]
indices.add((int(post["id"]), j))
label_shape = (len(posts_by_id), len(tag_idx))
label_values = np.ones(len(indices))
labels = tf.SparseTensor(list(indices), label_values, label_shape)
labels = tf.sparse.reorder(labels)


def sparse_row(A, i):
return tf.squeeze(tf.gather(A.indices, tf.where(A.indices[:, 0] == i)))[:, 1]


def bytes_feature(x):
x = tf.train.BytesList(value=[x])
x = tf.train.Feature(bytes_list=x)
return x


def int64_feature(x):
x = tf.train.Int64List(value=x)
x = tf.train.Feature(int64_list=x)
return x


enc_tag_names = [tag["name"].encode("utf8") for tag in tags]
enc_tag_names = np.array(enc_tag_names, dtype=object)


def make_example(post, image_str):
tag_indxs = sparse_row(labels, int(post["id"])).numpy()
tag_names_ft = bytes_feature(b" ".join(enc_tag_names[tag_indxs]))
tag_indxs_ft = int64_feature(tag_indxs)
post_id_ft = int64_feature([post_id])
image_str_ft = bytes_feature(image_str)
feature = dict(
image_str=image_str_ft,
tag_indxs=tag_indxs_ft,
tag_names=tag_names_ft,
post_id=post_id_ft,
)
return tf.train.Example(features=tf.train.Features(feature=feature))


def download_post(post, sample=True):
try:
endpoint = post["link"]
if not sample:
endpoint = endpoint.replace("/sample/", "/")
max_retries = 2
for _ in range(max_retries + 1):
try:
with http.urlopen(endpoint) as rsp:
return rsp.read()
except OSError:
continue
return None
except HTTPError as err:
if err.status == 404:
return None
else:
raise err


def download_posts(posts, sample=False):
concurrency = os.cpu_count() * 4
posts = iter(posts)
semaphore = BoundedSemaphore(concurrency)
with ThreadPoolExecutor() as pool:
jobs = deque()
while True:
if semaphore.acquire(blocking=False):
post = next(posts, None)
if post is not None:
job = pool.submit(download_post, post, sample=sample)
job.add_done_callback(lambda _: semaphore.release())
jobs.append(job)
else:
semaphore.release()
while jobs and jobs[0].done():
yield jobs.popleft().result()


shard_size = 256
total_shards = int(tf.math.ceil(len(posts_by_id) / shard_size))
zpad = int(tf.math.ceil(tf.math.log(float(total_shards)) / tf.math.log(10.0)) + 1)
SHARD_ROOT = "D:/yiff"

if not os.path.exists(SHARD_ROOT):
os.makedirs(SHARD_ROOT)

with open(os.path.join(SHARD_ROOT, "tags.txt"), "w+", encoding="utf8") as ostrm:
ostrm.write("\n".join(tag_idx.keys()))


all_posts = list(posts_by_id.values())
print("sharding posts...")
shard_posts = [[] for _ in range(total_shards)]
for i, post in tqdm(enumerate(all_posts), total=len(all_posts)):
shard_posts[i % total_shards].append(post)
with tqdm(total=len(posts_by_id)) as progress:
for i in range(total_shards):
index = str(i).zfill(zpad)
name = f"samples.shard{index}of{total_shards}.tfrecords"
name = os.path.join(SHARD_ROOT, name)
posts = shard_posts[i]
if (
os.path.exists(name)
and i != total_shards - 1
and os.stat(name).st_size > 200e6
):
progress.update(len(posts))
continue
with tf.io.TFRecordWriter(name) as records:
image_strs = download_posts(iter(posts))
for post, image_str in zip(posts, image_strs):
progress.update()
if image_str is None:
continue
example = make_example(post, image_str)
records.write(example.SerializeToString())
15 changes: 15 additions & 0 deletions dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

os.environ |= {"CUDA_VISIBLE_DEVICES": "-1", "TF_CPP_MIN_LOG_LEVEL": "2"}

import tensorflow as tf

from train import build_model

model = build_model()
model.load_weights(tf.io.gfile.glob("./train/*.h5")[-1])
model.save("./e6tag.h5")
# cvter = tf.lite.TFLiteConverter.from_keras_model(model)
# cvter.optimizations = [tf.lite.Optimize.DEFAULT]
# quant = cvter.convert()
# quant.write_bytes("./yiff_tag.tflite")
17 changes: 17 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tensorflow as tf

from train import build_model, make_dataset, shard_names

SHARD_ROOT = "D:/yiff"
model = build_model()
dataset = make_dataset(shard_names(SHARD_ROOT), model.inputs[0].shape[1:]).batch(32)
model.compile(
metrics=[
tf.keras.metrics.AUC(name="auc"),
tf.keras.metrics.BinaryAccuracy(name="acc"),
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
]
)
model.load_weights(tf.io.gfile.glob("./train/*.h5")[-1])
model.evaluate(dataset, verbose=1)
Loading

0 comments on commit 6c4e369

Please sign in to comment.