Skip to content

Commit

Permalink
use ruff (#73)
Browse files Browse the repository at this point in the history
* use ruff

* fix pytest call

* set line length to 120
  • Loading branch information
kha-white authored Jun 29, 2024
1 parent f965f39 commit 083ddbd
Show file tree
Hide file tree
Showing 19 changed files with 58 additions and 130 deletions.
10 changes: 0 additions & 10 deletions .github/workflows/black.yml

This file was deleted.

7 changes: 2 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
path: manga_ocr

- name: Set up Python
uses: actions/setup-python@v3
Expand All @@ -27,9 +25,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -e manga_ocr
pip install -e ".[dev]"
- name: Test
run: |
pytest manga_ocr/tests
pytest
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [ push, pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
4 changes: 2 additions & 2 deletions manga_ocr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._version import __version__
from manga_ocr.ocr import MangaOcr
from ._version import __version__ as __version__
from manga_ocr.ocr import MangaOcr as MangaOcr
16 changes: 4 additions & 12 deletions manga_ocr/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@


class MangaOcr:
def __init__(
self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False
):
def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False):
logger.info(f"Loading OCR model from {pretrained_model_name_or_path}")
self.processor = ViTImageProcessor.from_pretrained(
pretrained_model_name_or_path
)
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(
pretrained_model_name_or_path
)
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path)

if not force_cpu and torch.cuda.is_available():
logger.info("Using CUDA")
Expand All @@ -43,9 +37,7 @@ def __call__(self, img_or_path):
elif isinstance(img_or_path, Image.Image):
img = img_or_path
else:
raise ValueError(
f"img_or_path must be a path or PIL.Image, instead got: {img_or_path}"
)
raise ValueError(f"img_or_path must be a path or PIL.Image, instead got: {img_or_path}")

img = img.convert("L").convert("RGB")

Expand Down
16 changes: 4 additions & 12 deletions manga_ocr/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def process_and_write_results(mocr, img_or_path, write_to):
else:
write_to = Path(write_to)
if write_to.suffix != ".txt":
raise ValueError(
'write_to must be either "clipboard" or a path to a text file'
)
raise ValueError('write_to must be either "clipboard" or a path to a text file')

