diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..c5c6294
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 88
+extend-ignore = E203
+
diff --git a/.gitignore b/.gitignore
index fb9cb04..e0ba26a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,6 +26,7 @@ var/
*.egg-info/
.installed.cfg
*.egg
+.mypy_cache/
# PyInstaller
# Usually these files are written by a python script from a template
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..309b846
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,12 @@
+repos:
+ - repo: https://github.com/psf/black
+ rev: stable
+ hooks:
+ - id: black
+ language_version: python3.7
+ - repo: https://github.com/timothycrosley/isort
+ rev: 4.3.21
+ hooks:
+ - id: isort
+ language_version: python3.7
+
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..fd2fcb9
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,9 @@
+# Changelog
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
+and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+## [Unreleased]
+
+[Unreleased]: https://github.com/Breta01/handwriting-ocr/releases
diff --git a/LICENSE b/LICENSE
index a8221aa..b7c48c9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2017 Břetislav Hájek
+Copyright (c) 2020 Břetislav Hájek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..6bd8299
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,51 @@
+.PHONY: help bootstrap data lint clean
+
+SHELL=/bin/bash
+
+VENV_NAME?=venv
+VENV_BIN=$(shell pwd)/${VENV_NAME}/bin
+VENV_ACTIVATE=source $(VENV_NAME)/bin/activate
+
+PROJECT_DIR=handwriting_ocr
+
+PYTHON=${VENV_NAME}/bin/python3
+
+.DEFAULT: help
+help:
+ @echo "Make file commands:"
+ @echo " make bootstrap"
+ @echo " Prepare complete development environment"
+ @echo " make data"
+ @echo " Download and prepare data for training"
+ @echo " make lint"
+ @echo " Run pylint and mypy"
+ @echo " make clean"
+ @echo " Clean repository"
+
+bootstrap:
+ sudo xargs apt-get -y install < requirements-apt.txt
+ python3.7 -m pip install pip
+ python3.7 -m pip install virtualenv
+ make venv
+ ${VENV_ACTIVATE}; pre-commit install
+
+# Runs when the file changes
+venv: $(VENV_NAME)/bin/activate
+$(VENV_NAME)/bin/activate: setup.py requirements.txt requirements-dev.txt
+ test -d $(VENV_NAME) || virtualenv -p python3.7 $(VENV_NAME)
+ ${PYTHON} -m pip install -U pip
+ ${PYTHON} -m pip install -e .[dev]
+ touch $(VENV_NAME)/bin/activate
+
+data:
+ ${PYTHON} ${PROJECT_DIR}/data/data_create_sets.py
+
+lint: venv
+# pylint supports pyproject.toml from 2.5 version. Switch to following cmd once updated:
+# ${PYTHON} -m pylint src
+ ${PYTHON} -m pylint --extension-pkg-whitelist=cv2 --variable-rgx='[a-z_][a-z0-9_]{0,30}$' --max-line-length=88 src
+ ${PYTHON} -m flake8 src
+
+clean:
+ find . -name '*.pyc' -exec rm --force {} +
+ rm -rf $(VENV_NAME) *.eggs *.egg-info dist build docs/_build .cache
diff --git a/README.md b/README.md
index aecb5cd..a6294be 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,6 @@
# Handwriting OCR
+[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+
The project tries to create software for recognition of a handwritten text from photos (also for Czech language). It uses computer vision and machine learning. And it experiments with different approaches to the problem. It started as a school project which I got a chance to present on Intel ISEF 2018.
@@ -15,34 +17,37 @@ Main files combining all the steps are [OCR.ipynb](notebooks/OCR.ipynb) or [OCR-
## Getting Started
### 1. Clone the repository
-```
+```bash
git clone https://github.com/Breta01/handwriting-ocr.git
```
After downloading the repo, you have to download the datasets and models (for more info look into [data](data/) and [models](models/) folders).
### 2. Requirements
-The project is created using Python 3.6 with Jupyter Notebook. I recommend using Anaconda. If you have it, you can run the installation as:
+```bash
+make bootstrap
```
-conda create --name ocr-env --file environment.yml
-conda activate ocr-env
+The project is using Python 3.7 with Jupyter Notebook. I recommend using virtualenv. Running command `make bootstrap` should install all necessary packages. If you have it, you can run the installation as:
+
+Main libraries (all required libraries are in [requirements.txt](requirements.txt) and [requirements-dev.txt](requirements-dev.txt)):
+* Numpy
+* Tensorflow
+* OpenCV
+* Pandas
+* Matplotlib
+
+### Activate and Run
+```bash
+source venv/bin/activate
```
-Main libraries (all required libraries are in [environment.yml](environment.yml)):
-* Numpy (1.13)
-* Tensorflow (1.4)
-* OpenCV (3.1)
-* Pandas (0.21)
-* Matplotlib (2.1)
-
-### Run
-With all required libraries installed and cloned repo, run `jupyter notebook` in the directory of the project. Then you can work on the particular notebook.
+This command will activate the virtualenv. Then you can run `jupyter notebook` in the directory of the project and work on the particular notebook.
## Contributing
Best way how to get involved is through creating [GitHub issues](https://github.com/Breta01/handwriting-ocr/issues) or solving one! If there aren't any issues you can contact me directly on email.
## License
-**MIT**
+[MIT](./LICENSE.md)
## Support the project
If this project helped you or you want to support quick answers to questions and issues. Or you just think it is an interesting project. Please consider a small donation.
-[![paypal](https://www.paypalobjects.com/en_US/i/btn/btn_donate_LG.gif)](https://paypal.me/bretahajek/2)
+[![paypal](https://www.paypalobjects.com/en_US/i/btn/btn_donate_LG.gif)](https://paypal.me/bretahajek/5)
diff --git a/data/README.md b/data/README.md
index 00a589e..a483581 100644
--- a/data/README.md
+++ b/data/README.md
@@ -9,9 +9,10 @@ After downloading these datasets, there are scripts in `src/data/` folder which
### Breta’s data (1)
*5000 images*
-All data owned by [@Breta01](https://github.com/Breta01) are available on this link (distributed under the same license as this repository). The data should be placed either in `raw/breta/` or `processed/breta/` folder according to their location in archive from the link below. (I removed the Czech accents from words. If you want to use them, you have to recover them using CSV files containing: `word_without_accents, original_word` in UTF-8 encoding.)
+All data owned by [@Breta01](https://github.com/Breta01) are available on this link (distributed under the same license as this repository). The data should be placed either in `raw/breta/` or `processed/breta/` folder accordingly (see links below). (I removed the Czech accents from words. If you want to use them, you have to recover them using CSV files containing: `word_without_accents, original_word` in UTF-8 encoding.)
-
+`raw/breta/`:
+`processed/brata/`:
### IAM Handwriting Database (2)
*85000 images*
diff --git a/data/characters/README.md b/data/characters/README.md
deleted file mode 100644
index f281da0..0000000
--- a/data/characters/README.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# Single Character Images (Legacy)
-- 2613 images
-
-The images were already pre-processed by Sobel edge normalization, grayscaled, 64x64px size and saved as flatten array. They may not be so useful due to different normalization. More lettere images can be obtained form words with gaplines. The images are part of my dataset archive:
-
-
-
-The `en-data.csv` contains values of individual pixels of images and `en-labels.csv` contains corresponding characters.
-
-### CZ Characters
-2 * 42 - upper, lower alphabet with accents (without ch and Ch); plus null char
- = 83 characters
-### EN chars
-2 * 26 - upper, lower alphabet; plus null char
- = 53 characters
diff --git a/data/raw/README.md b/data/raw/README.md
deleted file mode 100644
index 0ab43d4..0000000
--- a/data/raw/README.md
+++ /dev/null
@@ -1,10 +0,0 @@
-# Raw Data Folder
-Here you shold place all raw downlaoded data. Placing the dataset in apropriate folders:
-```
-data/raw/
- - breta/ (6100 images)
- - iam/ (85012 images)
- - cvl/ (84164 images)
- - orand/ (11719 images)
- - camb/ (5260 images)
-```
diff --git a/environment.yml b/environment.yml
deleted file mode 100644
index af40541..0000000
--- a/environment.yml
+++ /dev/null
@@ -1,127 +0,0 @@
-# This file may be used to create an environment using:
-# $ conda create --name --file
-# platform: linux-64
-_tflow_select=2.3.0
-absl-py=0.6.1
-astor=0.7.1
-backcall=0.1.0
-blas=1.0
-bleach=3.0.2
-bzip2=1.0.6
-c-ares=1.15.0
-ca-certificates=2018.03.07
-cairo=1.14.12
-certifi=2018.10.15
-cycler=0.10.0
-dbus=1.13.2
-decorator=4.3.0
-entrypoints=0.2.3
-expat=2.2.6
-ffmpeg=4.0
-fontconfig=2.13.0
-freeglut=3.0.0
-freetype=2.9.1
-gast=0.2.0
-glib=2.56.2
-gmp=6.1.2
-graphite2=1.3.12
-grpcio=1.14.1
-gst-plugins-base=1.14.0
-gstreamer=1.14.0
-h5py=2.8.0
-harfbuzz=1.8.8
-hdf5=1.10.2
-icu=58.2
-intel-openmp=2019.0
-ipykernel=5.1.0
-ipython=7.1.1
-ipython_genutils=0.2.0
-ipywidgets=7.4.2
-jasper=2.0.14
-jedi=0.13.1
-jinja2=2.10
-jpeg=9b
-jsonschema=2.6.0
-jupyter_client=5.2.3
-jupyter_core=4.4.0
-keras-applications=1.0.6
-keras-preprocessing=1.0.5
-kiwisolver=1.0.1
-libedit=3.1.20170329
-libffi=3.2.1
-libgcc-ng=8.2.0
-libgfortran-ng=7.3.0
-libglu=9.0.0
-libopencv=3.4.2
-libopus=1.3
-libpng=1.6.35
-libprotobuf=3.6.1
-libsodium=1.0.16
-libstdcxx-ng=8.2.0
-libtiff=4.0.9
-libuuid=1.0.3
-libvpx=1.7.0
-libxcb=1.13
-libxml2=2.9.8
-markdown=3.0.1
-markupsafe=1.1.0
-matplotlib=3.0.1
-mistune=0.8.4
-mkl
-mkl_fft=1.0.6
-mkl_random=1.0.1
-nbconvert=5.3.1
-nbformat=4.4.0
-ncurses=6.1
-notebook=5.7.1
-numpy=1.15.4
-numpy-base=1.15.4
-opencv=3.4.2
-openssl=1.0.2p
-pandoc=2.2.3.2
-pandocfilters=1.4.2
-parso=0.3.1
-pcre=8.42
-pexpect=4.6.0
-pickleshare=0.7.5
-pip=18.1
-pixman=0.34.0
-prometheus_client=0.4.2
-prompt_toolkit=2.0.7
-protobuf=3.6.1
-ptyprocess=0.6.0
-py-opencv=3.4.2
-pygments=2.2.0
-pyparsing=2.3.0
-pyqt=5.9.2
-python=3.6.6
-python-dateutil=2.7.5
-pytz=2018.7
-pyzmq=17.1.2
-qt=5.9.6
-readline=7.0
-scipy=1.1.0
-send2trash=1.5.0
-setuptools=40.6.2
-simplejson=3.16.0
-sip=4.19.8
-six=1.11.0
-sqlite=3.25.3
-tensorboard=1.12.0
-tensorflow=1.12.0
-tensorflow-base=1.12.0
-termcolor=1.1.0
-terminado=0.8.1
-testpath=0.4.2
-tk=8.6.8
-tornado=5.1.1
-traitlets=4.3.2
-unidecode=1.0.22
-wcwidth=0.1.7
-webencodings=0.5.1
-werkzeug=0.14.1
-wheel=0.32.2
-widgetsnbextension=3.4.2
-xz=5.2.4
-zeromq=4.2.5
-zlib=1.2.11
diff --git a/src/__init__.py b/handwriting_ocr/__init__.py
similarity index 100%
rename from src/__init__.py
rename to handwriting_ocr/__init__.py
diff --git a/src/data/__init__.py b/handwriting_ocr/data/__init__.py
similarity index 100%
rename from src/data/__init__.py
rename to handwriting_ocr/data/__init__.py
diff --git a/handwriting_ocr/data/data.py b/handwriting_ocr/data/data.py
new file mode 100644
index 0000000..7e9d48c
--- /dev/null
+++ b/handwriting_ocr/data/data.py
@@ -0,0 +1,3 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Modelu for providing datasets for training and inference."""
diff --git a/handwriting_ocr/data/data_create_sets.py b/handwriting_ocr/data/data_create_sets.py
new file mode 100644
index 0000000..3108d3f
--- /dev/null
+++ b/handwriting_ocr/data/data_create_sets.py
@@ -0,0 +1,101 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Modelu for creating sets (train/dev/test) of normalized images."""
+
+import argparse
+from pathlib import Path
+
+import cv2 as cv
+import numpy as np
+from tqdm import tqdm
+
+from handwriting_ocr.data.data_loader import DATASETS, DATA_FOLDER
+from handwriting_ocr.ocr.normalization import word_normalization
+
+
+def create_sets(data_path, test_size, dev_size, seed=42):
+ """Loads all data and process them into train/dev/test sets.
+
+ It loads all available datasets from data_path, splits them into train/dev/test sets
+ and normalize images. Labels are saved into labels.txt file as
+ `{filename.png}\t{label}` (separated by tab)
+
+ Args:
+ data_path (Path): Data folder path
+ test_size (float): Percentage of test images from all images (between 0-1)
+ dev_size (float): Percentage of dev images from all images (between 0-1)
+ seed (int): Seed value for reproducibility of split
+ """
+ np.random.seed(seed)
+ lines = []
+ for d in DATASETS:
+ if d.is_downloaded(data_path):
+ lines.extend(d.load(data_path))
+ np.random.shuffle(lines)
+
+ test_i, dev_i = (len(lines) * np.array([test_size, test_size + dev_size])).astype(
+ int
+ )
+
+ sets = {
+ "test": lines[:test_i],
+ "dev": lines[test_i:dev_i],
+ "train": lines[dev_i:],
+ }
+
+ for k, set_lines in sets.items():
+ print(f"{k} images: {len(set_lines)}")
+
+ folder = data_path / "sets" / k
+ folder.mkdir(parents=True, exist_ok=True)
+
+ label_file = folder.joinpath("labels.txt").open("w+")
+ for i, (path, label) in enumerate(tqdm(set_lines)):
+ image = cv.imread(str(path))
+ if image.shape[0] < 15:
+ continue
+
+ norm = word_normalization(
+ image, height=64, border=False, tilt=False, hyst_norm=False
+ )
+
+ out_name = f"image_{i:05}.png"
+ cv.imwrite(str(folder.joinpath(out_name)), norm)
+ label_file.write(f"{out_name}\t{label}\n")
+
+ label_file.close()
+
+
+def get_args():
+ """ArgumentParser for sets creation."""
+ parser = argparse.ArgumentParser(
+ description="Script for creating sets (train/dev/test) of normalized images."
+ )
+ parser.add_argument(
+ "--test_size",
+ type=float,
+ default=0.1,
+ help="Percentage (0-1) size of test set.",
+ )
+ parser.add_argument(
+ "--dev_size",
+ type=float,
+ default=0.1,
+ help="Percentage (0-1) size of development set.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Seed for random shuffle.")
+ parser.add_argument(
+ "--data_path",
+ type=Path,
+ default=DATA_FOLDER,
+ help="Path of data folder (default is recommended).",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ for d in DATASETS:
+ d.download(DATA_FOLDER)
+
+ create_sets(args.data_path, args.test_size, args.dev_size, args.seed)
diff --git a/src/data/data_creation/WordClassDM.py b/handwriting_ocr/data/data_creation/WordClassDM.py
similarity index 75%
rename from src/data/data_creation/WordClassDM.py
rename to handwriting_ocr/data/data_creation/WordClassDM.py
index 6b9c1bd..8ec21a1 100644
--- a/src/data/data_creation/WordClassDM.py
+++ b/handwriting_ocr/data/data_creation/WordClassDM.py
@@ -17,8 +17,7 @@
import glob
import argparse
import simplejson
-from ocr.normalization import imageNorm
-from ocr.viz import printProgressBar
+from handwritin_ocr.ocr.normalization import word_normalization
def loadImages(dataloc, idx=0, num=None):
@@ -26,9 +25,9 @@ def loadImages(dataloc, idx=0, num=None):
print("Loading words...")
# Load images and short them from the oldest to the newest
- imglist = glob.glob(os.path.join(dataloc, u'*.jpg'))
+ imglist = glob.glob(os.path.join(dataloc, "*.jpg"))
imglist.sort(key=lambda x: float(x.split("_")[-1][:-4]))
- tmpLabels = [name[len(dataloc):] for name in imglist]
+ tmpLabels = [name[len(dataloc) :] for name in imglist]
labels = np.array(tmpLabels)
images = np.empty(len(imglist), dtype=object)
@@ -42,19 +41,19 @@ def loadImages(dataloc, idx=0, num=None):
for i, img in enumerate(imglist):
# TODO Speed up loading - Normalization
if i >= idx and i < upper:
- images[i] = imageNorm(
+ images[i] = word_normalization(
cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB),
height=60,
border=False,
tilt=True,
- hystNorm=True)
- printProgressBar(i-idx, upper-idx-1)
+ hystNorm=True,
+ )
print()
return (images[idx:num], labels[idx:num])
def locCheck(loc):
- return loc + '/' if loc[-1] != '/' else loc
+ return loc + "/" if loc[-1] != "/" else loc
class Cycler:
@@ -67,33 +66,33 @@ def __init__(self, idx, data_loc, save_loc):
# Create save_loc directory if not exists
if not os.path.exists(save_loc):
os.makedirs(save_loc)
-
+
self.data_loc = locCheck(data_loc)
self.save_loc = locCheck(save_loc)
-
+
self.idx = 0
self.org_idx = idx
self.blockLoad()
self.image_act = self.images[self.idx]
- cv2.namedWindow('image')
- cv2.setMouseCallback('image', self.mouseHandler)
+ cv2.namedWindow("image")
+ cv2.setMouseCallback("image", self.mouseHandler)
self.nextImage()
self.run()
def run(self):
- while(1):
+ while 1:
self.imageShow()
k = cv2.waitKey(1) & 0xFF
- if k == ord('d'):
+ if k == ord("d"):
# Delete last line
self.deleteLastLine()
- elif k == ord('r'):
+ elif k == ord("r"):
# Clear current gaplines
self.nextImage()
- elif k == ord('s'):
+ elif k == ord("s"):
# Save gaplines with image
if self.saveData():
self.idx += 1
@@ -101,7 +100,7 @@ def run(self):
if not self.blockLoad():
break
self.nextImage()
- elif k == ord('n'):
+ elif k == ord("n"):
# Skip to next image
self.idx += 1
if self.idx >= len(self.images):
@@ -116,29 +115,32 @@ def run(self):
def blockLoad(self):
self.images, self.labels = loadImages(
- self.data_loc, self.org_idx + self.idx, 100)
+ self.data_loc, self.org_idx + self.idx, 100
+ )
self.org_idx += self.idx
self.idx = 0
return len(self.images) is not 0
def imageShow(self):
cv2.imshow(
- 'image',
+ "image",
cv2.resize(
self.image_act,
- (0,0),
+ (0, 0),
fx=self.scaleF,
fy=self.scaleF,
- interpolation=cv2.INTERSECT_NONE))
+ interpolation=cv2.INTERSECT_NONE,
+ ),
+ )
def nextImage(self):
self.image_act = cv2.cvtColor(self.images[self.idx], cv2.COLOR_GRAY2RGB)
self.label_act = self.labels[self.idx][:-4]
- self.gaplines = [0, self.image_act.shape[1]]
+ self.gaplines = [0, self.image_act.shape[1]]
self.redrawLines()
print(self.org_idx + self.idx, ":", self.label_act.split("_")[0])
- self.imageShow();
+ self.imageShow()
def saveData(self):
self.gaplines.sort()
@@ -148,9 +150,9 @@ def saveData(self):
assert len(self.gaplines) - 1 == len(self.label_act.split("_")[0])
cv2.imwrite(
- self.save_loc + '%s.jpg' % (self.label_act),
- self.images[self.idx])
- with open(self.save_loc + '%s.txt' % (self.label_act), 'w') as fp:
+ self.save_loc + "%s.jpg" % (self.label_act), self.images[self.idx]
+ )
+ with open(self.save_loc + "%s.txt" % (self.label_act), "w") as fp:
simplejson.dump(self.gaplines, fp)
return True
except:
@@ -171,8 +173,7 @@ def redrawLines(self):
self.drawLine(x)
def drawLine(self, x):
- cv2.line(
- self.image_act, (x, 0), (x, self.image_act.shape[0]), (0,255,0), 1)
+ cv2.line(self.image_act, (x, 0), (x, self.image_act.shape[0]), (0, 255, 0), 1)
def mouseHandler(self, event, x, y, flags, param):
# Clip x into image width range
@@ -194,26 +195,20 @@ def mouseHandler(self, event, x, y, flags, param):
self.drawLine(x)
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- "Script creating UI for gaplines classification")
- parser.add_argument(
- "--index",
- type=int,
- default=0,
- help="Index of starting image")
-
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Script creating UI for gaplines classification")
+ parser.add_argument("--index", type=int, default=0, help="Index of starting image")
+
parser.add_argument(
- "--data",
- type=str,
- default='data/words_raw',
- help="Path to folder with images")
-
+ "--data", type=str, default="data/words_raw", help="Path to folder with images"
+ )
+
parser.add_argument(
"--save",
type=str,
- default='data/words2',
- help="Path to folder for saving images with gaplines")
+ default="data/words2",
+ help="Path to folder for saving images with gaplines",
+ )
args = parser.parse_args()
Cycler(args.index, args.data, args.save)
diff --git a/src/data/data_creation/words_labeling.ipynb b/handwriting_ocr/data/data_creation/words_labeling.ipynb
similarity index 100%
rename from src/data/data_creation/words_labeling.ipynb
rename to handwriting_ocr/data/data_creation/words_labeling.ipynb
diff --git a/handwriting_ocr/data/data_loader.py b/handwriting_ocr/data/data_loader.py
new file mode 100644
index 0000000..1f75c47
--- /dev/null
+++ b/handwriting_ocr/data/data_loader.py
@@ -0,0 +1,280 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Modelu for downloading datasets and loading data (DATASETS list)."""
+
+import xml.etree.ElementTree
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+
+import cv2 as cv
+import numpy as np
+from handwriting_ocr.data.loader import Loader
+
+DATA_FOLDER = Path(__file__).parent.joinpath("../../data/")
+
+
+class Dataset(Loader):
+ """Abstract class for managing data.
+
+ Attributes:
+ name (str): Name of dataset. It is used for naming appropriate folders.
+ files (List[Tuple[str, str, str, str]]): List of datasets' files/folders (URL,
+ tmp file, final file or folder, dataset type folder)
+ require_auth (bool): If authentication is required (default = False)
+ username (str): (Optional) username for authentication during donwload
+ password (str): (Optional) password for authentication during donwload
+ """
+
+ @abstractmethod
+ def load(self, data_path):
+ """Returns path to line images with corresponding labels (sorted by path).
+
+ Args:
+ data_path (Path): Path to data folder
+
+ Returns:
+ lines (List[Tuple[Path, str]]): List of tuples containing path to image and
+ label. It should always return images in same order (sort on return).
+ """
+ ...
+
+ def __str__(self):
+ return f"dataset-{self.name}"
+
+
+class Breta(Dataset):
+ """Handwriting data from Břetislav Hájek."""
+
+ name = "breta"
+ # TODO: Remove data.zip archive (no longer in use)
+ files = [
+ (
+ "https://drive.google.com/uc?id=1y6Kkcfk4DkEacdy34HJtwjPVa1ZhyBgg",
+ "data.zip",
+ "",
+ "raw",
+ ),
+ (
+ "https://drive.google.com/uc?id=1p7tZWzK0yWZO35lipNZ_9wnfXRNIZOqj",
+ "data.zip",
+ "",
+ "processed",
+ ),
+ ]
+
+ def __init__(self, name="breta"):
+ self.name = name
+
+ def load(self, data_path):
+ folder = data_path / self.files[0][3] / self.name / self.files[0][2]
+ return sorted((p, p.name.split("_")[0]) for p in folder.glob("**/*.png"))
+
+
+class CVL(Dataset):
+ """CVL Database
+ More info at: https://zenodo.org/record/1492267#.Xob4lPGxXeR
+ """
+
+ name = "cvl"
+ files = [
+ (
+ "https://zenodo.org/record/1492267/files/cvl-database-1-1.zip",
+ "cvl-database-1-1.zip",
+ "",
+ "raw",
+ )
+ ]
+
+ def __init__(self, name="cvl"):
+ self.name = name
+
+ def load(self, data_path):
+ lines = []
+
+ folder = data_path / self.files[0][3] / self.name / self.files[0][2]
+ l_dic = {}
+ for xf in folder.glob("**/xml/*.xml"):
+ try:
+ with open(xf, "r") as f:
+ root = xml.etree.ElementTree.fromstring(f.read())
+ except:
+ with open(xf, "r", encoding="iso-8859-15") as f:
+ root = xml.etree.ElementTree.fromstring(f.read())
+ # Get tag schema
+ tg = root.tag[: -len(root.tag.split("}", 1)[1])]
+ for attr in root.findall(
+ f".//{tg}AttrRegion[@attrType='2'][@fontType='2']"
+ ):
+ target = " ".join(
+ x.get("text") for x in attr.findall(f"{tg}AttrRegion[@text]")
+ )
+ if len(target) != 0:
+ l_dic[attr.get("id")] = target
+
+ ln_f = folder / self.files[0][2]
+ return sorted(
+ (p, l_dic[p.with_suffix("").name])
+ for p in ln_f.glob("**/lines/*/*.tif")
+ if p.with_suffix("").name in l_dic
+ )
+
+
+class IAM(Dataset):
+ """IAM Handwriting Database
+ More info at: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
+ """
+
+ name = "iam"
+ require_auth = True
+ files = [
+ (
+ "http://www.fki.inf.unibe.ch/DBs/iamDB/data/ascii/lines.txt",
+ "lines.txt",
+ "lines.txt",
+ "raw",
+ ),
+ (
+ "http://www.fki.inf.unibe.ch/DBs/iamDB/data/lines/lines.tgz",
+ "lines.tgz",
+ "lines",
+ "raw",
+ ),
+ ]
+
+ def __init__(self, name="iam", username=None, password=None):
+ self.name = name
+ self.username = username
+ self.password = password
+
+ def load(self, data_path):
+ lines = []
+
+ folder = data_path / self.files[0][3] / self.name
+ with open(folder.joinpath(self.files[0][2]), "r") as f:
+ labels = [l.strip() for l in f if l.strip()[0] != "#"]
+ labels = map(lambda x: (x.split(" ")[0], x.split(" ", 8)[-1]), labels)
+ l_dic = {im: label.replace("|", " ") for im, label in labels}
+
+ ln_f = folder / self.files[1][2]
+ return sorted((p, l_dic[p.with_suffix("").name]) for p in ln_f.glob("**/*.png"))
+
+
+class ORAND(Dataset):
+ """ORAND CAR 2014 dataset
+ More info at: https://www.orand.cl/icfhr2014-hdsr/#datasets
+ """
+
+ name = "orand"
+ files = [
+ (
+ "https://www.orand.cl/orand_car/ORAND-CAR-2014.tar.gz",
+ "ORAND-CAR-2014.tar.gz",
+ "",
+ "raw",
+ )
+ ]
+
+ def __init__(self, name="orand"):
+ self.name = name
+
+ def load(self, data_path):
+ lines = []
+
+ folder = data_path / self.files[0][3] / self.name / self.files[0][2]
+ for label_f in folder.glob("**/CAR-*/*.txt"):
+ im_folder = Path(str(label_f)[:-6] + "images")
+ with open(label_f, "r") as f:
+ labels = map(lambda x: x.strip().split("\t"), f)
+ lines.extend((im_folder.joinpath(im), w) for im, w in labels)
+ return sorted(lines)
+
+
+class Camb(Dataset):
+ """Cambridge Handwriting Database
+ More info at: ftp://svr-ftp.eng.cam.ac.uk/pub/data/handwriting_databases.README
+ """
+
+ name = "camb"
+ files = [
+ (
+ "ftp://svr-ftp.eng.cam.ac.uk/pub/data/handwriting_databases.README",
+ "handwriting_databases.README",
+ "handwriting_databases.README",
+ "raw",
+ ),
+ ("ftp://svr-ftp.eng.cam.ac.uk/pub/data/lob.tar", "lob.tar", "lob", "raw"),
+ (
+ "ftp://svr-ftp.eng.cam.ac.uk/pub/data/numbers.tar",
+ "numbers.tar",
+ "numbers",
+ "raw",
+ ),
+ ]
+
+ def __init__(self, name="camb"):
+ self.name = name
+
+ def post_download(self, data_path):
+ print(f"Running post-download processing on {self.name}...")
+ folder = data_path / self.files[0][3] / self.name
+ output = folder / "extracted"
+ output.mkdir(parents=True, exist_ok=True)
+
+ for i, seg_f in enumerate(sorted(folder.glob("**/*.seg"))):
+ with gzip.open(seg_f.with_suffix(".tiff.gz"), "rb") as f:
+ buff = np.frombuffer(f.read(), dtype=np.int8)
+ image = cv.imdecode(buff, cv.IMREAD_UNCHANGED)
+
+ with open(seg_f, "r") as f:
+ f.readline()
+ for line in f:
+ rect = list(map(int, line.strip().split(" ")[1:]))
+ word = line.split(" ")[0]
+ im = image[rect[2] : rect[3], rect[0] : rect[1]]
+
+ if 0 in im.shape:
+ continue
+ cv.imwrite(str(output.joinpath(f"{word}_{i:04}.png")), im)
+
+ def load(self, data_path):
+ folder = data_path / self.files[0][3] / self.name / "extracted"
+ return sorted((p, p.name.split("_")[0]) for p in folder.glob("**/*.png"))
+
+
+class NIST(Dataset):
+ """NIST SD 19 - character dataset
+ More info at: https://www.nist.gov/srd/nist-special-database-19
+ """
+
+ name = "nist"
+ files = [
+ (
+ "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip",
+ "by_class.zip",
+ "",
+ "raw",
+ ),
+ ]
+
+ def __init__(self, name="nist"):
+ self.name = name
+
+ def load(self, data_path):
+ # TODO: Generate lines from NIST characters
+ return []
+
+ def load_characters(self, data_path):
+ folder = data_path / self.files[0][3] / self.name / self.files[0][2]
+ return sorted(
+ (p, chr(int(p.name.split("_")[1], 16)))
+ for p in folder.glob("**/trian_*/*.png")
+ )
+
+
+DATASETS = [Breta(), CVL(), IAM(), ORAND(), Camb(), NIST()]
+
+
+if __name__ == "__main__":
+ for d in DATASETS:
+ d.download(DATA_FOLDER)
+ print(d, len(d.load(DATA_FOLDER)))
diff --git a/handwriting_ocr/data/loader.py b/handwriting_ocr/data/loader.py
new file mode 100644
index 0000000..4dec824
--- /dev/null
+++ b/handwriting_ocr/data/loader.py
@@ -0,0 +1,186 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Module for downloading data."""
+
+import getpass
+import gzip
+import tarfile
+import urllib.request
+import zipfile
+from abc import ABCMeta, abstractmethod
+
+import gdown
+from tqdm import tqdm
+
+
+class Progressbar(tqdm):
+ """Helper class for download progressbar."""
+
+ def update_to(self, b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ self.total = tsize
+ self.update(b * bsize - self.n)
+
+
+def download_url(url, output, username=None, password=None):
+ """Download file from URL to output location.
+
+ Args:
+ url (str): URL for downloading the file
+ output (Path): Path where should be downloaded file stored
+ username (str): (Optional) username for authentication
+ password (str): (Optional) username for authentication
+
+ Returns:
+ status (int): Returns status of download (200: OK, 401: Unauthorized,
+ -1: Unknown)
+ """
+
+ if "drive.google.com" in url:
+ gdown.download(url, str(output), quiet=False)
+ return 200
+
+ for _ in range(3):
+ if username or password:
+ # create a password manager
+ password_mgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
+ password_mgr.add_password(None, url, username, password)
+ handler = urllib.request.HTTPBasicAuthHandler(password_mgr)
+ # create "opener"
+ opener = urllib.request.build_opener(handler)
+ urllib.request.install_opener(opener)
+
+ with Progressbar(
+ unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
+ ) as t:
+ try:
+ urllib.request.urlretrieve(url, filename=output, reporthook=t.update_to)
+ return 200
+ except urllib.error.HTTPError as e:
+ if hasattr(e, "code") and e.code == 401:
+ return 401
+ print(f"\nError occured during download:\n{e}")
+ return -1
+
+
+def file_extract(file_path, out_path):
+ """Extract archive file into given location.
+
+ Args:
+ file_path (Path): Path of archive file
+ out_path (Path): Path of folder where should be the file extracted
+ """
+ print(f"Extracting {file_path} file...")
+ if file_path.suffix == ".zip":
+
+ def open_file(x):
+ return zipfile.ZipFile(x, "r")
+
+ elif file_path.suffix in [".gz", ".tgz"]:
+
+ def open_file(x):
+ return tarfile.open(x, "r:gz")
+
+ elif file_path.suffix == ".tar":
+
+ def open_file(x):
+ return tarfile.open(x, "r")
+
+ with open_file(file_path) as data_file:
+ data_file.extractall(out_path)
+
+
+class Loader(metaclass=ABCMeta):
+ """Abstract class for downloading data.
+
+ Attributes:
+ name (str): Name of data. It is used for naming appropriate folders.
+ files (List[Tuple[str, str, str, str]]): List of files/folders (URL, tmp file,
+ final file or folder, dataset type folder)
+ require_auth (bool): If authentication is required (default = False)
+ username (str): (Optional) username for authentication during donwload
+ password (str): (Optional) password for authentication during donwload
+ """
+
+ require_auth = False
+ username, password = None, None
+
+ @property
+ @abstractmethod
+ def name(self):
+ ...
+
+ @property
+ @abstractmethod
+ def files(self):
+ ...
+
+ def clear(self, data_path):
+ """Clear all downloaded files.
+
+ Args:
+ data_path (Path): Path to data folder
+ """
+ for _, _, res, folder in self.files:
+ d = data_path / folder / self.name
+ if not d.exists():
+ shutil.rmtree(d)
+
+ def is_downloaded(self, data_path):
+ """Check if files are downloaded.
+
+ Args:
+ data_path (Path): Path to data folder
+
+ Returns:
+ is_downloaded (bool): True if files are downloaded (no folder missing)
+ """
+ for _, _, res, folder in self.files:
+ if not data_path.joinpath(folder, self.name, res).exists():
+ return False
+ return True
+
+ def download(self, data_path):
+ print(f"Collecting {self}...")
+ downloaded = False
+ for url, f, res, folder in self.files:
+ folder = data_path / folder / self.name
+ tmp_output = folder / f
+ res_output = folder / res
+
+ if not res_output.exists():
+ tmp_output.parent.mkdir(parents=True, exist_ok=True)
+ # Try the authentication 3 times
+ for i in range(3):
+ if self.require_auth:
+ if not self.username:
+ self.username = input(f"Username for {self}: ")
+ if not self.password:
+ self.password = getpass.getpass(f"Password for {self}: ")
+ status = download_url(url, tmp_output, self.username, self.password)
+ if status == 200:
+ break
+
+ if status == 401 and i < 2:
+ print("Invalid username or password, please try again.")
+ else:
+ print(f"{self} skipped.")
+ return
+ downloaded = True
+
+ if tmp_output.suffix in [".zip", ".gz", ".tgz", ".tar"]:
+ file_extract(tmp_output, res_output)
+ tmp_output.unlink()
+
+ if downloaded:
+ self.post_download(data_path)
+
+ def post_download(self, data_path):
+ """Run post-processing on downloaded data (e.g. cut lines from form images)
+
+ Args:
+ data_path (Path): Path to data folder
+ """
+
+ def __str__(self):
+ return self.name
diff --git a/handwriting_ocr/data/model_loader.py b/handwriting_ocr/data/model_loader.py
new file mode 100644
index 0000000..968b6c2
--- /dev/null
+++ b/handwriting_ocr/data/model_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Modelu for downloading and providing pre-trained models."""
+
+from handwriting_ocr.data.loader import Loader
+
+MODEL_FOLDER = Path(__file__).parent.joinpath("../../models/")
+
+
+class Models(Loader):
+ """Download pre-trained models."""
+
+ name = "models"
+ files = [
+ (
+ "https://drive.google.com/open?id=1YbmsiJK3Wclfm6K8PrJuz-QROEKX1qis"
+ "ocr-handwriting-models.zip",
+ "",
+ "",
+ ),
+ ]
+
+ def __init__(self, name="models"):
+ self.name = name
diff --git a/src/data/datasets/__init__.py b/handwriting_ocr/ocr/__init__.py
similarity index 100%
rename from src/data/datasets/__init__.py
rename to handwriting_ocr/ocr/__init__.py
diff --git a/src/ocr/characters.py b/handwriting_ocr/ocr/characters.py
similarity index 66%
rename from src/ocr/characters.py
rename to handwriting_ocr/ocr/characters.py
index 3cd2a53..64764a6 100644
--- a/src/ocr/characters.py
+++ b/handwriting_ocr/ocr/characters.py
@@ -1,23 +1,26 @@
-# -*- coding: utf-8 -*-
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+
import os
+import math
+
+import cv2
import numpy as np
import tensorflow as tf
-import cv2
-import math
-from .helpers import *
-from .tfhelpers import Model
+from handwriting_ocr.ocr.helpers import *
+from handwriting_ocr.ocr.tfhelpers import Model
+
# Preloading trained model with activation function
# Loading is slow -> prevent multiple loads
print("Loading segmentation models...")
location = os.path.dirname(os.path.abspath(__file__))
-CNN_model = Model(
- os.path.join(location, '../../models/gap-clas/CNN-CG'))
+CNN_model = Model(os.path.join(location, "../../models/gap-clas/CNN-CG"))
CNN_slider = (60, 30)
RNN_model = Model(
- os.path.join(location, '../../models/gap-clas/RNN/Bi-RNN-new'),
- 'prediction')
+ os.path.join(location, "../../models/gap-clas/RNN/Bi-RNN-new"), "prediction"
+)
RNN_slider = (60, 60)
@@ -25,20 +28,24 @@ def _classify(img, step=2, RNN=False, slider=(60, 60)):
"""Slice the image and return raw output of classifier."""
length = (img.shape[1] - slider[1]) // 2 + 1
if RNN:
- input_seq = np.zeros((1, length, slider[0]*slider[1]), dtype=np.float32)
- input_seq[0][:] = [img[:, loc * step: loc * step + slider[1]].flatten()
- for loc in range(length)]
- pred = RNN_model.eval_feed({'inputs:0': input_seq,
- 'length:0': [length],
- 'keep_prob:0': 1})[0]
+ input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)
+ input_seq[0][:] = [
+ img[:, loc * step : loc * step + slider[1]].flatten()
+ for loc in range(length)
+ ]
+ pred = RNN_model.eval_feed(
+ {"inputs:0": input_seq, "length:0": [length], "keep_prob:0": 1}
+ )[0]
else:
- input_seq = np.zeros((length, slider[0]*slider[1]), dtype=np.float32)
- input_seq[:] = [img[:, loc * step: loc * step + slider[1]].flatten()
- for loc in range(length)]
+ input_seq = np.zeros((length, slider[0] * slider[1]), dtype=np.float32)
+ input_seq[:] = [
+ img[:, loc * step : loc * step + slider[1]].flatten()
+ for loc in range(length)
+ ]
pred = CNN_model.run(input_seq)
-
+
return pred
-
+
def segment(img, step=2, RNN=False, debug=False):
"""Take preprocessed image of word and
@@ -47,7 +54,7 @@ def segment(img, step=2, RNN=False, debug=False):
slider = CNN_slider
if RNN:
slider = RNN_slider
-
+
# Run the classifier
pred = _classify(img, step=step, RNN=RNN, slider=slider)
@@ -84,17 +91,14 @@ def segment(img, step=2, RNN=False, debug=False):
if gap_block_first != 0:
gaps.append(int(gap_block_first))
else:
- gap_position_sum += (len(pred) - 1) * 2 + slider[1]/2
+ gap_position_sum += (len(pred) - 1) * 2 + slider[1] / 2
gaps.append(int(gap_position_sum / (gap_count + 1)))
-
+
if debug:
# Drawing lines
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
for gap in gaps:
- cv2.line(img,
- ((int)(gap), 0),
- ((int)(gap), slider[0]),
- (0, 255, 0), 1)
+ cv2.line(img, ((int)(gap), 0), ((int)(gap), slider[0]), (0, 255, 0), 1)
implt(img, t="Separated characters")
-
- return gaps
\ No newline at end of file
+
+ return gaps
diff --git a/src/ocr/datahelpers.py b/handwriting_ocr/ocr/datahelpers.py
similarity index 60%
rename from src/ocr/datahelpers.py
rename to handwriting_ocr/ocr/datahelpers.py
index d584760..f0e878d 100644
--- a/src/ocr/datahelpers.py
+++ b/handwriting_ocr/ocr/datahelpers.py
@@ -1,45 +1,109 @@
-# -*- coding: utf-8 -*-
-"""
-Helper functions for loading and creating datasets
-"""
-import numpy as np
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Helper functions for loading and creating datasets"""
+
import glob
import simplejson
import os
-import cv2
import csv
import sys
+
+import cv2
import unidecode
+import numpy as np
-from .helpers import implt
-from .normalization import letter_normalization
-from .viz import print_progress_bar
+from handwriting_ocr.ocr.helpers import implt
+from handwriting_ocr.ocr.normalization import letter_normalization
-CHARS = ['', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
- 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S',
- 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
- 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
- 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w',
- 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6',
- '7', '8', '9', '.', '-', '+', "'"]
+CHARS = [
+ "",
+ "A",
+ "B",
+ "C",
+ "D",
+ "E",
+ "F",
+ "G",
+ "H",
+ "I",
+ "J",
+ "K",
+ "L",
+ "M",
+ "N",
+ "O",
+ "P",
+ "Q",
+ "R",
+ "S",
+ "T",
+ "U",
+ "V",
+ "W",
+ "X",
+ "Y",
+ "Z",
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "0",
+ "1",
+ "2",
+ "3",
+ "4",
+ "5",
+ "6",
+ "7",
+ "8",
+ "9",
+ ".",
+ "-",
+ "+",
+ "'",
+]
CHAR_SIZE = len(CHARS)
idxs = [i for i in range(len(CHARS))]
idx_2_chars = dict(zip(idxs, CHARS))
chars_2_idx = dict(zip(CHARS, idxs))
+
def char2idx(c, sequence=False):
if sequence:
return chars_2_idx[c] + 1
return chars_2_idx[c]
+
def idx2char(idx, sequence=False):
if sequence:
- return idx_2_chars[idx-1]
+ return idx_2_chars[idx - 1]
return idx_2_chars[idx]
-
-def load_words_data(dataloc='data/words/', is_csv=False, load_gaplines=False):
+
+def load_words_data(dataloc="data/words/", is_csv=False, load_gaplines=False):
"""
Load word images with corresponding labels and gaplines (if load_gaplines == True).
Args:
@@ -59,36 +123,29 @@ def load_words_data(dataloc='data/words/', is_csv=False, load_gaplines=False):
for loc in dataloc:
with open(loc) as csvfile:
reader = csv.reader(csvfile)
- length += max(sum(1 for row in csvfile)-1, 0)
+ length += max(sum(1 for row in csvfile) - 1, 0)
labels = np.empty(length, dtype=object)
images = np.empty(length, dtype=object)
- i = 0
for loc in dataloc:
print(loc)
with open(loc) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
- shape = np.fromstring(
- row['shape'],
- sep=',',
- dtype=int)
- img = np.fromstring(
- row['image'],
- sep=', ',
- dtype=np.uint8).reshape(shape)
- labels[i] = row['label']
+ shape = np.fromstring(row["shape"], sep=",", dtype=int)
+ img = np.fromstring(row["image"], sep=", ", dtype=np.uint8).reshape(
+ shape
+ )
+ labels[i] = row["label"]
images[i] = img
-
- print_progress_bar(i, length)
- i += 1
+
else:
img_list = []
tmp_labels = []
for loc in dataloc:
- tmp_list = glob.glob(os.path.join(loc, '*.png'))
+ tmp_list = glob.glob(os.path.join(loc, "*.png"))
img_list += tmp_list
- tmp_labels += [name[len(loc):].split("_")[0] for name in tmp_list]
+ tmp_labels += [name[len(loc) :].split("_")[0] for name in tmp_list]
labels = np.array(tmp_labels)
images = np.empty(len(img_list), dtype=object)
@@ -96,21 +153,20 @@ def load_words_data(dataloc='data/words/', is_csv=False, load_gaplines=False):
# Load grayscaled images
for i, img in enumerate(img_list):
images[i] = cv2.imread(img, 0)
- print_progress_bar(i, len(img_list))
# Load gaplines (lines separating letters) from txt files
if load_gaplines:
gaplines = np.empty(len(img_list), dtype=object)
for i, name in enumerate(img_list):
- with open(name[:-3] + 'txt', 'r') as fp:
+ with open(name[:-3] + "txt", "r") as fp:
gaplines[i] = np.array(simplejson.load(fp))
-
+
if load_gaplines:
assert len(labels) == len(images) == len(gaplines)
else:
assert len(labels) == len(images)
print("-> Number of words:", len(labels))
-
+
if load_gaplines:
return (images, labels, gaplines)
return (images, labels)
@@ -120,24 +176,24 @@ def _words2chars(images, labels, gaplines):
"""Transform word images with gaplines into individual chars."""
# Total number of chars
length = sum([len(l) for l in labels])
-
+
imgs = np.empty(length, dtype=object)
new_labels = []
-
+
height = images[0].shape[0]
-
- idx = 0;
+
+ idx = 0
for i, gaps in enumerate(gaplines):
for pos in range(len(gaps) - 1):
- imgs[idx] = images[i][0:height, gaps[pos]:gaps[pos+1]]
+ imgs[idx] = images[i][0:height, gaps[pos] : gaps[pos + 1]]
new_labels.append(char2idx(labels[i][pos]))
idx += 1
-
- print("Loaded chars from words:", length)
+
+ print("Loaded chars from words:", length)
return imgs, new_labels
-def load_chars_data(charloc='data/charclas/', wordloc='data/words/', lang='cz'):
+def load_chars_data(charloc="data/charclas/", wordloc="data/words/", lang="cz"):
"""
Load chars images with corresponding labels.
Args:
@@ -150,47 +206,48 @@ def load_chars_data(charloc='data/charclas/', wordloc='data/words/', lang='cz'):
images = np.zeros((1, 4096))
labels = []
- if charloc != '':
+ if charloc != "":
# Get subfolders with chars
dir_list = glob.glob(os.path.join(charloc, lang, "*/"))
- dir_list.sort()
+ dir_list.sort()
# if lang == 'en':
chars = CHARS[:53]
-
- assert [d[-2] if d[-2] != '0' else '' for d in dir_list] == chars
+
+ assert [d[-2] if d[-2] != "0" else "" for d in dir_list] == chars
# For every label load images and create corresponding labels
# cv2.imread(img, 0) - for loading images in grayscale
# Images are scaled to 64x64 = 4096 px
for i in range(len(chars)):
- img_list = glob.glob(os.path.join(dir_list[i], '*.jpg'))
- imgs = np.array([letter_normalization(cv2.imread(img, 0)) for img in img_list])
+ img_list = glob.glob(os.path.join(dir_list[i], "*.jpg"))
+ imgs = np.array(
+ [letter_normalization(cv2.imread(img, 0)) for img in img_list]
+ )
images = np.concatenate([images, imgs.reshape(len(imgs), 4096)])
labels.extend([i] * len(imgs))
-
- if wordloc != '':
+
+ if wordloc != "":
imgs, words, gaplines = load_words_data(wordloc, load_gaplines=True)
- if lang != 'cz':
- words = np.array([unidecode.unidecode(w) for w in words])
+ if lang != "cz":
+ words = np.array([unidecode.unidecode(w) for w in words])
imgs, chars = _words2chars(imgs, words, gaplines)
-
+
labels.extend(chars)
- images2 = np.zeros((len(imgs), 4096))
+ images2 = np.zeros((len(imgs), 4096))
for i in range(len(imgs)):
- print_progress_bar(i, len(imgs))
images2[i] = letter_normalization(imgs[i]).reshape(1, 4096)
- images = np.concatenate([images, images2])
+ images = np.concatenate([images, images2])
images = images[1:]
labels = np.array(labels)
-
+
print("-> Number of chars:", len(labels))
return (images, labels)
-def load_gap_data(loc='data/gapdet/large/', slider=(60, 120), seq=False, flatten=True):
+def load_gap_data(loc="data/gapdet/large/", slider=(60, 120), seq=False, flatten=True):
"""
Load gap data from location with corresponding labels.
Args:
@@ -202,51 +259,70 @@ def load_gap_data(loc='data/gapdet/large/', slider=(60, 120), seq=False, flatten
Returns:
(images, labels)
"""
- print('Loading gap data...')
+ print("Loading gap data...")
dir_list = glob.glob(os.path.join(loc, "*/"))
dir_list.sort()
-
+
if slider[1] > 120:
# TODO Implement for higher dimmensions
slider[1] = 120
-
- cut_s = None if (120 - slider[1]) // 2 <= 0 else (120 - slider[1]) // 2
+
+ cut_s = None if (120 - slider[1]) // 2 <= 0 else (120 - slider[1]) // 2
cut_e = None if (120 - slider[1]) // 2 <= 0 else -(120 - slider[1]) // 2
-
+
if seq:
images = np.empty(len(dir_list), dtype=object)
labels = np.empty(len(dir_list), dtype=object)
-
+
for i, loc in enumerate(dir_list):
# TODO Check for empty directories
- img_list = glob.glob(os.path.join(loc, '*.jpg'))
- if (len(img_list) != 0):
- img_list = sorted(imglist, key=lambda x: int(x[len(loc):].split("_")[1][:-4]))
- images[i] = np.array([(cv2.imread(img, 0)[:, cut_s:cut_e].flatten() if flatten else
- cv2.imread(img, 0)[:, cut_s:cut_e])
- for img in img_list])
- labels[i] = np.array([int(name[len(loc):].split("_")[0]) for name in img_list])
-
+ img_list = glob.glob(os.path.join(loc, "*.jpg"))
+ if len(img_list) != 0:
+ img_list = sorted(
+ imglist, key=lambda x: int(x[len(loc) :].split("_")[1][:-4])
+ )
+ images[i] = np.array(
+ [
+ (
+ cv2.imread(img, 0)[:, cut_s:cut_e].flatten()
+ if flatten
+ else cv2.imread(img, 0)[:, cut_s:cut_e]
+ )
+ for img in img_list
+ ]
+ )
+ labels[i] = np.array(
+ [int(name[len(loc) :].split("_")[0]) for name in img_list]
+ )
+
else:
- images = np.zeros((1, slider[0]*slider[1]))
+ images = np.zeros((1, slider[0] * slider[1]))
labels = []
for i in range(len(dir_list)):
- img_list = glob.glob(os.path.join(dir_list[i], '*.jpg'))
- if (len(img_list) != 0):
- imgs = np.array([cv2.imread(img, 0)[:, cut_s:cut_e] for img in img_list])
- images = np.concatenate([images, imgs.reshape(len(imgs), slider[0]*slider[1])])
+ img_list = glob.glob(os.path.join(dir_list[i], "*.jpg"))
+ if len(img_list) != 0:
+ imgs = np.array(
+ [cv2.imread(img, 0)[:, cut_s:cut_e] for img in img_list]
+ )
+ images = np.concatenate(
+ [images, imgs.reshape(len(imgs), slider[0] * slider[1])]
+ )
labels.extend([int(img[len(dirlist[i])]) for img in img_list])
images = images[1:]
labels = np.array(labels)
-
+
if seq:
- print("-> Number of words / gaps and letters:",
- len(labels), '/', sum([len(l) for l in labels]))
+ print(
+ "-> Number of words / gaps and letters:",
+ len(labels),
+ "/",
+ sum([len(l) for l in labels]),
+ )
else:
print("-> Number of gaps and letters:", len(labels))
- return (images, labels)
+ return (images, labels)
def corresponding_shuffle(a):
@@ -277,11 +353,13 @@ def sequences_to_sparse(sequences):
values = []
for n, seq in enumerate(sequences):
- indices.extend(zip([n]*len(seq), range(len(seq))))
+ indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
-
+
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=np.int32)
- shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1]+1], dtype=np.int64)
+ shape = np.asarray(
+ [len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64
+ )
return indices, values, shape
diff --git a/src/ocr/dataiterator.py b/handwriting_ocr/ocr/dataiterator.py
similarity index 61%
rename from src/ocr/dataiterator.py
rename to handwriting_ocr/ocr/dataiterator.py
index 716e940..c9a1d88 100644
--- a/src/ocr/dataiterator.py
+++ b/handwriting_ocr/ocr/dataiterator.py
@@ -1,21 +1,27 @@
-# -*- coding: utf-8 -*-
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
"""Classes for feeding data during training."""
+
import numpy as np
import pandas as pd
-from .helpers import img_extend
-from .datahelpers import sequences_to_sparse
+
+from handwriting_ocr.ocr.helpers import img_extend
+from handwriting_ocr.ocr.datahelpers import sequences_to_sparse
-class BucketDataIterator():
+class BucketDataIterator:
"""Iterator for feeding CTC model during training."""
- def __init__(self,
- images,
- targets,
- num_buckets=5,
- slider=(60, 30),
- augmentation=None,
- dropout=0.0,
- train=True):
+
+ def __init__(
+ self,
+ images,
+ targets,
+ num_buckets=5,
+ slider=(60, 30),
+ augmentation=None,
+ dropout=0.0,
+ train=True,
+ ):
self.train = train
self.slider = slider
@@ -23,24 +29,25 @@ def __init__(self,
self.dropout = dropout
for i in range(len(images)):
images[i] = img_extend(
- images[i],
- (self.slider[0],
- max(images[i].shape[1], self.slider[1])))
+ images[i], (self.slider[0], max(images[i].shape[1], self.slider[1]))
+ )
in_length = [image.shape[1] for image in images]
-
+
# Create pandas dataFrame and sort it by images width (length)
- self.dataFrame = pd.DataFrame({
- 'in_length': in_length,
- 'images': images,
- 'targets': targets}).sort_values('in_length').reset_index(drop=True)
+ self.dataFrame = (
+ pd.DataFrame({"in_length": in_length, "images": images, "targets": targets})
+ .sort_values("in_length")
+ .reset_index(drop=True)
+ )
bsize = int(len(images) / num_buckets)
self.num_buckets = num_buckets
self.buckets = []
- for bucket in range(num_buckets-1):
+ for bucket in range(num_buckets - 1):
self.buckets.append(
- self.dataFrame.iloc[bucket * bsize: (bucket+1) * bsize])
- self.buckets.append(self.dataFrame.iloc[(num_buckets-1) * bsize:])
+ self.dataFrame.iloc[bucket * bsize : (bucket + 1) * bsize]
+ )
+ self.buckets.append(self.dataFrame.iloc[(num_buckets - 1) * bsize :])
self.buckets_size = [len(bucket) for bucket in self.buckets]
self.cursor = np.array([0] * num_buckets)
@@ -49,14 +56,12 @@ def __init__(self,
self.shuffle()
print("Iterator created.")
-
def shuffle(self, idx=None):
"""Shuffle idx bucket or each bucket separately."""
for i in [idx] if idx is not None else range(self.num_buckets):
self.buckets[i] = self.buckets[i].sample(frac=1).reset_index(drop=True)
self.cursor[i] = 0
-
def next_batch(self, batch_size):
"""Creates next training batch of size.
Args:
@@ -74,25 +79,26 @@ def next_batch(self, batch_size):
self.shuffle(i_bucket)
# Handle too big batch sizes
- if (batch_size > self.buckets_size[i_bucket]):
+ if batch_size > self.buckets_size[i_bucket]:
batch_size = self.buckets_size[i_bucket]
- res = self.buckets[i_bucket].iloc[self.cursor[i_bucket]:
- self.cursor[i_bucket]+batch_size]
+ res = self.buckets[i_bucket].iloc[
+ self.cursor[i_bucket] : self.cursor[i_bucket] + batch_size
+ ]
self.cursor[i_bucket] += batch_size
# PAD input sequence and output
- input_max = max(res['in_length'])
+ input_max = max(res["in_length"])
input_imgs = np.zeros(
- (batch_size, self.slider[0], input_max, 1), dtype=np.uint8)
- for i, img in enumerate(res['images']):
- input_imgs[i][:, :res['in_length'].values[i], 0] = img
-
+ (batch_size, self.slider[0], input_max, 1), dtype=np.uint8
+ )
+ for i, img in enumerate(res["images"]):
+ input_imgs[i][:, : res["in_length"].values[i], 0] = img
+
if self.train:
input_imgs = self.augmentation.augment_images(input_imgs)
input_imgs = input_imgs.astype(np.float32)
- targets = sequences_to_sparse(res['targets'].values)
- return input_imgs, targets, res['in_length'].values
-
+ targets = sequences_to_sparse(res["targets"].values)
+ return input_imgs, targets, res["in_length"].values
diff --git a/src/ocr/helpers.py b/handwriting_ocr/ocr/helpers.py
similarity index 73%
rename from src/ocr/helpers.py
rename to handwriting_ocr/ocr/helpers.py
index 2c9038f..5d366e5 100644
--- a/src/ocr/helpers.py
+++ b/handwriting_ocr/ocr/helpers.py
@@ -1,16 +1,16 @@
-# -*- coding: utf-8 -*-
-"""
-Helper functions for ocr project
-"""
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Helper functions for ocr project."""
+
+import cv2
import matplotlib.pyplot as plt
import numpy as np
-import cv2
SMALL_HEIGHT = 800
-def implt(img, cmp=None, t=''):
+def implt(img, cmp=None, t=""):
"""Show image using plt."""
plt.imshow(img, cmap=cmp)
plt.title(t)
@@ -19,10 +19,9 @@ def implt(img, cmp=None, t=''):
def resize(img, height=SMALL_HEIGHT, allways=False):
"""Resize image to given height."""
- if (img.shape[0] > height or allways):
+ if img.shape[0] > height or allways:
rat = height / img.shape[0]
return cv2.resize(img, (int(rat * img.shape[1]), height))
-
return img
@@ -41,5 +40,5 @@ def img_extend(img, shape):
Extended image
"""
x = np.zeros(shape, np.uint8)
- x[:img.shape[0], :img.shape[1]] = img
- return x
\ No newline at end of file
+ x[: img.shape[0], : img.shape[1]] = img
+ return x
diff --git a/handwriting_ocr/ocr/imgtransform.py b/handwriting_ocr/ocr/imgtransform.py
new file mode 100644
index 0000000..66453ac
--- /dev/null
+++ b/handwriting_ocr/ocr/imgtransform.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Functions for transforming and preprocessing images for training."""
+
+import cv2
+import numpy as np
+import pandas as pd
+from scipy.ndimage.interpolation import map_coordinates
+
+
+def coordinates_remap(image, factor_alpha, factor_sigma):
+ """Transforming image using remaping coordinates."""
+ alpha = image.shape[1] * factor_alpha
+ sigma = image.shape[1] * factor_sigma
+ shape = image.shape
+
+ blur_size = int(4 * sigma) | 1
+ dx = alpha * cv2.GaussianBlur(
+ (np.random.rand(*shape) * 2 - 1), ksize=(blur_size, blur_size), sigmaX=sigma
+ )
+ dy = alpha * cv2.GaussianBlur(
+ (np.random.rand(*shape) * 2 - 1), ksize=(blur_size, blur_size), sigmaX=sigma
+ )
+
+ x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
+ indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1))
+
+ # TODO use cv2.remap(image, dx, dy, interpolation=cv2.INTER_LINEAR)
+ return np.array(
+ map_coordinates(image, indices, order=1, mode="constant").reshape(shape)
+ )
diff --git a/src/ocr/mlhelpers.py b/handwriting_ocr/ocr/mlhelpers.py
similarity index 69%
rename from src/ocr/mlhelpers.py
rename to handwriting_ocr/ocr/mlhelpers.py
index e2a6ffb..76968b8 100644
--- a/src/ocr/mlhelpers.py
+++ b/handwriting_ocr/ocr/mlhelpers.py
@@ -1,11 +1,12 @@
-# -*- coding: utf-8 -*-
-"""
-Classes for controling machine learning processes
-"""
-import numpy as np
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Classes for controling machine learning processes."""
+
+import csv
import math
+
+import numpy as np
import matplotlib.pyplot as plt
-import csv
class TrainingPlot:
@@ -14,6 +15,7 @@ class TrainingPlot:
REUIRES notebook backend: %matplotlib notebook
@TODO Migrate to Tensorboard
"""
+
train_loss = []
train_acc = []
valid_acc = []
@@ -37,12 +39,12 @@ def __init__(self, steps, test_itr, loss_itr):
self._update_plot()
# Description
- self.ax1.set_xlabel('Iteration')
- self.ax1.set_ylabel('Train Loss')
- self.ax2.set_ylabel('Valid. Accuracy')
+ self.ax1.set_xlabel("Iteration")
+ self.ax1.set_ylabel("Train Loss")
+ self.ax2.set_ylabel("Valid. Accuracy")
# Axes limits
- self.ax1.set_ylim([0,10])
+ self.ax1.set_ylim([0, 10])
def _update_plot(self):
self.fig.canvas.draw()
@@ -51,8 +53,12 @@ def update_loss(self, loss_train, index):
self.trainLoss.append(loss_train)
if len(self.train_loss) == 1:
self.ax1.set_ylim([0, min(10, math.ceil(loss_train))])
- self.ax1.plot(self.lossInterval * np.arange(len(self.train_loss)),
- self.train_loss, 'b', linewidth=1.0)
+ self.ax1.plot(
+ self.lossInterval * np.arange(len(self.train_loss)),
+ self.train_loss,
+ "b",
+ linewidth=1.0,
+ )
self.updatePlot()
@@ -60,18 +66,27 @@ def update_acc(self, acc_val, acc_train, index):
self.validAcc.append(acc_val)
self.trainAcc.append(acc_train)
- self.ax2.plot(self.test_iter * np.arange(len(self.valid_acc)),
- self.valid_acc, 'r', linewidth=1.0)
- self.ax2.plot(self.test_iter * np.arange(len(self.train_acc)),
- self.train_acc, 'g',linewidth=1.0)
-
- self.ax2.set_title('Valid. Accuracy: {:.4f}'.format(self.valid_acc[-1]))
+ self.ax2.plot(
+ self.test_iter * np.arange(len(self.valid_acc)),
+ self.valid_acc,
+ "r",
+ linewidth=1.0,
+ )
+ self.ax2.plot(
+ self.test_iter * np.arange(len(self.train_acc)),
+ self.train_acc,
+ "g",
+ linewidth=1.0,
+ )
+
+ self.ax2.set_title("Valid. Accuracy: {:.4f}".format(self.valid_acc[-1]))
self.updatePlot()
class DataSet:
"""Class for training data and feeding train function."""
+
images = None
labels = None
length = 0
diff --git a/src/ocr/normalization.py b/handwriting_ocr/ocr/normalization.py
similarity index 72%
rename from src/ocr/normalization.py
rename to handwriting_ocr/ocr/normalization.py
index 7bbac97..a190912 100644
--- a/src/ocr/normalization.py
+++ b/handwriting_ocr/ocr/normalization.py
@@ -1,20 +1,22 @@
-# -*- coding: utf-8 -*-
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
"""
Include functions for normalizing images of words and letters
Main functions: word_normalization, letter_normalization, image_standardization
"""
+import math
+
import numpy as np
import cv2
-import math
-from .helpers import *
+from handwriting_ocr.ocr.helpers import *
def image_standardization(image):
"""Image standardization should result in same output
as tf.image.per_image_standardization.
"""
- return (image - np.mean(image)) / max(np.std(image), 1.0/math.sqrt(image.size))
+ return (image - np.mean(image)) / max(np.std(image), 1.0 / math.sqrt(image.size))
def _crop_add_border(img, height, threshold=50, border=True, border_size=15):
@@ -33,7 +35,7 @@ def _crop_add_border(img, height, threshold=50, border=True, border_size=15):
break
for i in reversed(range(img.shape[0])):
if np.count_nonzero(img[i, :]) > 1:
- y1 = i+1
+ y1 = i + 1
break
for i in range(img.shape[1]):
if np.count_nonzero(img[:, i]) > 1:
@@ -41,7 +43,7 @@ def _crop_add_border(img, height, threshold=50, border=True, border_size=15):
break
for i in reversed(range(img.shape[1])):
if np.count_nonzero(img[:, i]) > 1:
- x1 = i+1
+ x1 = i + 1
break
if height != 0:
@@ -50,23 +52,25 @@ def _crop_add_border(img, height, threshold=50, border=True, border_size=15):
img = img[y0:y1, x0:x1]
if border:
- return cv2.copyMakeBorder(img, 0, 0, border_size, border_size,
- cv2.BORDER_CONSTANT,
- value=[0, 0, 0])
+ return cv2.copyMakeBorder(
+ img, 0, 0, border_size, border_size, cv2.BORDER_CONSTANT, value=[0, 0, 0]
+ )
return img
def _word_tilt(img, height, border=True, border_size=15):
"""Detect the angle and tilt the image."""
- edges = cv2.Canny(img, 50, 150, apertureSize = 3)
- lines = cv2.HoughLines(edges, 1, np.pi/180, 30)
+ edges = cv2.Canny(img, 50, 150, apertureSize=3)
+ lines = cv2.HoughLines(edges, 1, np.pi / 180, 30)
if lines is not None:
meanAngle = 0
# Set min number of valid lines (try higher)
numLines = np.sum(1 for l in lines if l[0][1] < 0.7 or l[0][1] > 2.6)
if numLines > 1:
- meanAngle = np.mean([l[0][1] for l in lines if l[0][1] < 0.7 or l[0][1] > 2.6])
+ meanAngle = np.mean(
+ [l[0][1] for l in lines if l[0][1] < 0.7 or l[0][1] > 2.6]
+ )
# Look for angle with correct value
if meanAngle != 0 and (meanAngle < 0.7 or meanAngle > 2.6):
@@ -78,23 +82,21 @@ def _tilt_by_angle(img, angle, height):
"""Tilt the image by given angle."""
dist = np.tan(angle) * height
width = len(img[0])
- sPoints = np.float32([[0,0], [0,height], [width,height], [width,0]])
+ sPoints = np.float32([[0, 0], [0, height], [width, height], [width, 0]])
# Dist is positive for angle < 0.7; negative for angle > 2.6
# Image must be shifed to right
if dist > 0:
- tPoints = np.float32([[0,0],
- [dist,height],
- [width+dist,height],
- [width,0]])
+ tPoints = np.float32(
+ [[0, 0], [dist, height], [width + dist, height], [width, 0]]
+ )
else:
- tPoints = np.float32([[-dist,0],
- [0,height],
- [width,height],
- [width-dist,0]])
+ tPoints = np.float32(
+ [[-dist, 0], [0, height], [width, height], [width - dist, 0]]
+ )
M = cv2.getPerspectiveTransform(sPoints, tPoints)
- return cv2.warpPerspective(img, M, (int(width+abs(dist)), height))
+ return cv2.warpPerspective(img, M, (int(width + abs(dist)), height))
def _sobel_detect(channel):
@@ -111,7 +113,7 @@ class HysterThresh:
def __init__(self, img):
img = 255 - img
img = (img - np.min(img)) / (np.max(img) - np.min(img)) * 255
- hist, bins = np.histogram(img.ravel(), 256, [0,256])
+ hist, bins = np.histogram(img.ravel(), 256, [0, 256])
self.high = np.argmax(hist) + 65
self.low = np.argmax(hist) + 45
@@ -126,12 +128,14 @@ def get_image(self):
def _hyster_rec(self, r, c):
h, w = self.img.shape
- for ri in range(r-1, r+2):
- for ci in range(c-1, c+2):
- if (h > ri >= 0
+ for ri in range(r - 1, r + 2):
+ for ci in range(c - 1, c + 2):
+ if (
+ h > ri >= 0
and w > ci >= 0
and self.im[ri, ci] == 0
- and self.high > self.img[ri, ci] >= self.low):
+ and self.high > self.img[ri, ci] >= self.low
+ ):
self.im[ri, ci] = self.img[ri, ci] + self.diff
self._hyster_rec(ri, ci)
@@ -139,7 +143,7 @@ def _hyster(self):
r, c = self.img.shape
for ri in range(r):
for ci in range(c):
- if (self.img[ri, ci] >= self.high):
+ if self.img[ri, ci] >= self.high:
self.im[ri, ci] = 255
self.img[ri, ci] = 255
self._hyster_rec(ri, ci)
@@ -148,12 +152,14 @@ def _hyster(self):
def _hyst_word_norm(image):
"""Word normalization using hystheresis thresholding."""
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
-# img = cv2.bilateralFilter(gray, 0, 10, 30)
+ # img = cv2.bilateralFilter(gray, 0, 10, 30)
img = cv2.bilateralFilter(gray, 10, 10, 30)
return HysterThresh(img).get_image()
-def word_normalization(image, height, border=True, tilt=True, border_size=15, hyst_norm=False):
+def word_normalization(
+ image, height, border=True, tilt=True, border_size=15, hyst_norm=False
+):
""" Preprocess a word - resize, binarize, tilt world."""
image = resize(image, height, True)
@@ -163,16 +169,16 @@ def word_normalization(image, height, border=True, tilt=True, border_size=15, hy
img = cv2.bilateralFilter(image, 10, 30, 30)
gray = 255 - cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
norm = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX)
- ret,th = cv2.threshold(norm, 50, 255, cv2.THRESH_TOZERO)
+ ret, th = cv2.threshold(norm, 50, 255, cv2.THRESH_TOZERO)
if tilt:
return _word_tilt(th, height, border, border_size)
return _crop_add_border(th, height, border, border_size)
-def _resize_letter(img, size = 56):
+def _resize_letter(img, size=56):
"""Resize bigger side of the image to given size."""
- if (img.shape[0] > img.shape[1]):
+ if img.shape[0] > img.shape[1]:
rat = size / img.shape[0]
return cv2.resize(img, (int(rat * img.shape[1]), size))
else:
@@ -194,12 +200,14 @@ def letter_normalization(image, is_thresh=True, dim=False):
offset = [0, 0]
# Calculate offset for smaller size
if image.shape[0] > image.shape[1]:
- offset = [int((result.shape[1] - resized.shape[1])/2), 4]
+ offset = [int((result.shape[1] - resized.shape[1]) / 2), 4]
else:
- offset = [4, int((result.shape[0] - resized.shape[0])/2)]
- # Replace zeros by image
- result[offset[1]:offset[1] + resized.shape[0],
- offset[0]:offset[0] + resized.shape[1]] = resized
+ offset = [4, int((result.shape[0] - resized.shape[0]) / 2)]
+ # Replace zeros by image
+ result[
+ offset[1] : offset[1] + resized.shape[0],
+ offset[0] : offset[0] + resized.shape[1],
+ ] = resized
if dim:
return result, image.shape
diff --git a/src/ocr/page.py b/handwriting_ocr/ocr/page.py
similarity index 56%
rename from src/ocr/page.py
rename to handwriting_ocr/ocr/page.py
index 833d846..10d5b14 100644
--- a/src/ocr/page.py
+++ b/handwriting_ocr/ocr/page.py
@@ -1,47 +1,44 @@
-# -*- coding: utf-8 -*-
-"""
-Crop background and transform perspective from the photo of page
-"""
-import numpy as np
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Crop background and transform perspective from the photo of page."""
+
import cv2
+import numpy as np
+
+from handwriting_ocr.ocr.helpers import *
-from .helpers import *
def detection(image):
"""Finding Page."""
# Edge detection
image_edges = _edges_detection(image, 200, 250)
-
+
# Close gaps between edges (double page clouse => rectangle kernel)
- closed_edges = cv2.morphologyEx(image_edges,
- cv2.MORPH_CLOSE,
- np.ones((5, 11)))
+ closed_edges = cv2.morphologyEx(image_edges, cv2.MORPH_CLOSE, np.ones((5, 11)))
# Countours
page_contour = _find_page_contours(closed_edges, resize(image))
# Recalculate to original scale
- page_contour = page_contour.dot(ratio(image))
+ page_contour = page_contour.dot(ratio(image))
# Transform prespective
new_image = _persp_transform(image, page_contour)
return new_image
-
+
def _edges_detection(img, minVal, maxVal):
"""Preprocessing (gray, thresh, filter, border) + Canny edge detection."""
img = cv2.cvtColor(resize(img), cv2.COLOR_BGR2GRAY)
img = cv2.bilateralFilter(img, 9, 75, 75)
- img = cv2.adaptiveThreshold(img, 255,
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
- cv2.THRESH_BINARY, 115, 4)
+ img = cv2.adaptiveThreshold(
+ img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 115, 4
+ )
# Median blur replace center pixel by median of pixels under kelner
# => removes thin details
img = cv2.medianBlur(img, 11)
# Add black border - detection of border touching pages
- img = cv2.copyMakeBorder(img, 5, 5, 5, 5,
- cv2.BORDER_CONSTANT,
- value=[0, 0, 0])
+ img = cv2.copyMakeBorder(img, 5, 5, 5, 5, cv2.BORDER_CONSTANT, value=[0, 0, 0])
return cv2.Canny(img, minVal, maxVal)
@@ -49,10 +46,14 @@ def _four_corners_sort(pts):
"""Sort corners in order: top-left, bot-left, bot-right, top-right."""
diff = np.diff(pts, axis=1)
summ = pts.sum(axis=1)
- return np.array([pts[np.argmin(summ)],
- pts[np.argmax(diff)],
- pts[np.argmax(summ)],
- pts[np.argmin(diff)]])
+ return np.array(
+ [
+ pts[np.argmin(summ)],
+ pts[np.argmax(diff)],
+ pts[np.argmax(summ)],
+ pts[np.argmin(diff)],
+ ]
+ )
def _contour_offset(cnt, offset):
@@ -64,10 +65,10 @@ def _contour_offset(cnt, offset):
def _find_page_contours(edges, img):
"""Finding corner points of page contour."""
- im2, contours, hierarchy = cv2.findContours(edges,
- cv2.RETR_TREE,
- cv2.CHAIN_APPROX_SIMPLE)
-
+ im2, contours, hierarchy = cv2.findContours(
+ edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
+ )
+
# Finding biggest rectangle otherwise return original corners
height = edges.shape[0]
width = edges.shape[1]
@@ -75,20 +76,21 @@ def _find_page_contours(edges, img):
MAX_COUNTOUR_AREA = (width - 10) * (height - 10)
max_area = MIN_COUNTOUR_AREA
- page_contour = np.array([[0, 0],
- [0, height-5],
- [width-5, height-5],
- [width-5, 0]])
+ page_contour = np.array(
+ [[0, 0], [0, height - 5], [width - 5, height - 5], [width - 5, 0]]
+ )
for cnt in contours:
perimeter = cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, 0.03 * perimeter, True)
# Page has 4 corners and it is convex
- if (len(approx) == 4 and
- cv2.isContourConvex(approx) and
- max_area < cv2.contourArea(approx) < MAX_COUNTOUR_AREA):
-
+ if (
+ len(approx) == 4
+ and cv2.isContourConvex(approx)
+ and max_area < cv2.contourArea(approx) < MAX_COUNTOUR_AREA
+ ):
+
max_area = cv2.contourArea(approx)
page_contour = approx[:, 0]
@@ -100,20 +102,21 @@ def _find_page_contours(edges, img):
def _persp_transform(img, s_points):
"""Transform perspective from start points to target points."""
# Euclidean distance - calculate maximum height and width
- height = max(np.linalg.norm(s_points[0] - s_points[1]),
- np.linalg.norm(s_points[2] - s_points[3]))
- width = max(np.linalg.norm(s_points[1] - s_points[2]),
- np.linalg.norm(s_points[3] - s_points[0]))
-
+ height = max(
+ np.linalg.norm(s_points[0] - s_points[1]),
+ np.linalg.norm(s_points[2] - s_points[3]),
+ )
+ width = max(
+ np.linalg.norm(s_points[1] - s_points[2]),
+ np.linalg.norm(s_points[3] - s_points[0]),
+ )
+
# Create target points
- t_points = np.array([[0, 0],
- [0, height],
- [width, height],
- [width, 0]], np.float32)
-
+ t_points = np.array([[0, 0], [0, height], [width, height], [width, 0]], np.float32)
+
# getPerspectiveTransform() needs float32
if s_points.dtype != np.float32:
s_points = s_points.astype(np.float32)
-
- M = cv2.getPerspectiveTransform(s_points, t_points)
- return cv2.warpPerspective(img, M, (int(width), int(height)))
\ No newline at end of file
+
+ M = cv2.getPerspectiveTransform(s_points, t_points)
+ return cv2.warpPerspective(img, M, (int(width), int(height)))
diff --git a/src/ocr/tfhelpers.py b/handwriting_ocr/ocr/tfhelpers.py
similarity index 60%
rename from src/ocr/tfhelpers.py
rename to handwriting_ocr/ocr/tfhelpers.py
index 8501eed..a5b0259 100644
--- a/src/ocr/tfhelpers.py
+++ b/handwriting_ocr/ocr/tfhelpers.py
@@ -1,15 +1,23 @@
-# -*- coding: utf-8 -*-
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
"""
Provide functions and classes:
Model = Class for loading and using trained models from tensorflow
create_cell = function for creatting RNN cells with wrappers
"""
import tensorflow as tf
-from tensorflow.python.ops.rnn_cell_impl import LSTMCell, ResidualWrapper, DropoutWrapper, MultiRNNCell
+from tensorflow.python.ops.rnn_cell_impl import (
+ LSTMCell,
+ ResidualWrapper,
+ DropoutWrapper,
+ MultiRNNCell,
+)
-class Model():
+
+class Model:
"""Loading and running isolated tf graph."""
- def __init__(self, loc, operation='activation', input_name='x'):
+
+ def __init__(self, loc, operation="activation", input_name="x"):
"""
loc: location of file containing saved model
operation: name of operation for running the model
@@ -19,32 +27,31 @@ def __init__(self, loc, operation='activation', input_name='x'):
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.graph.as_default():
- saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
+ saver = tf.train.import_meta_graph(loc + ".meta", clear_devices=True)
saver.restore(self.sess, loc)
self.op = self.graph.get_operation_by_name(operation).outputs[0]
def run(self, data):
"""Run the specified operation on given data."""
return self.sess.run(self.op, feed_dict={self.input: data})
-
+
def eval_feed(self, feed):
"""Run the specified operation with given feed."""
return self.sess.run(self.op, feed_dict=feed)
-
+
def run_op(self, op, feed, output=True):
"""Run given operation with the feed."""
if output:
return self.sess.run(
- self.graph.get_operation_by_name(op).outputs[0],
- feed_dict=feed)
+ self.graph.get_operation_by_name(op).outputs[0], feed_dict=feed
+ )
else:
- self.sess.run(
- self.graph.get_operation_by_name(op),
- feed_dict=feed)
-
-
-
-def _create_single_cell(cell_fn, num_units, is_residual=False, is_dropout=False, keep_prob=None):
+ self.sess.run(self.graph.get_operation_by_name(op), feed_dict=feed)
+
+
+def _create_single_cell(
+ cell_fn, num_units, is_residual=False, is_dropout=False, keep_prob=None
+):
"""Create single RNN cell based on cell_fn."""
cell = cell_fn(num_units)
if is_dropout:
@@ -54,19 +61,28 @@ def _create_single_cell(cell_fn, num_units, is_residual=False, is_dropout=False,
return cell
-def create_cell(num_units, num_layers, num_residual_layers, is_dropout=False, keep_prob=None, cell_fn=LSTMCell):
+def create_cell(
+ num_units,
+ num_layers,
+ num_residual_layers,
+ is_dropout=False,
+ keep_prob=None,
+ cell_fn=LSTMCell,
+):
"""Create corresponding number of RNN cells with given wrappers."""
cell_list = []
-
+
for i in range(num_layers):
- cell_list.append(_create_single_cell(
- cell_fn=cell_fn,
- num_units=num_units,
- is_residual=(i >= num_layers - num_residual_layers),
- is_dropout=is_dropout,
- keep_prob=keep_prob
- ))
+ cell_list.append(
+ _create_single_cell(
+ cell_fn=cell_fn,
+ num_units=num_units,
+ is_residual=(i >= num_layers - num_residual_layers),
+ is_dropout=is_dropout,
+ keep_prob=keep_prob,
+ )
+ )
if num_layers == 1:
return cell_list[0]
- return MultiRNNCell(cell_list)
\ No newline at end of file
+ return MultiRNNCell(cell_list)
diff --git a/src/ocr/words.py b/handwriting_ocr/ocr/words.py
similarity index 64%
rename from src/ocr/words.py
rename to handwriting_ocr/ocr/words.py
index e2f02d9..ff66ee4 100644
--- a/src/ocr/words.py
+++ b/handwriting_ocr/ocr/words.py
@@ -1,13 +1,12 @@
-# -*- coding: utf-8 -*-
-"""
-Detect words on the page
-return array of words' bounding boxes
-"""
-import numpy as np
-import matplotlib.pyplot as plt
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+"""Detecting bounding boxes of words in the page."""
+
import cv2
+import matplotlib.pyplot as plt
+import numpy as np
-from .helpers import *
+from handwriting_ocr.ocr.helpers import *
def detection(image, join=False):
@@ -18,8 +17,7 @@ def detection(image, join=False):
blurred = cv2.GaussianBlur(image, (5, 5), 18)
edge_img = _edge_detect(blurred)
ret, edge_img = cv2.threshold(edge_img, 50, 255, cv2.THRESH_BINARY)
- bw_img = cv2.morphologyEx(edge_img, cv2.MORPH_CLOSE,
- np.ones((15,15), np.uint8))
+ bw_img = cv2.morphologyEx(edge_img, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8))
return _text_detect(bw_img, image, join)
@@ -27,8 +25,8 @@ def detection(image, join=False):
def sort_words(boxes):
"""Sort boxes - (x, y, x+w, y+h) from left to right, top to bottom."""
mean_height = sum([y2 - y1 for _, y1, _, y2 in boxes]) / len(boxes)
-
- boxes.view('i8,i8,i8,i8').sort(order=['f1'], axis=0)
+
+ boxes.view("i8,i8,i8,i8").sort(order=["f1"], axis=0)
current_line = boxes[0][1]
lines = []
tmp_line = []
@@ -36,14 +34,14 @@ def sort_words(boxes):
if box[1] > current_line + mean_height:
lines.append(tmp_line)
tmp_line = [box]
- current_line = box[1]
+ current_line = box[1]
continue
tmp_line.append(box)
lines.append(tmp_line)
-
+
for line in lines:
line.sort(key=lambda box: box[0])
-
+
return lines
@@ -52,9 +50,16 @@ def _edge_detect(im):
Edge detection using sobel operator on each layer individually.
Sobel operator is applied for each image layer (RGB)
"""
- return np.max(np.array([_sobel_detect(im[:,:, 0]),
- _sobel_detect(im[:,:, 1]),
- _sobel_detect(im[:,:, 2])]), axis=0)
+ return np.max(
+ np.array(
+ [
+ _sobel_detect(im[:, :, 0]),
+ _sobel_detect(im[:, :, 1]),
+ _sobel_detect(im[:, :, 2]),
+ ]
+ ),
+ axis=0,
+ )
def _sobel_detect(channel):
@@ -66,22 +71,24 @@ def _sobel_detect(channel):
return np.uint8(sobel)
-def union(a,b):
+def union(a, b):
x = min(a[0], b[0])
y = min(a[1], b[1])
- w = max(a[0]+a[2], b[0]+b[2]) - x
- h = max(a[1]+a[3], b[1]+b[3]) - y
+ w = max(a[0] + a[2], b[0] + b[2]) - x
+ h = max(a[1] + a[3], b[1] + b[3]) - y
return [x, y, w, h]
-def _intersect(a,b):
+
+def _intersect(a, b):
x = max(a[0], b[0])
y = max(a[1], b[1])
- w = min(a[0]+a[2], b[0]+b[2]) - x
- h = min(a[1]+a[3], b[1]+b[3]) - y
- if w<0 or h<0:
+ w = min(a[0] + a[2], b[0] + b[2]) - x
+ h = min(a[1] + a[3], b[1] + b[3]) - y
+ if w < 0 or h < 0:
return False
return True
+
def _group_rectangles(rec):
"""
Uion intersecting rectangles.
@@ -95,7 +102,7 @@ def _group_rectangles(rec):
i = 0
while i < len(rec):
if not tested[i]:
- j = i+1
+ j = i + 1
while j < len(rec):
if not tested[j] and _intersect(rec[i], rec[j]):
rec[i] = union(rec[i], rec[j])
@@ -104,39 +111,41 @@ def _group_rectangles(rec):
j += 1
final += [rec[i]]
i += 1
-
+
return final
def _text_detect(img, image, join=False):
"""Text detection using contours."""
small = resize(img, 2000)
-
+
# Finding contours
mask = np.zeros(small.shape, np.uint8)
- im2, cnt, hierarchy = cv2.findContours(np.copy(small),
- cv2.RETR_CCOMP,
- cv2.CHAIN_APPROX_SIMPLE)
-
- index = 0
+ im2, cnt, hierarchy = cv2.findContours(
+ np.copy(small), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE
+ )
+
+ index = 0
boxes = []
# Go through all contours in top level
- while (index >= 0):
- x,y,w,h = cv2.boundingRect(cnt[index])
+ while index >= 0:
+ x, y, w, h = cv2.boundingRect(cnt[index])
cv2.drawContours(mask, cnt, index, (255, 255, 255), cv2.FILLED)
- maskROI = mask[y:y+h, x:x+w]
+ maskROI = mask[y : y + h, x : x + w]
# Ratio of white pixels to area of bounding rectangle
r = cv2.countNonZero(maskROI) / (w * h)
-
+
# Limits for text
- if (r > 0.1
+ if (
+ r > 0.1
and 1600 > w > 10
and 1600 > h > 10
- and h/w < 3
- and w/h < 10
- and (60 // h) * w < 1000):
+ and h / w < 3
+ and w / h < 10
+ and (60 // h) * w < 1000
+ ):
boxes += [[x, y, w, h]]
-
+
index = hierarchy[0][index][0]
if join:
@@ -145,42 +154,39 @@ def _text_detect(img, image, join=False):
# image for drawing bounding boxes
small = cv2.cvtColor(small, cv2.COLOR_GRAY2RGB)
- bounding_boxes = np.array([0,0,0,0])
+ bounding_boxes = np.array([0, 0, 0, 0])
for (x, y, w, h) in boxes:
- cv2.rectangle(small, (x, y),(x+w,y+h), (0, 255, 0), 2)
- bounding_boxes = np.vstack((bounding_boxes,
- np.array([x, y, x+w, y+h])))
-
- implt(small, t='Bounding rectangles')
-
+ cv2.rectangle(small, (x, y), (x + w, y + h), (0, 255, 0), 2)
+ bounding_boxes = np.vstack((bounding_boxes, np.array([x, y, x + w, y + h])))
+
+ implt(small, t="Bounding rectangles")
+
boxes = bounding_boxes.dot(ratio(image, small.shape[0])).astype(np.int64)
- return boxes[1:]
-
+ return boxes[1:]
+
def textDetectWatershed(thresh):
"""NOT IN USE - Text detection using watershed algorithm.
Based on: http://docs.opencv.org/trunk/d3/db4/tutorial_py_watershed.html
"""
- img = cv2.cvtColor(cv2.imread("data/textdet/%s.jpg" % IMG),
- cv2.COLOR_BGR2RGB)
+ img = cv2.cvtColor(cv2.imread("data/textdet/%s.jpg" % IMG), cv2.COLOR_BGR2RGB)
img = resize(img, 3000)
thresh = resize(thresh, 3000)
# noise removal
- kernel = np.ones((3,3),np.uint8)
- opening = cv2.morphologyEx(thresh,cv2.MORPH_OPEN,kernel, iterations = 3)
-
+ kernel = np.ones((3, 3), np.uint8)
+ opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=3)
+
# sure background area
- sure_bg = cv2.dilate(opening,kernel,iterations=3)
+ sure_bg = cv2.dilate(opening, kernel, iterations=3)
# Finding sure foreground area
- dist_transform = cv2.distanceTransform(opening,cv2.DIST_L2,5)
- ret, sure_fg = cv2.threshold(dist_transform,
- 0.01*dist_transform.max(), 255, 0)
+ dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
+ ret, sure_fg = cv2.threshold(dist_transform, 0.01 * dist_transform.max(), 255, 0)
# Finding unknown region
sure_fg = np.uint8(sure_fg)
- unknown = cv2.subtract(sure_bg,sure_fg)
-
+ unknown = cv2.subtract(sure_bg, sure_fg)
+
# Marker labelling
ret, markers = cv2.connectedComponents(sure_fg)
@@ -189,12 +195,12 @@ def textDetectWatershed(thresh):
# Now, mark the region of unknown with zero
markers[unknown == 255] = 0
-
+
markers = cv2.watershed(img, markers)
- implt(markers, t='Markers')
+ implt(markers, t="Markers")
image = img.copy()
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-
+
for mark in np.unique(markers):
# mark == 0 --> background
if mark == 0:
@@ -204,20 +210,20 @@ def textDetectWatershed(thresh):
mask = np.zeros(gray.shape, dtype="uint8")
mask[markers == mark] = 255
- cnts = cv2.findContours(mask.copy(),
- cv2.RETR_EXTERNAL,
- cv2.CHAIN_APPROX_SIMPLE)[-2]
+ cnts = cv2.findContours(
+ mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )[-2]
c = max(cnts, key=cv2.contourArea)
-
+
# Draw a bounding rectangle if it contains text
- x,y,w,h = cv2.boundingRect(c)
+ x, y, w, h = cv2.boundingRect(c)
cv2.drawContours(mask, c, 0, (255, 255, 255), cv2.FILLED)
- maskROI = mask[y:y+h, x:x+w]
+ maskROI = mask[y : y + h, x : x + w]
# Ratio of white pixels to area of bounding rectangle
r = cv2.countNonZero(maskROI) / (w * h)
-
+
# Limits for text
if r > 0.2 and 2000 > w > 15 and 1500 > h > 15:
- cv2.rectangle(image, (x, y),(x+w,y+h), (0, 255, 0), 2)
-
+ cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2)
+
implt(image)
diff --git a/notebooks/ocr_evaluator.ipynb b/notebooks/ocr_evaluator.ipynb
index 887e8a1..63bb698 100644
--- a/notebooks/ocr_evaluator.ipynb
+++ b/notebooks/ocr_evaluator.ipynb
@@ -13,30 +13,24 @@
"metadata": {},
"outputs": [],
"source": [
+ "import math\n",
"import sys\n",
+ "import time\n",
+ "from abc import ABC, abstractmethod\n",
+ "from collections import Counter\n",
+ "\n",
+ "import cv2\n",
+ "import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
- "import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
- "import cv2\n",
- "import time\n",
- "import math\n",
- "from collections import Counter\n",
"import unidecode\n",
- "from abc import ABC, abstractmethod\n",
"\n",
- "# Import Widgets\n",
- "from ipywidgets import Button, Text, HBox, VBox\n",
- "from IPython.display import display, clear_output\n",
- "\n",
- "sys.path.append('../src')\n",
- "from ocr import characters\n",
- "from ocr.normalization import word_normalization, letter_normalization\n",
- "# Helpers\n",
- "from ocr.helpers import implt, resize, img_extend\n",
- "from ocr.datahelpers import load_words_data, idx2char\n",
- "from ocr.tfhelpers import Model\n",
- "from ocr.viz import print_progress_bar"
+ "from handwriting_ocr.ocr import characters\n",
+ "from handwriting_ocr.ocr.datahelpers import idx2char, load_words_data\n",
+ "from handwriting_ocr.ocr.helpers import img_extend, implt, resize\n",
+ "from handwriting_ocr.ocr.normalization import letter_normalization, word_normalization\n",
+ "from handwriting_ocr.ocr.tfhelpers import Model"
]
},
{
@@ -101,7 +95,7 @@
}
],
"source": [
- "images, labels = load_words_data('../data/sets/test.csv', is_csv=True)\n",
+ "images, labels = load_words_data(\"../data/sets/test.csv\", is_csv=True)\n",
"\n",
"\n",
"for i in range(len(images)):\n",
@@ -111,13 +105,14 @@
" 60,\n",
" border=False,\n",
" tilt=True,\n",
- " hystNorm=True)\n",
+ " hystNorm=True,\n",
+ " )\n",
"\n",
- "if LANG == 'en':\n",
+ "if LANG == \"en\":\n",
" for i in range(len(labels)):\n",
" labels[i] = unidecode.unidecode(labels[i])\n",
- "print() \n",
- "print('Number of chars:', sum(len(l) for l in labels))"
+ "print()\n",
+ "print(\"Number of chars:\", sum(len(l) for l in labels))"
]
},
{
@@ -135,47 +130,53 @@
"source": [
"# Load Words\n",
"WORDS = {}\n",
- "with open('../data/dictionaries' + LANG + '_50k.txt') as f:\n",
+ "with open(\"../data/dictionaries\" + LANG + \"_50k.txt\") as f:\n",
" for line in f:\n",
- " if LANG == 'en':\n",
+ " if LANG == \"en\":\n",
" WORDS[unidecode.unidecode(line.split(\" \")[0])] = int(line.split(\" \")[1])\n",
" else:\n",
" WORDS[line.split(\" \")[0]] = int(line.split(\" \")[1])\n",
"WORDS = Counter(WORDS)\n",
"\n",
- "def P(word, N=sum(WORDS.values())): \n",
+ "\n",
+ "def P(word, N=sum(WORDS.values())):\n",
" \"Probability of word.\"\n",
" return WORDS[word] / N\n",
"\n",
- "def correction(word): \n",
+ "\n",
+ "def correction(word):\n",
" \"Most probable spelling correction for word.\"\n",
" if word in WORDS:\n",
" return word\n",
" return max(candidates(word), key=P)\n",
"\n",
- "def candidates(word): \n",
+ "\n",
+ "def candidates(word):\n",
" \"Generate possible spelling corrections for word.\"\n",
- " return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])\n",
+ " return known([word]) or known(edits1(word)) or known(edits2(word)) or [word]\n",
"\n",
- "def known(words): \n",
+ "\n",
+ "def known(words):\n",
" \"The subset of words that appear in the dictionary of WORDS.\"\n",
" return set(w for w in words if w in WORDS)\n",
"\n",
+ "\n",
"def edits1(word):\n",
" \"All edits that are one edit away from `word`.\"\n",
- " \n",
- " if LANG == 'cz':\n",
- " letters = 'aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž'\n",
+ "\n",
+ " if LANG == \"cz\":\n",
+ " letters = \"aábcčdďeéěfghiíjklmnňoópqrřsštťuúůvwxyýzž\"\n",
" else:\n",
- " letters = 'abcdefghijklmnopqrstuvwxyz'\n",
- " splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]\n",
- " deletes = [L + R[1:] for L, R in splits if R]\n",
- " transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]\n",
- " replaces = [L + c + R[1:] for L, R in splits if R for c in letters]\n",
- " inserts = [L + c + R for L, R in splits for c in letters]\n",
+ " letters = \"abcdefghijklmnopqrstuvwxyz\"\n",
+ " splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]\n",
+ " deletes = [L + R[1:] for L, R in splits if R]\n",
+ " transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]\n",
+ " replaces = [L + c + R[1:] for L, R in splits if R for c in letters]\n",
+ " inserts = [L + c + R for L, R in splits for c in letters]\n",
" return set(deletes + transposes + replaces + inserts)\n",
"\n",
- "def edits2(word): \n",
+ "\n",
+ "def edits2(word):\n",
" \"All edits that are two edits away from `word`.\"\n",
" return (e2 for e1 in edits1(word) for e2 in edits1(e1))"
]
@@ -210,10 +211,10 @@
" insertion = d[i][j - 1] + 1\n",
" deletion = d[i - 1][j] + 1\n",
" d[i][j] = min(substitution, insertion, deletion)\n",
- "# result = float(d[len(r)][len(h)]) / len(r) * 100\n",
- "# print('CER %.4f %%' % result)\n",
- "# print(d[len(r)][len(h)])\n",
- " return(d[len(r)][len(h)])"
+ " # result = float(d[len(r)][len(h)]) / len(r) * 100\n",
+ " # print('CER %.4f %%' % result)\n",
+ " # print(d[len(r)][len(h)])\n",
+ " return d[len(r)][len(h)]"
]
},
{
@@ -230,16 +231,19 @@
"outputs": [],
"source": [
"class Cycler(ABC):\n",
- " \"\"\" Abstract cycler class \"\"\" \n",
- " def __init__(self,\n",
- " images,\n",
- " labels,\n",
- " charClass,\n",
- " stats=\"No Stats Provided\",\n",
- " slider=(60, 15),\n",
- " ctc=False,\n",
- " seq2seq=False,\n",
- " charRNN=False):\n",
+ " \"\"\" Abstract cycler class \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " images,\n",
+ " labels,\n",
+ " charClass,\n",
+ " stats=\"No Stats Provided\",\n",
+ " slider=(60, 15),\n",
+ " ctc=False,\n",
+ " seq2seq=False,\n",
+ " charRNN=False,\n",
+ " ):\n",
" self.images = images\n",
" self.labels = labels\n",
" self.charClass = charClass\n",
@@ -249,32 +253,32 @@
" self.seq2seq = seq2seq\n",
" self.charRNN = charRNN\n",
" self.stats = stats\n",
- " \n",
+ "\n",
" self.evaluate()\n",
- " \n",
+ "\n",
" @abstractmethod\n",
" def recogniseWord(self, img):\n",
" pass\n",
- " \n",
+ "\n",
" def countCorrect(self, pred, label, lower=False):\n",
" correct = 0\n",
" for i in range(min(len(pred), len(label))):\n",
- " if ((not lower and pred[i] == label[i])\n",
- " or (lower and pred[i] == label.lower()[i])):\n",
+ " if (not lower and pred[i] == label[i]) or (\n",
+ " lower and pred[i] == label.lower()[i]\n",
+ " ):\n",
" correct += 1\n",
- " \n",
- " return correct \n",
"\n",
- " \n",
+ " return correct\n",
+ "\n",
" def evaluate(self):\n",
" \"\"\" Evaluate accuracy of the word classification \"\"\"\n",
" print()\n",
" print(\"STATS:\", self.stats)\n",
- " print(self.labels[1], ':', self.recogniseWord(self.images[1]))\n",
+ " print(self.labels[1], \":\", self.recogniseWord(self.images[1]))\n",
" start_time = time.time()\n",
" for i in range(len(self.images)):\n",
" word = self.recogniseWord(self.images[i])\n",
- "# a = correction(word.lower()\n",
+ " # a = correction(word.lower()\n",
" print(\"--- %s seconds ---\" % round(time.time() - start_time, 2))\n",
" ccer = 0\n",
" correctLetters = 0\n",
@@ -283,12 +287,11 @@
" correctLettersCorrection = 0\n",
" for i in range(len(self.images)):\n",
" word = self.recogniseWord(self.images[i])\n",
- " correctLetters += self.countCorrect(word,\n",
- " self.labels[i])\n",
+ " correctLetters += self.countCorrect(word, self.labels[i])\n",
" # Correction works only for lower letters\n",
- " correctLettersCorrection += self.countCorrect(correction(word.lower()),\n",
- " self.labels[i],\n",
- " lower=True)\n",
+ " correctLettersCorrection += self.countCorrect(\n",
+ " correction(word.lower()), self.labels[i], lower=True\n",
+ " )\n",
" ccer += cer(word, self.labels[i])\n",
" # Words accuracy\n",
" if word == self.labels[i]:\n",
@@ -297,11 +300,21 @@
" correctWordsCorrection += 1\n",
"\n",
" print(\"Correct/Total: %s / %s\" % (correctLetters, self.totalChars))\n",
- " print(\"CERacc: %s %%\" % round(100 - ccer/self.totalChars * 100, 4))\n",
- " print(\"Letter Accuracy: %s %%\" % round(correctLetters/self.totalChars * 100, 4))\n",
- " print(\"Letter Accuracy with Correction: %s %%\" % round(correctLettersCorrection/self.totalChars * 100, 4))\n",
- " print(\"Word Accuracy: %s %%\" % round(correctWords/len(self.images) * 100, 4))\n",
- " print(\"Word Accuracy with Correction: %s %%\" % round(correctWordsCorrection/len(self.images) * 100, 4))\n",
+ " print(\"CERacc: %s %%\" % round(100 - ccer / self.totalChars * 100, 4))\n",
+ " print(\n",
+ " \"Letter Accuracy: %s %%\" % round(correctLetters / self.totalChars * 100, 4)\n",
+ " )\n",
+ " print(\n",
+ " \"Letter Accuracy with Correction: %s %%\"\n",
+ " % round(correctLettersCorrection / self.totalChars * 100, 4)\n",
+ " )\n",
+ " print(\"Word Accuracy: %s %%\" % round(correctWords / len(self.images) * 100, 4))\n",
+ " print(\n",
+ " \"Word Accuracy with Correction: %s %%\"\n",
+ " % round(correctWordsCorrection / len(self.images) * 100, 4)\n",
+ " )\n",
+ "\n",
+ "\n",
"# print(\"--- %s seconds ---\" % round(time.time() - start_time, 2))"
]
},
@@ -312,60 +325,80 @@
"outputs": [],
"source": [
"class WordCycler(Cycler):\n",
- " \"\"\" Cycle through the words and recognise them \"\"\" \n",
+ " \"\"\" Cycle through the words and recognise them \"\"\"\n",
+ "\n",
" def recogniseWord(self, img):\n",
" slider = self.slider\n",
- " \n",
+ "\n",
" if self.ctc:\n",
- " step = 10 # 10 for (60, 60) slider\n",
+ " step = 10 # 10 for (60, 60) slider\n",
" img = cv2.copyMakeBorder(\n",
" img,\n",
- " 0, 0, self.slider[1]//2, self.slider[1]//2,\n",
+ " 0,\n",
+ " 0,\n",
+ " self.slider[1] // 2,\n",
+ " self.slider[1] // 2,\n",
" cv2.BORDER_CONSTANT,\n",
- " value=[0, 0, 0])\n",
+ " value=[0, 0, 0],\n",
+ " )\n",
" img = img_extend(\n",
" img,\n",
- " (img.shape[0], max(-(-img.shape[1] // step) * step, self.slider[1] + step)))\n",
- " length = (img.shape[1]-slider[1]) // step\n",
+ " (\n",
+ " img.shape[0],\n",
+ " max(-(-img.shape[1] // step) * step, self.slider[1] + step),\n",
+ " ),\n",
+ " )\n",
+ " length = (img.shape[1] - slider[1]) // step\n",
" input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)\n",
- " input_seq[0][:] = [img[:, loc*step: loc*step + slider[1]].flatten()\n",
- " for loc in range(length)]\n",
+ " input_seq[0][:] = [\n",
+ " img[:, loc * step : loc * step + slider[1]].flatten()\n",
+ " for loc in range(length)\n",
+ " ]\n",
" input_seq = input_seq.swapaxes(0, 1)\n",
- " \n",
- " pred = self.charClass.eval_feed({'inputs:0': input_seq,\n",
- " 'inputs_length:0': [length],\n",
- " 'keep_prob:0': 1})[0]\n",
- " \n",
- " word = ''\n",
+ "\n",
+ " pred = self.charClass.eval_feed(\n",
+ " {\"inputs:0\": input_seq, \"inputs_length:0\": [length], \"keep_prob:0\": 1}\n",
+ " )[0]\n",
+ "\n",
+ " word = \"\"\n",
" for i in pred:\n",
" if word == 0 and i != 0:\n",
" break\n",
" else:\n",
" word += idx2char(i)\n",
- " \n",
- " else: \n",
- " length = img.shape[1]//slider[1]\n",
+ "\n",
+ " else:\n",
+ " length = img.shape[1] // slider[1]\n",
"\n",
" input_seq = np.zeros((1, length, slider[0] * slider[1]), dtype=np.float32)\n",
- " input_seq[0][:] = [img[:, loc * slider[1]: (loc+1) * slider[1]].flatten()\n",
- " for loc in range(length)] \n",
+ " input_seq[0][:] = [\n",
+ " img[:, loc * slider[1] : (loc + 1) * slider[1]].flatten()\n",
+ " for loc in range(length)\n",
+ " ]\n",
" input_seq = input_seq.swapaxes(0, 1)\n",
"\n",
- "\n",
" if self.seq2seq:\n",
- " targets = np.zeros((1, 1), dtype=np.int32) \n",
- " pred = self.charClass.eval_feed({'encoder_inputs:0': input_seq,\n",
- " 'encoder_inputs_length:0': [length],\n",
- " 'decoder_targets:0': targets,\n",
- " 'keep_prob:0': 1})[0]\n",
+ " targets = np.zeros((1, 1), dtype=np.int32)\n",
+ " pred = self.charClass.eval_feed(\n",
+ " {\n",
+ " \"encoder_inputs:0\": input_seq,\n",
+ " \"encoder_inputs_length:0\": [length],\n",
+ " \"decoder_targets:0\": targets,\n",
+ " \"keep_prob:0\": 1,\n",
+ " }\n",
+ " )[0]\n",
" else:\n",
- " targets = np.zeros((1, 1, 4096), dtype=np.int32) \n",
- " pred = self.charClass.eval_feed({'encoder_inputs:0': input_seq,\n",
- " 'encoder_inputs_length:0': [length],\n",
- " 'letter_targets:0': targets,\n",
- " 'is_training:0': False,\n",
- " 'keep_prob:0': 1})[0]\n",
- " word = ''\n",
+ " targets = np.zeros((1, 1, 4096), dtype=np.int32)\n",
+ " pred = self.charClass.eval_feed(\n",
+ " {\n",
+ " \"encoder_inputs:0\": input_seq,\n",
+ " \"encoder_inputs_length:0\": [length],\n",
+ " \"letter_targets:0\": targets,\n",
+ " \"is_training:0\": False,\n",
+ " \"keep_prob:0\": 1,\n",
+ " }\n",
+ " )[0]\n",
+ " word = \"\"\n",
" for i in pred:\n",
" if word == 1:\n",
" break\n",
@@ -382,36 +415,36 @@
"outputs": [],
"source": [
"class CharCycler(Cycler):\n",
- " \"\"\" Cycle through the words and recognise them \"\"\" \n",
+ " \"\"\" Cycle through the words and recognise them \"\"\"\n",
+ "\n",
" def recogniseWord(self, img):\n",
- " img = cv2.copyMakeBorder(img,\n",
- " 0, 0, 30, 30,\n",
- " cv2.BORDER_CONSTANT,\n",
- " value=[0, 0, 0])\n",
+ " img = cv2.copyMakeBorder(\n",
+ " img, 0, 0, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0]\n",
+ " )\n",
" gaps = characters.segment(img, RNN=True)\n",
- " \n",
+ "\n",
" chars = []\n",
- " for i in range(len(gaps)-1):\n",
- " char = img[:, gaps[i]:gaps[i+1]]\n",
+ " for i in range(len(gaps) - 1):\n",
+ " char = img[:, gaps[i] : gaps[i + 1]]\n",
" # TODO None type error after treshold\n",
" char, dim = letter_normalization(char, is_thresh=True, dim=True)\n",
" # TODO Test different values\n",
" if dim[0] > 4 and dim[1] > 4:\n",
" chars.append(char.flatten())\n",
- " \n",
+ "\n",
" chars = np.array(chars)\n",
- " word = ''\n",
+ " word = \"\"\n",
" if len(chars) != 0:\n",
" if self.charRNN:\n",
- " pred = self.charClass.eval_feed({'inputs:0': [chars],\n",
- " 'length:0': [len(chars)],\n",
- " 'keep_prob:0': 1})[0]\n",
+ " pred = self.charClass.eval_feed(\n",
+ " {\"inputs:0\": [chars], \"length:0\": [len(chars)], \"keep_prob:0\": 1}\n",
+ " )[0]\n",
" else:\n",
" pred = self.charClass.run(chars)\n",
- " \n",
+ "\n",
" for c in pred:\n",
" # word += CHARS[charIdx]\n",
- " word += idx2char(c) \n",
+ " word += idx2char(c)\n",
" return word"
]
},
@@ -480,31 +513,13 @@
"source": [
"# Class cycling through words\n",
"\n",
- "WordCycler(images,\n",
- " labels,\n",
- " wordClass,\n",
- " stats='Seq2Seq',\n",
- " slider=(60, 2),\n",
- " seq2seq=True)\n",
- "\n",
- "WordCycler(images,\n",
- " labels,\n",
- " wordClass2,\n",
- " stats='Seq2SeqX',\n",
- " slider=(60, 2))\n",
- "\n",
- "WordCycler(images,\n",
- " labels,\n",
- " wordClass3,\n",
- " stats='CTC',\n",
- " slider=(60, 60),\n",
- " ctc=True)\n",
- "\n",
- "CharCycler(images,\n",
- " labels,\n",
- " charClass_1,\n",
- " stats='Bi-RNN and CNN',\n",
- " charRNN=False)\n",
+ "WordCycler(images, labels, wordClass, stats=\"Seq2Seq\", slider=(60, 2), seq2seq=True)\n",
+ "\n",
+ "WordCycler(images, labels, wordClass2, stats=\"Seq2SeqX\", slider=(60, 2))\n",
+ "\n",
+ "WordCycler(images, labels, wordClass3, stats=\"CTC\", slider=(60, 60), ctc=True)\n",
+ "\n",
+ "CharCycler(images, labels, charClass_1, stats=\"Bi-RNN and CNN\", charRNN=False)\n",
"\n",
"# Cycler(images,\n",
"# labels,\n",
@@ -534,9 +549,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.8"
+ "version": "3.7.5"
}
},
"nbformat": 4,
- "nbformat_minor": 1
+ "nbformat_minor": 4
}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..a1457f4
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,51 @@
+# Project configuration
+title = "Handwriting OCR"
+
+[owner]
+name = "Břetislav Hájek"
+
+
+[build-system]
+requires = ["setuptools", "wheel"]
+
+
+# NOTE: you have to use single-quoted strings in TOML for regular expressions.
+# It's the equivalent of r-strings in Python. Multiline strings are treated as
+# verbose regular expressions by Black. Use [ ] to denote a significant space
+# character.
+[tool.black]
+line-length = 88
+target-version = ['py37', 'py38']
+include = '\.pyi?$'
+exclude = '''
+/(
+ \.eggs
+ | \.git
+ | \.hg
+ | \.mypy_cache
+ | \.tox
+ | \.venv
+ | _build
+ | buck-out
+ | build
+ | dist
+)/
+'''
+
+[tool.isort]
+multi_line_output = 3
+include_trailing_comma = true
+force_grid_wrap = 0
+use_parentheses = true
+line_length = 88
+
+# This will work once pylint 2.5 is release ##
+# Update pylint as soon as possible
+[tool.pylint.'MASTER']
+extension-pkg-whitelist='cv2'
+
+[tool.pylint.'BASIC']
+variable-rgx='[a-z_][a-z0-9_]{0,30}$'
+
+[tool.pylint.'FORMAT']
+max-line-length=88
diff --git a/requirements-apt.txt b/requirements-apt.txt
new file mode 100644
index 0000000..40184c9
--- /dev/null
+++ b/requirements-apt.txt
@@ -0,0 +1,2 @@
+build-essential
+python3.7
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..7519436
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,6 @@
+black==19.10b0
+flake8==3.7.9
+isort==4.3.21
+jupyter==1.0.0
+pre-commit==2.2.0
+pylint==2.4.4
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..7b4ba87
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+gdown==3.10.2
+matplotlib==3.2.1
+numpy==1.18.2
+opencv-python==4.2.0.32
+pandas==1.0.3
+tensorflow==2.1.0
+tensorflow-addons==0.8.1
+tqdm==4.44.1
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..17d1d9e
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,52 @@
+# Copyright 2020 Břetislav Hájek
+# Licensed under the MIT License. See LICENSE for details.
+from pathlib import Path
+
+import pkg_resources
+import setuptools
+
+
+CURRENT_DIR = Path(__file__).parent
+
+
+# TODO: Later if possible move all requirements to setup.py
+with CURRENT_DIR.joinpath("requirements.txt").open() as f:
+ install_req = list(map(str, pkg_resources.parse_requirements(f)))
+
+with CURRENT_DIR.joinpath("requirements-dev.txt").open() as f:
+ dev_req = list(map(str, pkg_resources.parse_requirements(f)))
+
+
+def get_long_description() -> str:
+ return (
+ (CURRENT_DIR / "README.md").read_text(encoding="utf8")
+ + "\n\n"
+ + (CURRENT_DIR / "CHANGELOG.md").read_text(encoding="utf8")
+ )
+
+
+setuptools.setup(
+ name="handwriting-ocr",
+ version="0.0.0",
+ author="Břetislav Hájek",
+ author_email="info@bretahajek.com",
+ description="OCR tool for handwriting.",
+ long_description=get_long_description(),
+ long_description_content_type="text/markdown",
+ url="https://github.com/Breta01/handwriting-ocr",
+ packages=setuptools.find_packages(),
+ python_requires=">=3.7",
+ keywords="handwriting ocr",
+ license="MIT",
+ install_requires=install_req,
+ extras_require={"dev": dev_req,},
+ classifiers=[
+ "Development Status :: 4 - Beta",
+ "Intended Audience :: Developers",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ ],
+)
diff --git a/src/data/README.txt b/src/data/README.txt
deleted file mode 100644
index 403a0fb..0000000
--- a/src/data/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
- Name - Number of words (numbers)
-1. IAM - 85012
-2. Camb - 5260
-3. ORAND - 11719
-4. CVL - 84164
-5. Other - 2460
-Total number of words: 188615
-
-All final samples are stored in folders (archives) called 'words-final' under each dataset folder.
-
-The words are stored in form '__.png' (Way of labeling can be changed.)
-For example: car_1_1528457794.9072268.png
- - file corespons to image of a word 'car' from IAM dataset
-
-The word can contain all english alphabet characters (uppercase, lowercase), 0-9 digits,
-and four special characters ('.', '-', "+", "'").
-(IAM dataset has some other special characters which can be added.)
-
-If you want to recreate final dataset using Python scripts. You have to download and extract the original dataset files and then run the Python script in same folder.
diff --git a/src/data/create_csv.py b/src/data/create_csv.py
deleted file mode 100644
index 57d5836..0000000
--- a/src/data/create_csv.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import argparse
-import csv
-import glob
-import os
-import sys
-
-import cv2
-import numpy as np
-import simplejson
-
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../'))
-from ocr.viz import print_progress_bar
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- '--sets',
- default=os.path.join(location, '../../data/sets/'),
- help="Folder with sets for converting to CSV.")
-
-
-def create_csv(datadir):
- print('Converting word images to CSV...')
- img_paths = {
- 'train': glob.glob(os.path.join(datadir, 'train', '*.png')),
- 'dev': glob.glob(os.path.join(datadir, 'dev', '*.png')),
- 'test': glob.glob(os.path.join(datadir, 'test', '*.png'))}
-
- for split in ['train', 'dev', 'test']:
- labels = np.array([
- os.path.basename(name).split('_')[0] for name in img_paths[split]])
- length = len(img_paths[split])
- images = np.empty(length, dtype=object)
-
- for i, img in enumerate(img_paths[split]):
- gaplines = 'None'
- if os.path.isfile(img[:-3] + 'txt'):
- with open(img[:-3] + 'txt', 'r') as fp:
- gaplines = str(simplejson.load(fp))[1:-1]
- images[i] = (cv2.imread(img, 0), gaplines)
- print_progress_bar(i, length)
-
- with open(os.path.join(datadir, split + '.csv'), 'w') as csvfile:
- fieldnames = ['label', 'shape', 'image', 'gaplines']
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
- writer.writeheader()
- for i in range(length):
- writer.writerow({
- fieldnames[0]: labels[i],
- fieldnames[1]: str(images[i][0].shape)[1:-1],
- fieldnames[2]: str(list(images[i][0].flatten()))[1:-1],
- fieldnames[3]: images[i][1]
- })
-
- print('\tCSV files created!')
-
-
-if __name__ == '__main__':
- args = parser.parse_args()
- create_csv(args.sets)
diff --git a/src/data/data_create_sets.py b/src/data/data_create_sets.py
deleted file mode 100644
index ebfccf2..0000000
--- a/src/data/data_create_sets.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import argparse
-import glob
-import os
-import random
-import sys
-from shutil import copyfile
-
-import cv2
-import numpy as np
-
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../'))
-
-from create_csv import create_csv
-from data_extractor import datasets
-from ocr.viz import print_progress_bar
-
-
-random.seed(17) # Make the datasets split random, but reproducible
-data_folder = 'words_final'
-output_folder = os.path.join(location, '../../data/sets/')
-
-# Sets percent distribution
-test_set = 0.1
-validation_set = 0.1
-
-
-parser = argparse.ArgumentParser(
- description='Script sliting processed words into train, validation and test sets.')
-parser.add_argument(
- '-d', '--dataset',
- nargs='*',
- choices=datasets.keys(),
- help='Pick dataset(s) to be used.')
-parser.add_argument(
- '-p', '--path',
- nargs='*',
- default=[],
- help="""Path to folder containing the dataset. For multiple datasets
- provide path or ''. If not set, default paths will be used.""")
-parser.add_argument(
- '--output',
- default='data-handwriting/sets',
- help="Directory for normalized and split data")
-parser.add_argument(
- '--csv',
- action='store_true',
- default=False,
- help="Include flag if you want to create csv files along with split.")
-
-
-if __name__ == '__main__':
- args = parser.parse_args()
- if args.dataset == ['all']:
- args.dataset = list(datasets.keys())[:-1]
-
- assert args.path == [] or len(args.dataset) == len(args.path), \
- "provide same number of paths as datasets (use '' for default)"
- if args.path != []:
- for ds, path in zip(args.dataset, args.path):
- datasets[ds][1] = path
-
- if not os.path.exists(output_folder):
- os.makedirs(output_folder)
-
- imgs = []
- for ds in args.dataset:
- for loc, _, _ in os.walk(datasets[ds][1].replace("raw", "processed")):
- imgs += glob.glob(os.path.join(loc, '*.png'))
-
- imgs.sort()
- random.shuffle(imgs)
-
- length = len(imgs)
- sp1 = int((1 - test_set - validation_set) * length)
- sp2 = int((1 - test_set) * length)
- img_paths = {'train': imgs[:sp1], 'dev': imgs[sp1:sp2], 'test': imgs[sp2:]}
-
- i = 0
- for split in ['train', 'dev', 'test']:
- split_output = os.path.join(output_folder, split)
- if not os.path.exists(split_output):
- os.mkdir(split_output)
- for im_path in img_paths[split]:
- copyfile(im_path, os.path.join(split_output, os.path.basename(im_path)))
- if '_gaplines' in im_path:
- im_path = im_path[:-3] + 'txt'
- copyfile(
- im_path, os.path.join(split_output, os.path.basename(im_path)))
-
- print_progress_bar(i, length)
- i += 1
-
- print(
- "\n\tNumber of %s words: %s" % (split, len(os.listdir(split_output))))
-
- if args.csv:
- create_csv(output_folder)
diff --git a/src/data/data_extractor.py b/src/data/data_extractor.py
deleted file mode 100644
index 1ffbceb..0000000
--- a/src/data/data_extractor.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import argparse
-import os
-
-from datasets import breta, camb, cvl, iam, orand
-
-location = os.path.dirname(os.path.abspath(__file__))
-data_folder = os.path.join(location, '../../data/raw/')
-datasets = {
- 'breta': [breta.extract, os.path.join(data_folder, 'breta'), 1],
- 'iam': [iam.extract, os.path.join(data_folder, 'iam'), 2],
- 'cvl': [cvl.extract, os.path.join(data_folder, 'cvl'), 3],
- 'orand': [orand.extract, os.path.join(data_folder, 'orand'), 4],
- 'camb': [camb.extract, os.path.join(data_folder, 'camb'), 5],
- 'all': []}
-
-output_folder = 'words_final'
-
-
-parser = argparse.ArgumentParser(
- description='Script extracting words from raw dataset.')
-parser.add_argument(
- '-d', '--dataset',
- nargs='*',
- choices=datasets.keys(),
- help='Pick dataset(s) to be used.')
-parser.add_argument(
- '-p', '--path',
- nargs='*',
- default=[],
- help="""Path to folder containing the dataset. For multiple datasets
- provide path or ''. If not filled, default paths will be used.""")
-
-
-if __name__ == '__main__':
- args = parser.parse_args()
- if args.dataset == ['all']:
- args.dataset = list(datasets.keys())[:-1]
-
- assert args.path == [] or len(args.dataset) == len(args.path), \
- "provide same number of paths as datasets (use '' for default)"
- if args.path != []:
- for ds, path in zip(args.dataset, args.path):
- datasets[ds][1] = path
-
- for ds in args.dataset:
- print("Processing -", ds)
- entry = datasets[ds]
- entry[0](entry[1], output_folder, entry[2])
diff --git a/src/data/data_normalization.py b/src/data/data_normalization.py
deleted file mode 100644
index 723eb4b..0000000
--- a/src/data/data_normalization.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import argparse
-import glob
-import os
-import sys
-
-import cv2
-import numpy as np
-from PIL import Image
-
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../'))
-
-from data_extractor import datasets
-from ocr.normalization import word_normalization
-from ocr.viz import print_progress_bar
-
-
-data_folder = 'words_final'
-output_folder = os.path.join(location, '../../data/processed/')
-
-
-parser = argparse.ArgumentParser(
- description='Script normalizing words from datasts.')
-parser.add_argument(
- '-d', '--dataset',
- nargs='*',
- choices=datasets.keys(),
- help='Pick dataset(s) to be used.')
-parser.add_argument(
- '-p', '--path',
- nargs='*',
- default=[],
- help="""Path to folder containing the dataset. For multiple datasets
- provide path or ''. If not set, default paths will be used.""")
-
-
-def words_norm(location, output):
- output = os.path.join(location, output)
- if os.path.exists(output):
- print("THIS DATASET IS BEING SKIPPED")
- print("Output folder already exists:", output)
- return 1
- else:
- output = os.path.join(output, 'words_nolines')
- os.makedirs(output)
-
- imgs = glob.glob(os.path.join(location, data_folder, '*.png'))
- length = len(imgs)
-
- for i, img_path in enumerate(imgs):
- image = cv2.imread(img_path)
- # Simple check for invalid images
- if image.shape[0] > 20:
- cv2.imwrite(
- os.path.join(output, os.path.basename(img_path)),
- word_normalization(
- image,
- height=64,
- border=False,
- tilt=False,
- hyst_norm=False))
- print_progress_bar(i, len(imgs))
-
- print("\tNumber of normalized words:",
- len([n for n in os.listdir(output)]))
-
-
-if __name__ == '__main__':
- args = parser.parse_args()
- if args.dataset == ['all']:
- args.dataset = list(datasets.keys())[:-1]
-
- assert args.path == [] or len(args.dataset) == len(args.path), \
- "provide same number of paths as datasets (use '' for default)"
- if args.path != []:
- for ds, path in zip(args.dataset, args.path):
- datasets[ds][1] = path
-
- if not os.path.exists(output_folder):
- os.makedirs(output_folder)
-
- for ds in args.dataset:
- print("Processing -", ds)
- entry = datasets[ds]
- words_norm(entry[1], os.path.join(output_folder, ds))
diff --git a/src/data/datasets/breta.py b/src/data/datasets/breta.py
deleted file mode 100644
index 98168da..0000000
--- a/src/data/datasets/breta.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import os
-from PIL import Image
-import time
-import sys
-# Allow accesing files relative to this file
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../../'))
-from ocr.viz import print_progress_bar
-
-
-def extract(location, output, number=1):
- output = os.path.join(location, output)
- if not os.path.exists(output):
- os.makedirs(output)
-
- for sub in ['words', 'archive', 'cz_raw', 'en_raw']:
- folder = os.path.join(location, sub)
-
- img_list = os.listdir(os.path.join(folder))
- for i, data in enumerate(img_list):
- word = data.split('_')[0]
- img = os.path.join(folder, data)
- out = os.path.join(
- output,
- '%s_%s_%s.png' % (word, number, data.split('_')[-1][:-4]))
- Image.open(img).save(out)
- print_progress_bar(i, len(img_list))
-
- print("\tNumber of words:", len([n for n in os.listdir(output)]))
diff --git a/src/data/datasets/camb.py b/src/data/datasets/camb.py
deleted file mode 100644
index 53f2bea..0000000
--- a/src/data/datasets/camb.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import cv2
-import glob
-import numpy as np
-import os
-import sys
-import time
-import gzip
-import shutil
-# Allow accesing files relative to this file
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../../'))
-from ocr.viz import print_progress_bar
-
-
-def extract(location, output, number=5):
- output = os.path.join(location, output)
- if not os.path.exists(output):
- os.makedirs(output)
-
- for sub in ['lob', 'numbers']:
- folder = os.path.join(location, sub)
- seg_files = glob.glob(os.path.join(folder, '*.seg'))
- length = sum([int(open(l, 'r').readline()) for l in seg_files])
-
- itr = 0
- for fl in seg_files:
- # Uncompressing tiff files
- with gzip.open(fl[:-4] + '.tiff.gz', 'rb') as f_in:
- with open(fl[:-4] + '.tiff', 'wb') as f_out:
- shutil.copyfileobj(f_in, f_out)
- image = cv2.imread(fl[:-4] + ".tiff")
- with open(fl) as f:
- f.readline()
- for line in f:
- rect = [int(val) for val in line.strip().split(' ')[1:]]
- word = line.split(' ')[0].split('_')[0]
- im = image[rect[2]:rect[3], rect[0]:rect[1]]
-
- if 0 not in im.shape:
- cv2.imwrite(
- os.path.join(
- output,
- '%s_%s_%s.png' % (word, number, time.time())),
- im)
- print_progress_bar(itr, length)
- itr += 1
-
- print("\tNumber of words:", len([n for n in os.listdir(output)]))
diff --git a/src/data/datasets/cvl.py b/src/data/datasets/cvl.py
deleted file mode 100644
index d95af40..0000000
--- a/src/data/datasets/cvl.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import unidecode
-import glob
-import os
-import sys
-import time
-import re
-from PIL import Image
-# Allow accesing files relative to this file
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../../'))
-from ocr.viz import print_progress_bar
-
-
-def extract(location, output, number=3):
- output = os.path.join(location, output)
- if not os.path.exists(output):
- os.makedirs(output)
-
- for sub in ['cvl-database-1-1/testset', 'cvl-database-1-1/trainset']:
- folder = os.path.join(location, sub)
- images = glob.glob(os.path.join(folder, 'words', '*', '*.tif'))
-
- for i, im in enumerate(images):
- word = re.search('\/\d+-\d+-\d+-\d+-(.+?).tif', im).group(1)
- word = unidecode.unidecode(word)
-
- if os.stat(im).st_size != 0:
- outpath = os.path.join(
- output,
- '%s_%s_%s.png' % (word, number, time.time()))
- Image.open(im).save(outpath)
- print_progress_bar(i, len(images))
-
- print("\tNumber of words:", len([n for n in os.listdir(output)]))
diff --git a/src/data/datasets/iam.py b/src/data/datasets/iam.py
deleted file mode 100644
index 721be43..0000000
--- a/src/data/datasets/iam.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import time
-import os
-import sys
-from shutil import copyfile
-# Allow accesing files relative to this file
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../../'))
-from ocr.viz import print_progress_bar
-
-
-# Words with these characters are removed
-# you have to extend the alphabet in order to use them (ocr/datahelpers.py)
-prohibited = [',', '(', ')', ';', ':', '/', '\\',
- '#', '"', '?', '!', '*', '_', '&']
-
-
-def extract(location, output, number=2):
- output = os.path.join(location, output)
- err_output = os.path.join(location, 'words_with_error')
- if not os.path.exists(output):
- os.makedirs(output)
- if not os.path.exists(err_output):
- os.makedirs(err_output)
-
- folder = os.path.join(location, 'words')
- label_file = os.path.join(location, 'words.txt')
- length = len(open(label_file).readlines())
-
- with open(label_file) as fp:
- for i, line in enumerate(fp):
- if line[0] != '#':
- l = line.strip().split(" ")
- impath = os.path.join(
- folder,
- l[0].split('-')[0],
- l[0].split('-')[0] + '-' + l[0].split('-')[1],
- l[0] + '.png')
- word = l[-1]
-
- if (os.stat(impath).st_size != 0
- and word not in ['.', '-', "'"]
- and not any(i in word for i in prohibited)):
-
- out = output if l[1] == 'ok' else err_output
- outpath = os.path.join(
- out, "%s_%s_%s.png" % (word, number, time.time()))
- copyfile(impath, outpath)
-
- print_progress_bar(i, length)
- print("\tNumber of words:", len([n for n in os.listdir(output)]))
diff --git a/src/data/datasets/orand.py b/src/data/datasets/orand.py
deleted file mode 100644
index cc511cc..0000000
--- a/src/data/datasets/orand.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import glob
-import os
-import sys
-from shutil import copyfile
-import time
-# Allow accesing files relative to this file
-location = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(os.path.join(location, '../../'))
-from ocr.viz import print_progress_bar
-
-
-def extract(location, output, number=4):
- output = os.path.join(location, output)
- if not os.path.exists(output):
- os.makedirs(output)
-
- for sub in ['ORAND-CAR-2014/CAR-A', 'ORAND-CAR-2014/CAR-B']:
- folder = os.path.join(location, sub)
- l_files = glob.glob(os.path.join(folder, '*.txt'))
- length = sum(1 for fl in l_files for line in open(fl))
-
- itr = 0
- for fl in l_files:
- im_folder = fl[:-6] + 'images'
- with open(fl) as f:
- for line in f:
- im, word = line.strip().split('\t')
- impath = os.path.join(im_folder, im)
-
- if os.stat(impath).st_size != 0:
- outpath = os.path.join(
- output,
- '%s_%s_%s.png' % (word, number, time.time()))
- copyfile(impath, outpath)
- print_progress_bar(itr, length)
- itr += 1
-
- print("\tNumber of words:", len([n for n in os.listdir(output)]))
diff --git a/src/ocr/__init__.py b/src/ocr/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/ocr/imgtransform.py b/src/ocr/imgtransform.py
deleted file mode 100644
index 722b131..0000000
--- a/src/ocr/imgtransform.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Functions for transforming and preprocessing images for training
-"""
-import numpy as np
-import pandas as pd
-import cv2
-from scipy.ndimage.interpolation import map_coordinates
-
-
-def coordinates_remap(image, factor_alpha, factor_sigma):
- """Transforming image using remaping coordinates."""
- alpha = image.shape[1] * factor_alpha
- sigma = image.shape[1] * factor_sigma
- shape = image.shape
-
- blur_size = int(4*sigma) | 1
- dx = alpha * cv2.GaussianBlur((np.random.rand(*shape) * 2 - 1),
- ksize=(blur_size, blur_size),
- sigmaX=sigma)
- dy = alpha * cv2.GaussianBlur((np.random.rand(*shape) * 2 - 1),
- ksize=(blur_size, blur_size),
- sigmaX=sigma)
-
- x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
- indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
-
- # TODO use cv2.remap(image, dx, dy, interpolation=cv2.INTER_LINEAR)
- return np.array(map_coordinates(image, indices, order=1, mode='constant').reshape(shape))
\ No newline at end of file
diff --git a/src/ocr/viz.py b/src/ocr/viz.py
deleted file mode 100644
index 04670ca..0000000
--- a/src/ocr/viz.py
+++ /dev/null
@@ -1,22 +0,0 @@
-def print_progress_bar(iteration,
- total,
- prefix = '',
- suffix = ''):
- """Call in a loop to create terminal progress bar.
- Args:
- iteration: current iteration (Int)
- total: total iterations (Int)
- prefix: prefix string (Str)
- suffix: suffix string (Str)
- """
- # Printing slowes down the loop
- if iteration % (total // 100) == 0:
- length = 40
- iteration += 1
- percent = (100 * iteration) // (total * 99/100)
- filled_length = int(length * percent / 100)
- bar = '█' * filled_length + '-' * (length - filled_length)
- print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
-
- if iteration >= total * 99/100:
- print()