diff --git a/README.md b/README.md index 49ae5b4..c1cb529 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,28 @@ 中文版说明请见[中文README](./README_cn.md)。 + + +# Update 2019.07.25: release cnocr V1.0.0 + +`cnocr` `v1.0.0` is released, which is more efficient for prediction. **The new version of the model is not compatible with the previous version.** So if upgrading, please download the latest model file again. See below for the details (same as before). + + + +Main changes are: + +- **The new crnn model supports prediction for variable-width image files, so is more efficient for prediction.** +- Support fine-tuning the existing model with specific data. +- Fix bugs,such as `train accuracy` always `0`. +- Depended package `mxnet` is upgraded from `1.3.1` to `1.4.1`. + + + # cnocr + A python package for Chinese OCR with available trained models. So it can be used directly after installed. -The accuracy of the current crnn model is about `98.7%`. +The accuracy of the current crnn model is about `98.8%`. The project originates from our own ([爱因互动 Ein+](https://einplus.cn)) internal needs. Thanks for the internal supports. @@ -30,8 +48,38 @@ pip install cnocr ## Usage +The first time cnocr is used, the model files will be downloaded automatically from +[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0) to `~/.cnocr`. + +The zip file will be extracted and you can find the resulting model files in `~/.cnocr/models` by default. +In case the automatic download can't perform well, you can download the zip file manually +from [Baidu NetDisk](https://pan.baidu.com/s/1DWV3H2UWmzOU6d48UbTYVw) with extraction code `ss81`, and put the zip file to `~/.cnocr`. The code will do else. + + + ### Predict +Three functions are provided for prediction. + + + +#### 1. `CnOcr.ocr(img_fp)` + +The function `cnOcr.ocr (img_fp)` can recognize texts in an image containing multiple lines of text (or single lines). + + + +**Function Description** + +- input parameter `img_fp`: image file path; or color image `mx.nd.NDArray` or `np.ndarray`, with shape `(height, width, 3)`, and the channels should be RGB formatted. +- return: `List(List(Char))`, such as: `[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`. + + + + +**Usage Case** + + ```python from cnocr import CnOcr ocr = CnOcr() @@ -39,30 +87,48 @@ res = ocr.ocr('examples/multi-line_cn1.png') print("Predicted Chars:", res) ``` -When you run the previous codes, the model files will be downloaded automatically from -[Dropbox](https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip) to `~/.cnocr`. -The zip file will be extracted and you can find the resulting model files in `~/.cnocr/models` by default. -In case the automatic download can't perform well, you can download the zip file manually -from [Baidu NetDisk](https://pan.baidu.com/s/1s91985r0YBGbk_1cqgHa1Q) with extraction code `pg26`, -and put the zip file to `~/.cnocr`. The code will do else. +or: + +```python +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = mx.image.imread(img_fp, 1) +res = ocr.ocr(img) +print("Predicted Chars:", res) +``` -Try the predict command for [examples/multi-line_cn1.png](./examples/multi-line_cn1.png): +The previous codes can recognize texts in the image file [examples/multi-line_cn1.png](./examples/multi-line_cn1.png): ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) +The OCR results shoule be: + ```bash -python scripts/cnocr_predict.py --file examples/multi-line_cn1.png -``` -You will get: -```python -Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], ['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], ['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], ['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], ['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], ['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], ['等', '多', '种', '形', '式', '。']] +Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], + ['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], + ['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'], + ['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], + ['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], + ['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], + ['等', '多', '种', '形', '式', '。']] ``` +#### 2. `CnOcr.ocr_for_single_line(img_fp)` + +If you know that the image you're predicting contains only one line of text, function `CnOcr.ocr_for_single_line(img_fp)` can be used instead。Compared with `CnOcr.ocr()`, the result of `CnOcr.ocr_for_single_line()` is more reliable because the process of splitting lines is not required. + + +**Function Description** -### Predict for Single-line-characters Image +- input parameter `img_fp`: image file path; or color image `mx.nd.NDArray` or `np.ndarray`, with shape `[height, width]` or `[height, width, channel]`. The optional channel should be `1` (gray image) or `3` (color image). +- return: `List(Char)`, such as: `['你', '好']`. -If you know your image includes only one single line characters, you can use function `Cnocr.ocr_for_single_line()` instead of `Cnocr.ocr()`. `Cnocr.ocr_for_single_line()` should be more efficient. + + +**Usage Case**: ```python from cnocr import CnOcr @@ -71,31 +137,94 @@ res = ocr.ocr_for_single_line('examples/rand_cn1.png') print("Predicted Chars:", res) ``` -With file [examples/multi-line_cn1.png](./examples/multi-line_cn1.png): +or: + +```python +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/rand_cn1.png' +img = mx.image.imread(img_fp, 1) +res = ocr.ocr_for_single_line(img) +print("Predicted Chars:", res) +``` + + +The previous codes can recognize texts in the image file [examples/rand_cn1.png](./examples/rand_cn1.png): ![examples/rand_cn1.png](./examples/rand_cn1.png) -you will get: +The OCR results shoule be: + +```bash +Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷'] +``` + +#### 3. `CnOcr.ocr_for_single_lines(img_list)` + +Function `CnOcr.ocr_for_single_lines(img_list)` can predict a number of single-line-text image arrays batchly. Actually `CnOcr.ocr(img_fp)` and `CnOcr.ocr_for_single_line(img_fp)` both invoke `CnOcr.ocr_for_single_lines(img_list)` internally. + + + +**Function Description** + +- input parameter `img_list`: list of images, in which each element should be a line image array, with type `mx.nd.NDArray` or `np.ndarray`. Each element should be a tensor with values ranging from `0` to` 255`, and with shape `[height, width]` or `[height, width, channel]`. The optional channel should be `1` (gray image) or `3` (color image). +- return: `List(List(Char))`, such as: `[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`. + + + +Usage Case**: ```python -Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '子', '犷'] +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = mx.image.imread(img_fp, 1).asnumpy() +line_imgs = line_split(img, blank=True) +line_img_list = [line_img for line_img, _ in line_imgs] +res = ocr.ocr_for_single_lines(line_img_list) +print("Predicted Chars:", res) +``` + +More usage cases can be found at [tests/test_cnocr.py](./tests/test_cnocr.py). + + +### Using the Script + +```bash +python scripts/cnocr_predict.py --file examples/multi-line_cn1.png ``` ### (No NECESSARY) Train -You can use the package without any train. But if you really really want to train your own models, -follow this: +You can use the package without any train. But if you really really want to train your own models, follow this: ```bash python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr ``` + + +Fine-tuning the model with specific data from existing models is also supported. Please refer to the following command: + +```bash +python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr --load_epoch 20 +``` + + + +More references can be found at [scripts/run_cnocr_train.sh](./scripts/run_cnocr_train.sh). + + + ## Future Work -* [x] support multi-line-characters recognition -* Support space recognition -* Bugfixes -* Add Tests -* Maybe use no symbol to rewrite the model -* Try other models such as DenseNet, ResNet + +* [x] support multi-line-characters recognition (`Done`) +* [x] crnn model supports prediction for variable-width image files (`Done`) +* [x] Add Unit Tests (`Doing`) +* [x] Bugfixes (`Doing`) +* [ ] Support space recognition (Tried, but not successful for now ) +* [ ] Try other models such as DenseNet, ResNet diff --git a/README_cn.md b/README_cn.md index b2a9f4a..3c5dc26 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,8 +1,21 @@ +# Update 2019.07.25: 发布 cnocr V1.0.0 + +`cnocr`发布了预测效率更高的新版本v1.0.0。**新版本的模型跟以前版本的模型不兼容**。所以如果大家是升级的话,需要重新下载最新的模型文件。具体说明见下面(流程和原来相同)。 + + + +主要改动如下: + +- **crnn模型支持可变长预测,提升预测效率** +- 支持利用特定数据对现有模型进行精调(继续训练) +- 修复bugs,如训练时`accuracy`一直为`0` +- 依赖的 `mxnet` 版本从`1.3.1`更新至 `1.4.1` + # cnocr **cnocr**是用来做中文OCR的**Python 3**包。cnocr自带了训练好的识别模型,所以安装后即可直接使用。 -目前使用的识别模型是**crnn**,识别准确度约为 `98.7%`。 +目前使用的识别模型是**crnn**,识别准确度约为 `98.8%`。 本项目起源于我们自己 ([爱因互动 Ein+](https://einplus.cn)) 内部的项目需求,所以非常感谢公司的支持。 @@ -28,9 +41,36 @@ pip install cnocr ## 使用方法 -### 预测 +首次使用cnocr时,系统会自动从[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0)下载zip格式的模型压缩文件,并存于 `~/.cnocr`目录。 +下载后的zip文件代码会自动对其解压,然后把解压后的模型相关文件放于`~/.cnocr/models`目录。 + +如果系统不能自动从[Dropbox](https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=0)成功下载zip文件,则需要手动下载此zip文件并把它放于 `~/.cnocr`目录。 +另一个下载地址是[百度云盘](https://pan.baidu.com/s/1DWV3H2UWmzOU6d48UbTYVw)(提取码为`ss81`)。 +放置好zip文件后,后面的事代码就会自动执行了。 + + + +### 代码预测 + +主要包含三个函数,下面分别说明。 + + + +#### 1. 函数`CnOcr.ocr(img_fp)` + +函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。 + + + +**函数说明**: + +- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如上例);或者是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width, 3)`,第三个维度是channel,它应该是`RGB`格式的。 +- 返回值:为一个嵌套的`list`,类似这样`[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`。 + + + +**调用示例**: -#### 代码引用 ```python from cnocr import CnOcr @@ -39,11 +79,18 @@ res = ocr.ocr('examples/multi-line_cn1.png') print("Predicted Chars:", res) ``` -首次使用cnocr时,系统会自动从[Dropbox](https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip)下载zip格式的模型压缩文件,并存于 `~/.cnocr`目录。 -下载后的zip文件代码会自动对其解压,然后把解压后的模型相关文件放于`~/.cnocr/models`目录。 -如果系统不能自动从[Dropbox](https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip)成功下载zip文件,则需要手动下载此zip文件并把它放于 `~/.cnocr`目录。 -另一个下载地址是[百度云盘](https://pan.baidu.com/s/1s91985r0YBGbk_1cqgHa1Q)(提取码为`pg26`)。 -放置好zip文件后,后面的事代码就会自动执行了。 +或: +```python +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = mx.image.imread(img_fp, 1) +res = ocr.ocr(img) +print("Predicted Chars:", res) +``` + + 上面使用的图片文件 [examples/multi-line_cn1.png](./examples/multi-line_cn1.png)内容如下: @@ -53,15 +100,30 @@ print("Predicted Chars:", res) 上面预测代码段的返回结果如下: -```python -Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], ['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], ['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], ['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], ['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], ['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], ['等', '多', '种', '形', '式', '。']] +```bash +Predicted Chars: [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], + ['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], + ['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'], + ['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], + ['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], + ['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], + ['等', '多', '种', '形', '式', '。']] ``` -##### 单行文字图片的预测 +#### 2. 函数`CnOcr.ocr_for_single_line(img_fp)` + +如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。 + +**函数说明**: + +- 输入参数 `img_fp`: 可以是需要识别的单行文字图片文件路径(如上例);或者是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width)`或`(height, width, channel)`。如果没有channel,表示传入的就是灰度图片。第三个维度channel可以是`1`(灰度图片)或者`3`(彩色图片)。如果是彩色图片,它应该是`RGB`格式的。 +- 返回值:为一个`list`,类似这样`['你', '好']`。 -如果明确知道要预测的图片中只包含了单行文字,可以使用`Cnocr.ocr_for_single_line()` 接口,和 `Cnocr.ocr()`相比,`Cnocr.ocr_for_single_line()`结果可靠性更强。 + + +**调用示例**: ```python from cnocr import CnOcr @@ -70,21 +132,65 @@ res = ocr.ocr_for_single_line('examples/rand_cn1.png') print("Predicted Chars:", res) ``` +或: + +```python +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/rand_cn1.png' +img = mx.image.imread(img_fp, 1) +res = ocr.ocr_for_single_line(img) +print("Predicted Chars:", res) +``` -对图片文件 [examples/multi-line_cn1.png](./examples/multi-line_cn1.png): +对图片文件 [examples/rand_cn1.png](./examples/rand_cn1.png): ![examples/rand_cn1.png](./examples/rand_cn1.png) 的预测结果如下: ```bash -Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '子', '犷'] +Predicted Chars: ['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷'] +``` + + + +#### 3. 函数`CnOcr.ocr_for_single_lines(img_list)` + +函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)`和`CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`。 + + + +**函数说明**: + +- 输入参数` img_list`: 为一个`list`;其中每个元素是已经从图片文件中读入的数组,类型可以为`mx.nd.NDArray` 或 `np.ndarray`,取值应该是`[0,255]`的整数,维数应该是`(height, width)`或`(height, width, channel)`。如果没有channel,表示传入的就是灰度图片。第三个维度channel可以是`1`(灰度图片)或者`3`(彩色图片)。如果是彩色图片,它应该是`RGB`格式的。 +- 返回值:为一个嵌套的`list`,类似这样`[['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']]`。 + + + +**调用示例**: + +```python +import mxnet as mx +from cnocr import CnOcr +ocr = CnOcr() +img_fp = 'examples/multi-line_cn1.png' +img = mx.image.imread(img_fp, 1).asnumpy() +line_imgs = line_split(img, blank=True) +line_img_list = [line_img for line_img, _ in line_imgs] +res = ocr.ocr_for_single_lines(line_img_list) +print("Predicted Chars:", res) ``` -#### 脚本引用 +更详细的使用方法,可参考[tests/test_cnocr.py](./tests/test_cnocr.py)中提供的测试用例。 + + + +### 脚本引用 也可以使用脚本模式预测: @@ -105,12 +211,25 @@ python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr ``` + +现在也支持从已有模型利用特定数据精调模型,请参考下面命令: + +```bash +python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr --load_epoch 20 +``` + + + +更多可参考脚本[scripts/run_cnocr_train.sh](./scripts/run_cnocr_train.sh)中的命令。 + + + ## 未来工作 -* [x] 支持图片包含多行文字 -* 支持`空格`识别 -* 修bugs(目前代码还比较凌乱。。) -* 完善测试用例 -* 考虑使用MxNet的命令式编程重写代码,提升灵活性 +* [x] 支持图片包含多行文字 (`Done`) +* [x] crnn模型支持可变长预测,提升灵活性 (`Done`) +* [x] 完善测试用例 (`Doing`) +* [x] 修bugs(目前代码还比较凌乱。。) (`Doing`) +* [ ] 支持`空格`识别(`V1.0.0`在训练集中加入了空格,但从预测结果看,空格依旧是识别不出来) * 尝试新模型,如 DenseNet、ResNet,进一步提升识别准确率 diff --git a/cnocr/__version__.py b/cnocr/__version__.py new file mode 100644 index 0000000..1f356cc --- /dev/null +++ b/cnocr/__version__.py @@ -0,0 +1 @@ +__version__ = '1.0.0' diff --git a/cnocr/cn_ocr.py b/cnocr/cn_ocr.py index b255f93..b786286 100644 --- a/cnocr/cn_ocr.py +++ b/cnocr/cn_ocr.py @@ -16,18 +16,18 @@ # specific language governing permissions and limitations # under the License. import os -from copy import deepcopy import mxnet as mx import numpy as np from PIL import Image +from cnocr.__version__ import __version__ from cnocr.consts import MODEL_EPOCE from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams from cnocr.fit.lstm import init_states from cnocr.fit.ctc_metrics import CtcMetrics from cnocr.data_utils.data_iter import SimpleBatch from cnocr.symbols.crnn import crnn_lstm -from cnocr.utils import data_dir, get_model_file, read_charset +from cnocr.utils import data_dir, get_model_file, read_charset, normalize_img_array from cnocr.line_split import line_split @@ -93,7 +93,7 @@ def load_module(prefix, epoch, data_names, data_shapes, network=None): class CnOcr(object): - MODEL_FILE_PREFIX = 'model' + MODEL_FILE_PREFIX = 'model-v{}'.format(__version__) def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE): self._model_dir = os.path.join(root, 'models') @@ -102,7 +102,9 @@ def __init__(self, root=data_dir(), model_epoch=MODEL_EPOCE): self._alphabet, _ = read_charset(os.path.join(self._model_dir, 'label_cn.txt')) self._hp = Hyperparams() - self._mods = {} + self._hp._loss_type = None # infer mode + + self._mod = self._get_module(self._hp) def _assert_and_prepare_model_files(self, root): model_dir = self._model_dir @@ -123,65 +125,169 @@ def _assert_and_prepare_model_files(self, root): os.removedirs(model_dir) get_model_file(root) - def _get_module(self, hp, sample): + def _get_module(self, hp): network = crnn_lstm(hp) prefix = os.path.join(self._model_dir, self.MODEL_FILE_PREFIX) - mod = load_module(prefix, MODEL_EPOCE, sample.data_names, sample.provide_data, network=network) + # import pdb; pdb.set_trace() + data_names = ['data'] + data_shapes = [(data_names[0], (hp.batch_size, 1, hp.img_height, hp.img_width))] + mod = load_module(prefix, self._model_epoch, data_names, data_shapes, network=network) return mod def ocr(self, img_fp): """ :param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray, with shape (height, width, 3), and the channels should be RGB formatted. - :return: List(List(Letter)), such as: + :return: List(List(Char)), such as: [['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']] """ if isinstance(img_fp, str) and os.path.isfile(img_fp): img = mx.image.imread(img_fp, 1).asnumpy() - elif isinstance(img_fp, mx.nd.NDArray) or isinstance(img_fp, np.ndarray): + elif isinstance(img_fp, mx.nd.NDArray): + img = img_fp.asnumpy() + elif isinstance(img_fp, np.ndarray): img = img_fp else: raise TypeError('Inappropriate argument type.') if min(img.shape[0], img.shape[1]) < 2: return '' line_imgs = line_split(img, blank=True) - line_chars_list = [] - for line_idx, (line_img, _) in enumerate(line_imgs): - line_img = np.array(Image.fromarray(line_img).convert('L')) - line_chars = self.ocr_for_single_line(line_img) - line_chars_list.append(line_chars) + line_img_list = [line_img for line_img, _ in line_imgs] + line_chars_list = self.ocr_for_single_lines(line_img_list) return line_chars_list def ocr_for_single_line(self, img_fp): """ - Recognize characters from an image with characters with only one line - :param img_fp: image file path; or gray image mx.nd.NDArray; or gray image np.ndarray, - with shape [height, width] or [height, width, 1]. - :return: charector list, such as ['你', '好'] + Recognize characters from an image with only one-line characters. + :param img_fp: image file path; or image mx.nd.NDArray or np.ndarray, + with shape [height, width] or [height, width, channel]. + The optional channel should be 1 (gray image) or 3 (color image). + :return: character list, such as ['你', '好'] """ - hp = deepcopy(self._hp) if isinstance(img_fp, str) and os.path.isfile(img_fp): img = read_ocr_img(img_fp) elif isinstance(img_fp, mx.nd.NDArray) or isinstance(img_fp, np.ndarray): img = img_fp else: raise TypeError('Inappropriate argument type.') - img = rescale_img(img, hp) + res = self.ocr_for_single_lines([img]) + return res[0] + + def ocr_for_single_lines(self, img_list): + """ + Batch recognize characters from a list of one-line-characters images. + :param img_list: list of images, in which each element should be a line image array, + with type mx.nd.NDArray or np.ndarray. + Each element should be a tensor with values ranging from 0 to 255, + and with shape [height, width] or [height, width, channel]. + The optional channel should be 1 (gray image) or 3 (color image). + :return: list of list of chars, such as + [['第', '一', '行'], ['第', '二', '行'], ['第', '三', '行']] + """ + if len(img_list) == 0: + return [] + img_list = [self._preprocess_img_array(img) for img in img_list] - init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) + batch_size = len(img_list) + img_list, img_widths = self._pad_arrays(img_list) + # import pdb; pdb.set_trace() sample = SimpleBatch( - data_names=['data'] + init_state_names, - data=[mx.nd.array([img])] + init_state_arrays) + data_names=['data'], + data=[mx.nd.array(img_list)]) - mod = self._get_module(hp, sample) + prob = self._predict(sample) + prob = np.reshape(prob, (-1, batch_size, prob.shape[1])) # [seq_len, batch_size, num_classes] + max_width = max(img_widths) + res = [] + for i in range(batch_size): + res.append(self._gen_line_pred_chars(prob[:, i, :], img_widths[i], max_width)) + return res + + def _preprocess_img_array(self, img): + """ + :param img: image array with type mx.nd.NDArray or np.ndarray, + with shape [height, width] or [height, width, channel]. + channel shoule be 1 (gray image) or 3 (color image). + + :return: np.ndarray, with shape (1, height, width) + """ + if len(img.shape) == 3 and img.shape[2] == 3: + if isinstance(img, mx.nd.NDArray): + img = img.asnumpy() + if img.dtype != np.dtype('uint8'): + img = img.astype('uint8') + # color to gray + img = np.array(Image.fromarray(img).convert('L')) + img = rescale_img(img, self._hp) + return normalize_img_array(img) + + def _pad_arrays(self, img_list): + """Padding to make sure all the elements have the same width.""" + img_widths = [img.shape[2] for img in img_list] + if len(img_list) <= 1: + return img_list, img_widths + max_width = max(img_widths) + pad_width = [(0, 0), (0, 0), (0, 0)] + padded_img_list = [] + for img in img_list: + if img.shape[2] < max_width: + pad_width[2] = (0, max_width - img.shape[2]) + img = np.pad(img, pad_width, 'constant', constant_values=0.0) + padded_img_list.append(img) + return padded_img_list, img_widths + + def _predict(self, sample): + mod = self._mod mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() + return prob - prediction, start_end_idx = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) - # print(start_end_idx) + def _gen_line_pred_chars(self, line_prob, img_width, max_img_width): + """ + Get the predicted characters. + :param line_prob: with shape of [seq_length, num_classes] + :param img_width: + :param max_img_width: + :return: + """ + class_ids = np.argmax(line_prob, axis=-1) + # idxs = list(zip(range(len(class_ids)), class_ids)) + # probs = [line_prob[e[0], e[1]] for e in idxs] + if img_width < max_img_width: + comp_ratio = self._hp.seq_len_cmpr_ratio + end_idx = img_width // comp_ratio + if end_idx < len(class_ids): + class_ids[end_idx:] = 0 + prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist()) + # print(start_end_idx) alphabet = self._alphabet res = [alphabet[p] for p in prediction] + + # res = self._insert_space_char(res, start_end_idx) return res + + def _insert_space_char(self, pred_chars, start_end_idx, min_interval=None): + if len(pred_chars) < 2: + return pred_chars + assert len(pred_chars) == len(start_end_idx) + + if min_interval is None: + # 自动计算最小区间值 + intervals = {start_end_idx[idx][0] - start_end_idx[idx-1][1] for idx in range(1, len(start_end_idx))} + if len(intervals) >= 3: + intervals = sorted(list(intervals)) + if intervals[0] < 1: # 排除间距为0的情况 + intervals = intervals[1:] + min_interval = intervals[2] + else: + min_interval = start_end_idx[-1][1] # no space will be inserted + + res_chars = [pred_chars[0]] + for idx in range(1, len(pred_chars)): + if start_end_idx[idx][0] - start_end_idx[idx-1][1] >= min_interval: + res_chars.append(' ') + res_chars.append(pred_chars[idx]) + return res_chars diff --git a/cnocr/consts.py b/cnocr/consts.py index 9147212..a7ea653 100644 --- a/cnocr/consts.py +++ b/cnocr/consts.py @@ -1,2 +1,7 @@ -MODEL_BASE_URL = 'https://www.dropbox.com/s/5n09nxf4x95jprk/cnocr-models-v0.1.0.zip?dl=1' +from .__version__ import __version__ + + +MODEL_BASE_URL = 'https://www.dropbox.com/s/7w8l3mk4pvkt34w/cnocr-models-v1.0.0.zip?dl=1' MODEL_EPOCE = 20 + +ZIP_FILE_NAME = 'cnocr-models-v{}.zip'.format(__version__) diff --git a/cnocr/data_utils/captcha_generator.py b/cnocr/data_utils/captcha_generator.py index 5d49528..1dd9c73 100644 --- a/cnocr/data_utils/captcha_generator.py +++ b/cnocr/data_utils/captcha_generator.py @@ -67,6 +67,7 @@ def image(self, captcha_str): img = cv2.resize(img, (self.h, self.w)) img = img.transpose(1, 0) img = np.multiply(img, 1 / 255.0) + # print(np.mean(img), np.std(img)) return img @@ -113,7 +114,7 @@ def get(self): np.ndarray A captcha image, normalized to [0, 1] """ - return self._gen_sample() + return self._gen_sample(0) @staticmethod def get_rand(num_digit_min, num_digit_max): @@ -130,7 +131,7 @@ def get_rand(num_digit_min, num_digit_max): buf += str(random.randint(0, 9)) return buf - def _gen_sample(self): + def _gen_sample(self, _): """ Generate a random captcha image sample Returns @@ -139,7 +140,10 @@ def _gen_sample(self): Tuple of image (numpy ndarray) and character string of digits used to generate the image """ num_str = self.get_rand(self.num_digit_min, self.num_digit_max) - return self.captcha.image(num_str), num_str + num_array = np.zeros(self.num_digit_max) + num = list(map(lambda x: int(x) + 1, list(num_str))) + num_array[:len(num)] = num + return self.captcha.image(num_str), num_array class MPDigitCaptcha(DigitCaptcha): diff --git a/cnocr/data_utils/data_iter.py b/cnocr/data_utils/data_iter.py index 1cb9646..c28c98c 100644 --- a/cnocr/data_utils/data_iter.py +++ b/cnocr/data_utils/data_iter.py @@ -6,6 +6,7 @@ import mxnet as mx import random +from ..utils import normalize_img_array from .multiproc_data import MPData @@ -138,7 +139,7 @@ def __init__(self, data_root, data_list, batch_size, data_shape, num_label, lstm self.name = name def __iter__(self): - init_state_names = [x[0] for x in self.init_states] + # init_state_names = [x[0] for x in self.init_states] data = [] label = [] cnt = 0 @@ -157,9 +158,9 @@ def __iter__(self): label.append(ret) if cnt % self.batch_size == 0: - data_all = [mx.nd.array(data)] + self.init_state_arrays + data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] - data_names = ['data'] + init_state_names + data_names = ['data'] label_names = ['label'] data = [] label = [] @@ -192,11 +193,16 @@ def __init__(self, data_root, data_list, data_shape, num_label, num_processes, m self.data_root = data_root self.dataset_lines = open(data_list).readlines() + self.total_size = len(self.dataset_lines) + self.cur_proc_idxs = list(range(num_processes)) + self.num_proc = num_processes self.mp_data = MPData(num_processes, max_queue_size, self._gen_sample) - def _gen_sample(self): - m_line = random.choice(self.dataset_lines) + def _gen_sample(self, proc_id): + # m_line = random.choice(self.dataset_lines) + cur_idx = self.cur_proc_idxs[proc_id] + m_line = self.dataset_lines[cur_idx] img_lst = m_line.strip().split(' ') img_path = os.path.join(self.data_root, img_lst[0]) @@ -204,6 +210,8 @@ def _gen_sample(self): img = np.array(img) # print(img.shape) img = np.transpose(img, (1, 0)) # res: [1, width, height] + img = normalize_img_array(img) + # print(np.mean(img), np.std(img)) # if len(img.shape) == 2: # img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # res: [1, width, height] @@ -211,6 +219,10 @@ def _gen_sample(self): for idx in range(1, len(img_lst)): labels[idx - 1] = int(img_lst[idx]) + self.cur_proc_idxs[proc_id] += self.num_proc + if self.cur_proc_idxs[proc_id] >= self.total_size: + self.cur_proc_idxs[proc_id] -= self.total_size + return img, labels @property @@ -249,7 +261,7 @@ class OCRIter(mx.io.DataIter): """ Iterator class for generating captcha image data """ - def __init__(self, count, batch_size, lstm_init_states, captcha, num_label, name): + def __init__(self, count, batch_size, captcha, num_label, name): """ Parameters ---------- @@ -265,16 +277,17 @@ def __init__(self, count, batch_size, lstm_init_states, captcha, num_label, name super(OCRIter, self).__init__() self.batch_size = batch_size self.count = count if count > 0 else captcha.size // batch_size - self.init_states = lstm_init_states - self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states] + # self.init_states = lstm_init_states + # self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states] data_shape = captcha.shape - self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states + # self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states + self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] self.provide_label = [('label', (self.batch_size, num_label))] self.mp_captcha = captcha self.name = name def __iter__(self): - init_state_names = [x[0] for x in self.init_states] + # init_state_names = [x[0] for x in self.init_states] for k in range(self.count): data = [] label = [] @@ -282,12 +295,17 @@ def __iter__(self): img, labels = self.mp_captcha.get() # print(img.shape) img = np.expand_dims(np.transpose(img, (1, 0)), axis=0) # size: [1, height, width] - # import pdb; pdb.set_trace() + # print(img.shape) data.append(img) + # print('labels', type(labels), labels) label.append(labels) - data_all = [mx.nd.array(data)] + self.init_state_arrays + # data_all = [mx.nd.array(data)] + self.init_state_arrays + data_all = [mx.nd.array(data)] + # print(data_all[0].shape) label_all = [mx.nd.array(label)] - data_names = ['data'] + init_state_names + # print(label_all[0]) + # data_names = ['data'] + init_state_names + data_names = ['data'] label_names = ['label'] data_batch = SimpleBatch(data_names, data_all, label_names, label_all) diff --git a/cnocr/data_utils/multiproc_data.py b/cnocr/data_utils/multiproc_data.py index 62d8160..5aec968 100644 --- a/cnocr/data_utils/multiproc_data.py +++ b/cnocr/data_utils/multiproc_data.py @@ -67,8 +67,7 @@ def start(self): """ self._init_proc() - @staticmethod - def _proc_loop(proc_id, alive, queue, fn): + def _proc_loop(self, proc_id, alive, queue, fn): """ Thread loop for generating data @@ -86,7 +85,7 @@ def _proc_loop(proc_id, alive, queue, fn): print("proc {} started".format(proc_id)) try: while alive.value: - data = fn() + data = fn(proc_id) put_success = False while alive.value and not put_success: try: diff --git a/cnocr/fit/ctc_loss.py b/cnocr/fit/ctc_loss.py index 3fcef2c..14e8535 100644 --- a/cnocr/fit/ctc_loss.py +++ b/cnocr/fit/ctc_loss.py @@ -9,7 +9,8 @@ def _add_warp_ctc_loss(pred, seq_len, num_label, label): def _add_mxnet_ctc_loss(pred, seq_len, label): """ Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """ - pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0)) + pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0)) # res: (seq_len, batch_size, num_classes) + # print('pred_ctc', pred_ctc.infer_shape()[1]) loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label) ctc_loss = mx.sym.MakeLoss(loss) @@ -23,6 +24,7 @@ def _add_mxnet_ctc_loss(pred, seq_len, label): def add_ctc_loss(pred, seq_len, num_label, loss_type): """ Adds CTC loss on top of pred symbol and returns the resulting symbol """ label = mx.sym.Variable('label') + # label = mx.sym.Variable('label', shape=(128, 4)) if loss_type == 'warpctc': # print("Using WarpCTC Loss") sm = _add_warp_ctc_loss(pred, seq_len, num_label, label) diff --git a/cnocr/fit/ctc_metrics.py b/cnocr/fit/ctc_metrics.py index 0eae5ad..f218784 100644 --- a/cnocr/fit/ctc_metrics.py +++ b/cnocr/fit/ctc_metrics.py @@ -42,14 +42,14 @@ def ctc_label(p): for i, _ in enumerate(p): c1 = p1[i] c2 = p1[i+1] - if c2 == 0 and c1 != 0 and len(ret) > 0: + if (c2 == 0 or c2 != c1) and c1 != 0 and len(ret) > 0: ret[-1][-1] = i if c2 == 0 or c2 == c1: continue ret.append([c2, i, -1]) if len(ret) == 0: - return [0], [(0, 0)] + return [], [] if ret[-1][-1] < 0: ret[-1][-1] = len(p) @@ -93,7 +93,8 @@ def accuracy(self, label, pred): p = [] for k in range(self.seq_len): p.append(np.argmax(pred[k * batch_size + i])) - p = self.ctc_label(p) + p, _ = self.ctc_label(p) + # print('real: {}, pred: {}'.format(l, p)) if len(p) == len(l): match = True for k, _ in enumerate(p): @@ -116,7 +117,7 @@ def accuracy_lcs(self, label, pred): p = [] for k in range(self.seq_len): p.append(np.argmax(pred[k * batch_size + i])) - p = self.ctc_label(p) + p, _ = self.ctc_label(p) hit += self._lcs(p, l) * 1.0 / len(l) total += 1.0 assert total == batch_size diff --git a/cnocr/fit/fit.py b/cnocr/fit/fit.py index bd0cab3..b3c65db 100644 --- a/cnocr/fit/fit.py +++ b/cnocr/fit/fit.py @@ -3,17 +3,15 @@ import mxnet as mx -def _load_model(args, rank=0): +def _load_model(args): if 'load_epoch' not in args or args.load_epoch is None: - return (None, None, None) + return None, None, None assert args.prefix is not None model_prefix = args.prefix - if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): - model_prefix += "-%d" % (rank) sym, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.load_epoch) - logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) - return (sym, arg_params, aux_params) + logging.info('Loaded model %s-%04d.params', model_prefix, args.load_epoch) + return sym, arg_params, aux_params def fit(network, data_train, data_val, metrics, args, hp, data_names=None): @@ -29,15 +27,38 @@ def fit(network, data_train, data_val, metrics, args, hp, data_names=None): os.makedirs(os.path.dirname(args.prefix)) module = mx.mod.Module( - symbol = network, - data_names= ["data"] if data_names is None else data_names, + symbol=network, + data_names=["data"] if data_names is None else data_names, label_names=['label'], context=contexts) + # from mxnet import nd + # import numpy as np + # data = nd.random.uniform(shape=(128, 1, 32, 100)) + # label = np.random.randint(1, 11, size=(128, 4)) + # module.bind(data_shapes=[('data', (128, 1, 32, 100))], label_shapes=[('label', (128, 4))]) + # # e = module.bind() + # # f = e.forward(is_train=False) + # module.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) + # from ..data_utils.data_iter import SimpleBatch + # data_all = [data] + # label_all = [mx.nd.array(label)] + # # print(label_all[0]) + # # data_names = ['data'] + init_state_names + # data_names = ['data'] + # label_names = ['label'] + # + # data_batch = SimpleBatch(data_names, data_all, label_names, label_all) + # module.forward(data_batch) + # f = module.get_outputs() + # import pdb; pdb.set_trace() + + begin_epoch = args.load_epoch if args.load_epoch else 0 + num_epoch = hp.num_epoch + begin_epoch module.fit(train_data=data_train, eval_data=data_val, - begin_epoch=args.load_epoch if args.load_epoch else 0, - num_epoch=hp.num_epoch, + begin_epoch=begin_epoch, + num_epoch=num_epoch, # use metrics.accuracy or metrics.accuracy_lcs eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), optimizer='AdaDelta', diff --git a/cnocr/fit/lstm.py b/cnocr/fit/lstm.py index 7a89f69..b16126d 100644 --- a/cnocr/fit/lstm.py +++ b/cnocr/fit/lstm.py @@ -2,6 +2,7 @@ from collections import namedtuple import mxnet as mx +from mxnet.gluon.rnn.rnn_layer import LSTM LSTMState = namedtuple("LSTMState", ["c", "h"]) LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", @@ -54,6 +55,22 @@ def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx): next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") return LSTMState(c=next_c, h=next_h) + +def lstm2(net, num_lstm_layer, num_hidden): + net = mx.symbol.squeeze(net, axis=2) # res: bz x f x 35 + net = mx.symbol.transpose(net, axes=(2, 0, 1)) + # print('6', net.infer_shape()[1]) + + lstm = LSTM(num_hidden, num_lstm_layer, bidirectional=True) + # import pdb; pdb.set_trace() + output = lstm(net) # res: `(sequence_length, batch_size, 2*num_hidden)` + # print('7', output.infer_shape()[1]) + return mx.symbol.reshape(output, shape=(-3, -2)) # res: (bz * 35, c) + # - **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)` + # when `layout` is "TNC". If `bidirectional` is True, output shape will instead + # be `(sequence_length, batch_size, 2*num_hidden)` + + def lstm(net, num_lstm_layer, num_hidden, seq_length): last_states = [] forward_param = [] diff --git a/cnocr/hyperparams/cn_hyperparams.py b/cnocr/hyperparams/cn_hyperparams.py index 87885af..78ae02d 100644 --- a/cnocr/hyperparams/cn_hyperparams.py +++ b/cnocr/hyperparams/cn_hyperparams.py @@ -17,7 +17,7 @@ def __init__(self): self._loss_type = "ctc" # ["warpctc" "ctc"] self._batch_size = 128 - self._num_classes = 6425 # 应该是6426的。。 5990 + self._num_classes = 6426 # 应该是6426的。。 5990 self._img_width = 280 self._img_height = 32 @@ -30,7 +30,8 @@ def __init__(self): self._num_hidden = 100 self._num_lstm_layer = 2 # self._seq_length = 35 - self._seq_length = self._img_width // 8 + self.seq_len_cmpr_ratio = 8 # 模型对于图片宽度压缩的比例(模型中的卷积层造成的) + self._seq_length = self._img_width // self.seq_len_cmpr_ratio self._num_label = 10 self._drop_out = 0.5 diff --git a/cnocr/hyperparams/hyperparams2.py b/cnocr/hyperparams/hyperparams2.py index e2f528c..5bd24e1 100644 --- a/cnocr/hyperparams/hyperparams2.py +++ b/cnocr/hyperparams/hyperparams2.py @@ -29,7 +29,8 @@ def __init__(self): # LSTM hyper parameters self._num_hidden = 100 self._num_lstm_layer = 2 - self._seq_length = self._img_width // 8 + self.seq_len_cmpr_ratio = 8 # 模型对于图片宽度压缩的比例(模型中的卷积层造成的) + self._seq_length = self._img_width // self.seq_len_cmpr_ratio self._num_label = 4 self._drop_out = 0.5 diff --git a/cnocr/symbols/crnn.py b/cnocr/symbols/crnn.py index 312c958..61733da 100644 --- a/cnocr/symbols/crnn.py +++ b/cnocr/symbols/crnn.py @@ -22,7 +22,7 @@ """ import mxnet as mx from ..fit.ctc_loss import add_ctc_loss -from ..fit.lstm import lstm +from ..fit.lstm import lstm2 def crnn_no_lstm(hp): @@ -99,7 +99,7 @@ def convRelu(i, input_data, bn=True): layer = mx.sym.LeakyReLU(data=layer, name='leakyrelu-%d-1x1' % i) return layer - net = convRelu(0, data) # bz x f x 32 x 280 + net = convRelu(0, data) # bz x f x 32 x 280 # print('0', net.infer_shape()[1]) max = mx.sym.Pooling(data=net, name='pool-0_m', pool_type='max', kernel=(2, 2), stride=(2, 2)) avg = mx.sym.Pooling(data=net, name='pool-0_a', pool_type='avg', kernel=(2, 2), stride=(2, 2)) @@ -120,18 +120,19 @@ def convRelu(i, input_data, bn=True): if hp.dropout > 0: net = mx.symbol.Dropout(data=net, p=hp.dropout) - hidden_concat = lstm(net, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden, seq_length=hp.seq_length) + hidden_concat = lstm2(net, num_lstm_layer=hp.num_lstm_layer, num_hidden=hp.num_hidden) # import pdb; pdb.set_trace() # mx.sym.transpose(net, []) - pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes, name='pred_fc') # (bz x 25) x num_classes + pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=hp.num_classes, name='pred_fc') # (bz x 35) x num_classes + # print('pred', pred.infer_shape()[1]) if hp.loss_type: # Training mode, add loss return add_ctc_loss(pred, hp.seq_length, hp.num_label, hp.loss_type) - else: - # Inference mode, add softmax - return mx.sym.softmax(data=pred, name='softmax') + # else: + # # Inference mode, add softmax + # return mx.sym.softmax(data=pred, name='softmax') from ..hyperparams.cn_hyperparams import CnHyperparams as Hyperparams diff --git a/cnocr/utils.py b/cnocr/utils.py index 73924fa..758558f 100644 --- a/cnocr/utils.py +++ b/cnocr/utils.py @@ -18,9 +18,10 @@ import os import platform import zipfile - +import numpy as np from mxnet.gluon.utils import download -from .consts import MODEL_BASE_URL + +from .consts import MODEL_BASE_URL, ZIP_FILE_NAME def data_dir_default(): @@ -59,12 +60,11 @@ def get_model_file(root=data_dir()): file_path Path to the requested pretrained model file. """ - file_name = 'cnocr-models.zip' root = os.path.expanduser(root) os.makedirs(root, exist_ok=True) - zip_file_path = os.path.join(root, file_name) + zip_file_path = os.path.join(root, ZIP_FILE_NAME) if not os.path.exists(zip_file_path): download(MODEL_BASE_URL, path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: @@ -75,7 +75,7 @@ def get_model_file(root=data_dir()): def read_charset(charset_fp): - alphabet = [] + alphabet = [None] # 第0个元素是预留id,在CTC中用来分割字符。它不对应有意义的字符 with open(charset_fp, encoding='utf-8') as fp: for line in fp: @@ -85,3 +85,8 @@ def read_charset(charset_fp): # inv_alph_dict[' '] = inv_alph_dict[''] # 对应空格 return alphabet, inv_alph_dict + +def normalize_img_array(img): + """ rescale to [-1.0, 1.0] """ + # return (img / 255.0 - 0.5) * 2 + return (img - np.mean(img)) / (np.std(img) + 1e-6) diff --git a/examples/00010991.jpg b/examples/00010991.jpg new file mode 100644 index 0000000..909af97 Binary files /dev/null and b/examples/00010991.jpg differ diff --git a/examples/helloworld.jpg b/examples/helloworld.jpg new file mode 100644 index 0000000..0b73eef Binary files /dev/null and b/examples/helloworld.jpg differ diff --git a/examples/label_cn.txt b/examples/label_cn.txt index 981af90..c22d12e 100644 --- a/examples/label_cn.txt +++ b/examples/label_cn.txt @@ -1,4 +1,3 @@ - , 的 。 @@ -6423,3 +6422,4 @@ $ 丨 ‖ ˇ + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7d6de99..db9e547 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ #click==6.7 numpy==1.14.0 pillow==5.3.0 -mxnet==1.3.1 +mxnet==1.4.1 gluoncv==0.3.0 #opencv-python==3.4.4.19 diff --git a/scripts/cnocr_train.py b/scripts/cnocr_train.py index b1f60ee..08f354f 100644 --- a/scripts/cnocr_train.py +++ b/scripts/cnocr_train.py @@ -23,7 +23,8 @@ import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from cnocr.data_utils.captcha_generator import MPDigitCaptcha +from cnocr.__version__ import __version__ +from cnocr.utils import data_dir from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams from cnocr.hyperparams.hyperparams2 import Hyperparams as Hyperparams2 from cnocr.data_utils.data_iter import ImageIterLstm, MPOcrImages, OCRIter @@ -35,6 +36,7 @@ def parse_args(): # Parse command line arguments parser = argparse.ArgumentParser() + default_model_prefix = os.path.join(data_dir(), 'models', 'model-v{}'.format(__version__)) parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr", @@ -51,8 +53,9 @@ def parse_args(): type=int, default=2) parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int) parser.add_argument('--load_epoch', type=int, - help='load the model on an epoch using the model-load-prefix') - parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model') + help='load the model on an epoch using the model-load-prefix [Default: no trained model will be loaded]') + parser.add_argument("--prefix", help="Checkpoint prefix [Default '{}']".format(default_model_prefix), + default=default_model_prefix) parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc') parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4) parser.add_argument("--font_path", help="Path to ttf font file or directory containing ttf files") @@ -71,6 +74,8 @@ def get_fonts(path): def run_captcha(args): + from cnocr.data_utils.captcha_generator import MPDigitCaptcha + hp = Hyperparams2() network = crnn_lstm(hp) @@ -85,23 +90,26 @@ def run_captcha(args): num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2) mp_captcha.start() # img, num = mp_captcha.get() - # print(img.shape) + # print(img.shape, num) # import numpy as np # import cv2 # img = np.transpose(img, (1, 0)) # cv2.imwrite('captcha1.png', img * 255) + # import sys + # sys.exit(0) # import pdb; pdb.set_trace() - init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_states = init_c + init_h - data_names = ['data'] + [x[0] for x in init_states] + # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + # init_states = init_c + init_h + # data_names = ['data'] + [x[0] for x in init_states] + data_names = ['data'] data_train = OCRIter( - hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label, + hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_captcha, num_label=hp.num_label, name='train') data_val = OCRIter( - hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, num_label=hp.num_label, + hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_captcha, num_label=hp.num_label, name='val') head = '%(asctime)-15s %(message)s' @@ -120,7 +128,7 @@ def run_cn_ocr(args): network = crnn_lstm(hp) mp_data_train = MPOcrImages(args.data_root, args.train_file, (hp.img_width, hp.img_height), hp.num_label, - num_processes=args.num_proc, max_queue_size=hp.batch_size * 2) + num_processes=args.num_proc, max_queue_size=hp.batch_size * 100) # img, num = mp_data_train.get() # print(img.shape) # print(mp_data_train.shape) @@ -131,20 +139,21 @@ def run_cn_ocr(args): # cv2.imwrite('captcha1.png', img * 255) # import pdb; pdb.set_trace() mp_data_test = MPOcrImages(args.data_root, args.test_file, (hp.img_width, hp.img_height), hp.num_label, - num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 2) + num_processes=max(args.num_proc // 2, 1), max_queue_size=hp.batch_size * 10) mp_data_train.start() mp_data_test.start() - init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] - init_states = init_c + init_h - data_names = ['data'] + [x[0] for x in init_states] + # init_c = [('l%d_init_c' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + # init_h = [('l%d_init_h' % l, (hp.batch_size, hp.num_hidden)) for l in range(hp.num_lstm_layer * 2)] + # init_states = init_c + init_h + # data_names = ['data'] + [x[0] for x in init_states] + data_names = ['data'] data_train = OCRIter( - hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_train, num_label=hp.num_label, + hp.train_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_train, num_label=hp.num_label, name='train') data_val = OCRIter( - hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_data_test, num_label=hp.num_label, + hp.eval_epoch_size // hp.batch_size, hp.batch_size, captcha=mp_data_test, num_label=hp.num_label, name='val') # data_train = ImageIterLstm( # args.data_root, args.train_file, hp.batch_size, (hp.img_width, hp.img_height), hp.num_label, init_states, name="train") @@ -159,7 +168,7 @@ def run_cn_ocr(args): fit(network=network, data_train=data_train, data_val=data_val, metrics=metrics, args=args, hp=hp, data_names=data_names) mp_data_train.reset() - mp_data_test.start() + mp_data_test.reset() if __name__ == '__main__': diff --git a/scripts/infer_captcha_ocr.py b/scripts/infer_captcha_ocr.py index 2aadc2d..df6b0d6 100644 --- a/scripts/infer_captcha_ocr.py +++ b/scripts/infer_captcha_ocr.py @@ -36,8 +36,11 @@ def read_captcha_img(path, hp): """ Reads image specified by path into numpy.ndarray""" import cv2 tgt_h, tgt_w = hp.img_height, hp.img_width - img = cv2.resize(cv2.imread(path, 0), (tgt_h, tgt_w)).astype(np.float32) / 255 - img = np.expand_dims(img.transpose(1, 0), 0) # res: [1, height, width] + img = cv2.imread(path, 0) + # import pdb; pdb.set_trace() + # img = img.astype(np.float32) / 255.0 + img = cv2.resize(img, (tgt_w, tgt_h)).astype(np.float32) / 255.0 + img = np.expand_dims(img, 0) # res: [1, height, width] return img @@ -107,10 +110,10 @@ def read_charset(charset_fp): def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", help="use which kind of dataset, captcha or cn_ocr", - choices=['captcha', 'cn_ocr'], type=str, default='cn_ocr') + choices=['captcha', 'cn_ocr'], type=str, default='captcha') parser.add_argument("--file", help="Path to the CAPTCHA image file") parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='./models/model') - parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=100) + parser.add_argument("--epoch", help="Checkpoint epoch [Default 100]", type=int, default=20) parser.add_argument('--charset_file', type=str, help='存储了每个字对应哪个id的关系.') args = parser.parse_args() if args.dataset == 'cn_ocr': @@ -120,12 +123,10 @@ def main(): hp = Hyperparams2() img = read_captcha_img(args.file, hp) - init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) + # init_state_names, init_state_arrays = lstm_init_states(batch_size=1, hp=hp) # import pdb; pdb.set_trace() - sample = SimpleBatch( - data_names=['data'] + init_state_names, - data=[mx.nd.array([img])] + init_state_arrays) + sample = SimpleBatch(data_names=['data'], data=[mx.nd.array([img])]) network = crnn_lstm(hp) mod = load_module(args.prefix, args.epoch, sample.data_names, sample.provide_data, network=network) @@ -133,7 +134,7 @@ def main(): mod.forward(sample) prob = mod.get_outputs()[0].asnumpy() - prediction = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) + prediction, start_end_idx = CtcMetrics.ctc_label(np.argmax(prob, axis=-1).tolist()) if args.charset_file: alphabet, _ = read_charset(args.charset_file) diff --git a/scripts/run_cnocr_train.sh b/scripts/run_cnocr_train.sh new file mode 100644 index 0000000..dbcc2be --- /dev/null +++ b/scripts/run_cnocr_train.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# -*- coding: utf-8 -*- + +cd `dirname $0`/../ + +## 训练captcha +#python scripts/cnocr_train.py --cpu 2 --num_proc 2 --loss ctc --dataset captcha --font_path /Users/king/Documents/WhatIHaveDone/Test/text_renderer/data/fonts/chn/msyh.ttf + +# 训练中文ocr模型crnn +python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr + +## gpu版本 +#python scripts/cnocr_train.py --gpu 1 --num_proc 8 --loss ctc --dataset cn_ocr --data_root /jfs/jinlong/data/ocr/outer/images \ +# --train_file /jfs/jinlong/data/ocr/outer/train.txt --test_file /jfs/jinlong/data/ocr/outer/test.txt + +## 预测中文图片 +#python scripts/cnocr_predict.py --file examples/rand_cn1.png \ No newline at end of file diff --git a/scripts/run_crnn.sh b/scripts/run_crnn.sh deleted file mode 100644 index 48a0a7d..0000000 --- a/scripts/run_crnn.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env bash -# -*- coding: utf-8 -*- - -cd `dirname $0` - -# 训练中文ocr模型crnn -python scripts/cnocr_train.py --cpu 2 --num_proc 4 --loss ctc --dataset cn_ocr - - -## 预测中文图片 -#python scripts/cnocr_predict.py --file examples/rand_cn1.png \ No newline at end of file diff --git a/setup.py b/setup.py index 6fc3771..755a9ee 100644 --- a/setup.py +++ b/setup.py @@ -1,22 +1,32 @@ #!/usr/bin/env python3 import os from setuptools import find_packages, setup +from pathlib import Path -dir_path = os.path.dirname(os.path.realpath(__file__)) +PACKAGE_NAME = "cnocr" + +here = Path(__file__).parent + +long_description = (here / "README.md").read_text(encoding="utf-8") + +about = {} +exec( + (here / PACKAGE_NAME.replace('.', os.path.sep) / "__version__.py").read_text( + encoding="utf-8" + ), + about, +) required = [ 'numpy>=1.14.0,<1.15.0', 'pillow>=5.3.0', - 'mxnet>=1.3.1,<1.4.0', + 'mxnet>=1.4.1,<1.5.0', 'gluoncv>=0.3.0,<0.4.0', ] -with open("README.md", "r") as fh: - long_description = fh.read() - setup( - name='cnocr', - version='0.2.0', + name=PACKAGE_NAME, + version=about['__version__'], description="Package for Chinese OCR, which can be used after installed without training yourself OCR model", long_description=long_description, long_description_content_type="text/markdown", diff --git a/tests/test_cnocr.py b/tests/test_cnocr.py new file mode 100644 index 0000000..8227600 --- /dev/null +++ b/tests/test_cnocr.py @@ -0,0 +1,102 @@ +# coding: utf-8 +import os +import sys +import pytest +import numpy as np +import mxnet as mx +from mxnet import nd +from PIL import Image + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) + +from cnocr import CnOcr +from cnocr.line_split import line_split + +CNOCR = CnOcr() + +SINGLE_LINE_CASES = [ + ('20457890_2399557098.jpg', [['就', '会', '哈', '哈', '大', '笑', '。', '3', '.', '0']]), + ('rand_cn1.png', [['笠', '淡', '嘿', '骅', '谧', '鼎', '臭', '姚', '歼', '蠢', '驼', '耳', '裔', '挝', '涯', '狗', '蒽', '子', '犷']]) +] +MULTIPLE_LINE_CASES = [ + ('multi-line_cn1.png', [['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], + ['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], + ['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '―', '这', '个', '账'], + ['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], + ['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], + ['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], + ['等', '多', '种', '形', '式', '。']]), + ('multi-line_cn2.png', [['。', '当', '然', ',', '在', '媒', '介', '越', '来', '越', '多', '的', '情', '形', '下', ','], + ['意', '味', '着', '传', '播', '方', '式', '的', '变', '化', '。', '过', '去', '主', '流'], + ['的', '是', '大', '众', '传', '播', ',', '现', '在', '互', '动', '性', '和', '定', '制'], + ['性', '带', '来', '了', '新', '的', '挑', '战', '—', '—', '如', '何', '让', '品', '牌'], + ['与', '消', '费', '者', '更', '加', '互', '动', '。']]), +] +CASES = SINGLE_LINE_CASES + MULTIPLE_LINE_CASES + + +@pytest.mark.parametrize('img_fp, expected', CASES) +def test_ocr(img_fp, expected): + ocr = CNOCR + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + img_fp = os.path.join(root_dir, 'examples', img_fp) + pred = ocr.ocr(img_fp) + print('\n') + print("Predicted Chars:", pred) + assert expected == pred + img = mx.image.imread(img_fp, 1) + pred = ocr.ocr(img) + print("Predicted Chars:", pred) + assert expected == pred + img = mx.image.imread(img_fp, 1).asnumpy() + pred = ocr.ocr(img) + print("Predicted Chars:", pred) + assert expected == pred + + +@pytest.mark.parametrize('img_fp, expected', SINGLE_LINE_CASES) +def test_ocr_for_single_line(img_fp, expected): + ocr = CNOCR + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + img_fp = os.path.join(root_dir, 'examples', img_fp) + pred = ocr.ocr_for_single_line(img_fp) + print('\n') + print("Predicted Chars:", pred) + assert expected[0] == pred + img = mx.image.imread(img_fp, 1) + pred = ocr.ocr_for_single_line(img) + print("Predicted Chars:", pred) + assert expected[0] == pred + img = mx.image.imread(img_fp, 1).asnumpy() + pred = ocr.ocr_for_single_line(img) + print("Predicted Chars:", pred) + assert expected[0] == pred + img = np.array(Image.fromarray(img).convert('L')) + assert len(img.shape) == 2 + pred = ocr.ocr_for_single_line(img) + print("Predicted Chars:", pred) + assert expected[0] == pred + img = np.expand_dims(img, axis=2) + assert len(img.shape) == 3 and img.shape[2] == 1 + pred = ocr.ocr_for_single_line(img) + print("Predicted Chars:", pred) + assert expected[0] == pred + + +@pytest.mark.parametrize('img_fp, expected', MULTIPLE_LINE_CASES) +def test_ocr_for_single_lines(img_fp, expected): + ocr = CNOCR + root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + img_fp = os.path.join(root_dir, 'examples', img_fp) + img = mx.image.imread(img_fp, 1).asnumpy() + line_imgs = line_split(img, blank=True) + line_img_list = [line_img for line_img, _ in line_imgs] + pred = ocr.ocr_for_single_lines(line_img_list) + print('\n') + print("Predicted Chars:", pred) + assert expected == pred + line_img_list = [nd.array(line_img) for line_img in line_img_list] + pred = ocr.ocr_for_single_lines(line_img_list) + print("Predicted Chars:", pred) + assert expected == pred diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..c281738 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,26 @@ +# coding: utf-8 +import os +import sys +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) + +from cnocr.fit.ctc_metrics import CtcMetrics + + +@pytest.mark.parametrize('input, expected', [ + ('1100220030', '123'), + ('111010220030', '1123'), + ('121000220030', '12123'), + ('12100022003', '12123'), + ('012100022003', '12123'), + ('0121000220030', '12123'), + ('0000', ''), + ('0300120200220030', '312223'), +]) +def test_ctc_metrics(input, expected): + input = list(map(int, list(input))) + expected = list(map(int, list(expected))) + p, _ = CtcMetrics.ctc_label(input) + assert expected == p diff --git a/tests/test_mxnet.py b/tests/test_mxnet.py new file mode 100644 index 0000000..4297d84 --- /dev/null +++ b/tests/test_mxnet.py @@ -0,0 +1,20 @@ +# coding: utf-8 +import os +import sys +import mxnet as mx +import numpy as np +from mxnet import nd +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) + + +def test_nd(): + ele = np.reshape(np.array(range(2*3)), (2, 3)) + data = [ele, ele + 10] + new = nd.array([ele]) + assert new.shape == (1, 2, 3) + new = nd.array(data) + assert new.shape == (2, 2, 3) + print(new)