with write_to.open("a", encoding="utf-8") as f:
f.write(text + "\n")
Expand Down Expand Up @@ -102,23 +100,17 @@ def run(
# Pillow error when clipboard contains text (Linux, X11)
pass
else:
logger.warning(
"Error while reading from clipboard ({})".format(error)
)
logger.warning("Error while reading from clipboard ({})".format(error))
else:
if isinstance(img, Image.Image) and not are_images_identical(
img, old_img
):
if isinstance(img, Image.Image) and not are_images_identical(img, old_img):
process_and_write_results(mocr, img, write_to)

time.sleep(delay_secs)

else:
read_from = Path(read_from)
if not read_from.is_dir():
raise ValueError(
'read_from must be either "clipboard" or a path to a directory'
)
raise ValueError('read_from must be either "clipboard" or a path to a directory')

logger.info(f"Reading from directory {read_from}")

Expand Down
3 changes: 1 addition & 2 deletions manga_ocr_dev/data/generate_backgrounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def generate_backgrounds(crops_per_page=5, min_size=40):

if crop.shape[0] >= min_size and crop.shape[1] >= min_size:
out_filename = (
"_".join(Path(page_path).with_suffix("").parts[-2:])
+ f"_{ymin}_{ymax}_{xmin}_{xmax}.png"
"_".join(Path(page_path).with_suffix("").parts[-2:]) + f"_{ymin}_{ymax}_{xmin}_{xmax}.png"
)
cv2.imwrite(str(BACKGROUND_DIR / out_filename), crop)

Expand Down
16 changes: 4 additions & 12 deletions manga_ocr_dev/data/process_manga109s.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ def get_books():
books = pd.DataFrame(
{
"book": books,
"annotations": [
str(root / "annotations" / f"{book}.xml") for book in books
],
"annotations": [str(root / "annotations" / f"{book}.xml") for book in books],
"images": [str(root / "images" / book) for book in books],
}
)
Expand All @@ -36,9 +34,7 @@ def export_frames():
row = {}
row["book"] = book.book
row["page_index"] = int(page.attrib["index"])
row["page_path"] = str(
Path(book.images) / f'{row["page_index"]:03d}.jpg'
)
row["page_path"] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
row["page_width"] = int(page.attrib["width"])
row["page_height"] = int(page.attrib["height"])
row["id"] = frame.attrib["id"]
Expand Down Expand Up @@ -69,9 +65,7 @@ def export_crops():
row = {}
row["book"] = book.book
row["page_index"] = int(page.attrib["index"])
row["page_path"] = str(
Path(book.images) / f'{row["page_index"]:03d}.jpg'
)
row["page_path"] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
row["page_width"] = int(page.attrib["width"])
row["page_height"] = int(page.attrib["height"])
row["id"] = text.attrib["id"]
Expand All @@ -93,9 +87,7 @@ def export_crops():
data.crop_path = data.crop_path.apply(lambda x: "/".join(Path(x).parts[-2:]))
data.to_csv(MANGA109_ROOT / "data.csv", index=False)

for page_path, boxes in tqdm(
data.groupby("page_path"), total=data.page_path.nunique()
):
for page_path, boxes in tqdm(data.groupby("page_path"), total=data.page_path.nunique()):
img = cv2.imread(str(MANGA109_ROOT / page_path))

for box in boxes.itertuples():
Expand Down
9 changes: 2 additions & 7 deletions manga_ocr_dev/synthetic_data_generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def add_random_furigana(self, line, word_prob=1.0, vocab=None):
kanji_group = ""
ascii_group = ""
for i, c in enumerate(line):

if is_kanji(c):
c_type = "kanji"
kanji_group += c
Expand All @@ -141,12 +140,8 @@ def add_random_furigana(self, line, word_prob=1.0, vocab=None):
if c_type != "kanji" or i == len(line) - 1:
if kanji_group:
if np.random.uniform() < word_prob:
furigana_len = int(
np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group)
)
char_source = np.random.choice(
["hiragana", "katakana", "all"], p=[0.8, 0.15, 0.05]
)
furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group))
char_source = np.random.choice(["hiragana", "katakana", "all"], p=[0.8, 0.15, 0.05])
char_source = {
"hiragana": self.hiragana,
"katakana": self.katakana,
Expand Down
36 changes: 8 additions & 28 deletions manga_ocr_dev/synthetic_data_generator/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def get_random_css_params():
if np.random.rand() < 0.7:
params["text_orientation"] = "upright"

stroke_variant = np.random.choice(
["stroke", "shadow", "none"], p=[0.8, 0.15, 0.05]
)
stroke_variant = np.random.choice(["stroke", "shadow", "none"], p=[0.8, 0.15, 0.05])
if stroke_variant == "stroke":
params["stroke_size"] = np.random.choice([1, 2, 3, 4, 8])
params["stroke_color"] = "white"
Expand All @@ -88,9 +86,7 @@ def render_background(self, img):
A.HorizontalFlip(),
A.RandomRotate90(),
A.InvertImg(),
A.RandomBrightnessContrast(
(-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1
),
A.RandomBrightnessContrast((-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1),
A.Blur((3, 5), p=0.3),
A.Resize(img.shape[0], img.shape[1]),
]
Expand All @@ -108,17 +104,9 @@ def render_background(self, img):
sigma = np.random.randint(10, 15)

ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
ymax = (
img.shape[0]
- m0
+ int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
)
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmax = (
img.shape[1]
- m0
+ int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
)
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))

bubble_fill_color = (255, 255, 255, 255)
bubble_contour_color = (0, 0, 0, 255)
Expand Down Expand Up @@ -150,13 +138,9 @@ def render_background(self, img):
img = blend(img, background)

ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
ymax = (
img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
)
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmax = (
img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
)
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
img = img[ymin:ymax, xmin:xmax]
return img

Expand Down Expand Up @@ -184,9 +168,7 @@ def blend(img, background):
return img


def rounded_rectangle(
src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA
):
def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA):
"""From https://stackoverflow.com/a/60210706"""

# corners:
Expand Down Expand Up @@ -345,9 +327,7 @@ def get_css(
# stroke is simulated by shadow overlaid multiple times
styles.extend(
[
f"text-shadow: "
+ ",".join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size)
+ ";",
"text-shadow: " + ",".join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) + ";",
"-webkit-font-smoothing: antialiased;",
]
)
Expand Down
6 changes: 2 additions & 4 deletions manga_ocr_dev/synthetic_data_generator/run_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def f(args):
ret = source, id_, text_gt, params["vertical"], str(font_path)
return ret

except Exception as e:
except Exception:
print(traceback.format_exc())


Expand Down Expand Up @@ -54,9 +54,7 @@ def run(package=0, n_random=1000, n_limit=None, max_workers=16):
OUT_DIR = DATA_SYNTHETIC_ROOT / "img" / package
OUT_DIR.mkdir(parents=True, exist_ok=True)

data = thread_map(
f, args, max_workers=max_workers, desc=f"Processing package {package}"
)
data = thread_map(f, args, max_workers=max_workers, desc=f"Processing package {package}")

data = pd.DataFrame(data, columns=["source", "id", "text", "vertical", "font_path"])
meta_path = DATA_SYNTHETIC_ROOT / f"meta/{package}.csv"
Expand Down
4 changes: 1 addition & 3 deletions manga_ocr_dev/synthetic_data_generator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,5 @@ def get_charsets(vocab_path=None):
def get_font_meta():
df = pd.read_csv(ASSETS_PATH / "fonts.csv")
df.font_path = df.font_path.apply(lambda x: str(FONTS_ROOT / x))
font_map = {
row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()
}
font_map = {row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()}
return df, font_map
4 changes: 1 addition & 3 deletions manga_ocr_dev/training/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def __init__(
continue
df = pd.read_csv(path)
df = df.dropna()
df["path"] = df.id.apply(
lambda x: str(DATA_SYNTHETIC_ROOT / "img" / path.stem / f"{x}.jpg")
)
df["path"] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / "img" / path.stem / f"{x}.jpg"))
df = df[["path", "text"]]
df["synthetic"] = True
data.append(df)
Expand Down
12 changes: 3 additions & 9 deletions manga_ocr_dev/training/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,15 @@ def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):

if num_decoder_layers is not None:
if decoder_config.model_type == "bert":
decoder.bert.encoder.layer = decoder.bert.encoder.layer[
-num_decoder_layers:
]
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
elif decoder_config.model_type in ("roberta", "xlm-roberta"):
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[
-num_decoder_layers:
]
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
else:
raise ValueError(f"Unsupported model_type: {decoder_config.model_type}")

decoder_config.num_hidden_layers = num_decoder_layers

config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
encoder.config, decoder.config
)
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
config.tie_word_embeddings = False
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)

Expand Down
4 changes: 1 addition & 3 deletions manga_ocr_dev/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def compute_metrics(self, pred):

results = {}
try:
results["cer"] = self.cer_metric.compute(
predictions=pred_str, references=label_str
)
results["cer"] = self.cer_metric.compute(predictions=pred_str, references=label_str)
except Exception as e:
print(e)
print(pred_str)
Expand Down
12 changes: 3 additions & 9 deletions manga_ocr_dev/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,11 @@ def run(
):
wandb.login()

model, processor = get_model(
encoder_name, decoder_name, max_len, num_decoder_layers
)
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)

# keep package 0 for validation
train_dataset = MangaDataset(
processor, "train", max_len, augment=True, skip_packages=[0]
)
eval_dataset = MangaDataset(
processor, "test", max_len, augment=False, skip_packages=range(1, 9999)
)
train_dataset = MangaDataset(processor, "train", max_len, augment=True, skip_packages=[0])
eval_dataset = MangaDataset(processor, "test", max_len, augment=False, skip_packages=range(1, 9999))

metrics = Metrics(processor)

Expand Down
7 changes: 1 addition & 6 deletions manga_ocr_dev/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,4 @@ def decoder_summary(model, batch_size=4):


def tensor_to_image(img):
return (
((img.cpu().numpy() + 1) / 2 * 255)
.clip(0, 255)
.astype(np.uint8)
.transpose(1, 2, 0)
)
return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
Loading

0 comments on commit 083ddbd

Please sign in to comment.