详细
+- 2023/04/01
+1. 增加新模型
+ - 关键信息抽取[LayoutLMv3](configs/kie/layoutlmv3/)
+
+- 2024/03/20
+1. 增加新模型
+ - OCR大模型[Vary-toy](configs/llm/vary/vary_toy.yaml),支持基于通义千问1.8B LLM的检测和OCR功能
+
- 2023/12/25
1. 增加新模型
- 表格识别[TableMaster](configs/table/table_master.yaml)
@@ -283,8 +299,8 @@ MindOCR提供了[数据格式转换工具](tools/dataset_converters) ,以支
- 2023/12/14
1. 增加新模型
- - 关键信息抽取[LayoutXLM SER](configs/kie/vi_layoutxlm)
- - 关键信息抽取[VI-LayoutXLM SER](configs/kie/layoutlm_series)
+ - 关键信息抽取[LayoutXLM](configs/kie/layoutxlm)
+ - 关键信息抽取[VI-LayoutXLM](configs/kie/vi_layoutxlm)
- 文本检测[PP-OCRv3 DBNet](configs/det/dbnet/db_mobilenetv3_ppocrv3.yaml)和文本识别[PP-OCRv3 SVTR](configs/rec/svtr/svtr_ppocrv3_ch.yaml),支持在线推理和微调训练
2. 添加更多基准数据集及其结果
- [XFUND](configs/kie/vi_layoutxlm/README_CN.md)
diff --git a/configs/kie/layoutlmv3/README.md b/configs/kie/layoutlmv3/README.md
new file mode 100644
index 000000000..cc18fd725
--- /dev/null
+++ b/configs/kie/layoutlmv3/README.md
@@ -0,0 +1,250 @@
+English | [中文](README_CN.md)
+
+# LayoutLMv3
+
+
+> [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387)
+
+
+## 1. Introduction
+Unlike previous LayoutLM series models, LayoutLMv3 does not rely on complex CNN or Faster R-CNN networks to represent images in its model architecture. Instead, it directly utilizes image blocks of document images, thereby greatly reducing parameters and avoiding complex document preprocessing such as manual annotation of target region boxes and document object detection. Its simple unified architecture and training objectives make LayoutLMv3 a versatile pretraining model suitable for both text-centric and image-centric document AI tasks.
+
+The experimental results demonstrate that LayoutLMv3 achieves better performance with fewer parameters on the following datasets:
+
+- Text-centric datasets: Form Understanding FUNSD dataset, Receipt Understanding CORD dataset, and Document Visual Question Answering DocVQA dataset.
+- Image-centric datasets: Document Image Classification RVL-CDIP dataset and Document Layout Analysis PubLayNet dataset.
+
+LayoutLMv3 also employs a text-image multimodal Transformer architecture to learn cross-modal representations. Text vectors are obtained by adding word vectors, one-dimensional positional vectors, and two-dimensional positional vectors of words. Text from document images and their corresponding two-dimensional positional information (layout information) are extracted using optical character recognition (OCR) tools. As adjacent words in text often convey similar semantics, LayoutLMv3 shares the two-dimensional positional vectors of adjacent words, while each word in LayoutLM and LayoutLMv2 has different two-dimensional positional vectors.
+
+The representation of image vectors typically relies on CNN-extracted feature grid features or Faster R-CNN-extracted region features, which increase computational costs or depend on region annotations. Therefore, the authors obtain image features by linearly mapping image blocks, a representation method initially proposed in ViT, which incurs minimal computational cost and does not rely on region annotations, effectively addressing the aforementioned issues. Specifically, the image is first resized to a uniform size (e.g., 224x224), then divided into fixed-size blocks (e.g., 16x16), and image features are obtained through linear mapping to form an image feature sequence, followed by addition of a learnable one-dimensional positional vector to obtain the image vector.[[1](#references)]
+
+
+
+
+
+ Figure 1. LayoutLMv3 architecture [1]
+
+
+## 2. Results
+
+
+### Accuracy
+
+
+According to our experiments, the performance and accuracy evaluation([Model Evaluation](#33-Model-Evaluation)) results of training ([Model Training](#32-Model-Training)) on the XFUND Chinese dataset are as follows:
+
+
+
+| **Model** | **Task** | **Context** | **Dateset** | **Model Params** | **Batch size** | **Graph train 1P (s/epoch)** | **Graph train 1P (ms/step)** | **Graph train 1P (FPS)** | **hmean** | **Config** | **Download** |
+| :----------: | :------: | :-------------: | :--------: | :--------: | :----------: | :--------------------------: | :--------------------------: | :----------------------: | :-------: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------: |
+| LayoutLMv3 | SER | D910x1-MS2.1-G | XFUND_zh | 265.8 M | 8 | 19.53 | 1094.86 | 7.37 | 91.88% | [yaml](../layoutlmv3/ser_layoutlmv3_xfund_zh.yaml) | ckpt(TODO) |
+
+
+
+
+
+## 3. Quick Start
+### 3.1 Preparation
+
+#### 3.1.1 Installation
+Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
+
+#### 3.1.2 Dataset Download
+
+[The XFUND dataset](https://github.com/doc-analysis/XFUND) is used as the experimental dataset. The XFUND dataset is a multilingual dataset proposed by Microsoft for the Knowledge-Intensive Extraction (KIE) task. It consists of seven datasets, each containing 149 training samples and 50 validation samples.
+
+Respectively: ZH (Chinese), JA (Japanese), ES (Spanish), FR (French), IT (Italian), DE (German), PT (Portuguese)
+
+a preprocessed [Chinese dataset](https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar) that can be directly used is provided for everyone to download.
+
+```bash
+mkdir train_data
+cd train_data
+wget https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar && tar -xf XFUND.tar
+cd ..
+```
+
+#### 3.1.3 Dataset Usage
+
+After decompression, the data folder structure is as follows:
+
+```bash
+ └─ zh_train/ Training set
+ ├── image/ Folder for storing images
+ ├── train.json Annotation information
+ └─ zh_val/ Validation set
+ ├── image/ Folder for storing images
+ ├── val.json Annotation information
+
+```
+
+The annotation format of this dataset is:
+
+```bash
+{
+ "height": 3508, # Image height
+ "width": 2480, # Image width
+ "ocr_info": [
+ {
+ "text": "邮政地址:", # Single text content
+ "label": "question", # Category of the text
+ "bbox": [261, 802, 483, 859], # Single text box
+ "id": 54, # Text index
+ "linking": [[54, 60]], # Relationships between the current text and other texts [question, answer]
+ "words": []
+ },
+ {
+ "text": "湖南省怀化市市辖区",
+ "label": "answer",
+ "bbox": [487, 810, 862, 859],
+ "id": 60,
+ "linking": [[54, 60]],
+ "words": []
+ }
+ ]
+}
+```
+
+**The data configuration for model training.**
+
+If you want to reproduce the training of the model, it is recommended to modify the dataset-related fields in the configuration YAML file as follows:
+
+```yaml
+...
+train:
+ ...
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Root directory of the training dataset
+ data_dir: XFUND/zh_train/image/ # Directory of the training dataset, concatenated with `dataset_root` to form the complete directory of the training dataset
+ label_file: XFUND/zh_train/train.json # Path to the label file of the training dataset, concatenated with `dataset_root` to form the complete path of the label file of the training dataset
+...
+eval:
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Root directory of the validation dataset
+ data_dir: XFUND/zh_val/image/ # Directory of the validation dataset, concatenated with `dataset_root` to form the complete directory of the validation dataset
+ label_file: XFUND/zh_val/val.json # Path to the label file of the validation dataset, concatenated with `dataset_root` to form the complete path of the label file of the validation dataset
+ ...
+
+```
+
+#### 3.1.4 Check YAML Config Files
+Apart from the dataset setting, please also check the following important args: `system.distribute`, `system.val_while_train`, `common.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_path`, `eval.ckpt_load_path`, `eval.dataset.dataset_path`, `eval.loader.batch_size`. Explanations of these important args:
+
+```yaml
+system:
+ mode:
+ distribute: False # `True` for distributed training, `False` for standalone training
+ amp_level: 'O0'
+ seed: 42
+ val_while_train: True # Validate while training
+ drop_overflow_update: False
+model:
+ type: kie
+ transform: null
+ backbone:
+ name: layoutlmv3
+ head:
+ name: TokenClassificationHead
+ num_classes: 7
+ use_visual_backbone: True
+ use_float16: True
+ pretrained:
+...
+train:
+ ckpt_save_dir: './tmp_kie_ser' # The training result (including checkpoints, per-epoch performance and curves) saving directory
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Path of training dataset
+ data_dir: XFUND/zh_train/image/ # Path of training dataset data dir
+ label_file: XFUND/zh_train/train.json # Path of training dataset label file
+...
+eval:
+ ckpt_load_path: './tmp_kie_ser/best.ckpt' # checkpoint file path
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Path of evaluation dataset
+ data_dir: XFUND/zh_val/image/ # Path of evaluation dataset data dir
+ label_file: XFUND/zh_val/val.json # Path of evaluation dataset label file
+...
+ ...
+...
+```
+
+**Notes:**
+- As the global batch size (batch_size x num_devices) is important for reproducing the result, please adjust `batch_size` accordingly to keep the global batch size unchanged for a different number of GPUs/NPUs, or adjust the learning rate linearly to a new global batch size.
+
+
+### 3.2 Model Training
+
+* Distributed Training
+
+It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please modify the configuration parameter `distribute` as True and run:
+
+```shell
+# distributed training on multiple GPU/Ascend devices
+mpirun --allow-run-as-root -n 8 python tools/train.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please modify the configuration parameter`distribute` as False and run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python tools/train.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `ckpt_save_dir`. The default directory is `./tmp_kie_ser`.
+
+### 3.3 Model Evaluation
+
+To evaluate the accuracy of the trained model, you can use `eval.py`. Please set the checkpoint path to the arg `ckpt_load_path` in the `eval` section of yaml config file, set `distribute` to be False, and then run:
+
+```
+python tools/eval.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+### 3.4 Model Inference
+
+To perform inference using a pre-trained model, you can utilize `tools/infer/text/predict_ser.py` for inference and visualize the results.
+
+```
+python tools/infer/text/predict_ser.py --rec_algorithm CRNN_CH --image_dir {dir of images or path of image}
+```
+
+As an example of entity recognition in Chinese forms, use the script to recognize entities in the form of `configs/kie/vi_layoutxlm/example.jpg`. The results will be stored in the `./inference_results` folder by default, and you can also customize the result storage path through the `--draw_img_save_dir` command-line parameter.
+
+
+
+
+
+ example.jpg
+
+Recognition results are as shown in the image, and the image is saved as`inference_results/example_ser.jpg`:
+
+
+
+
+
+ example_ser.jpg
+
+
+
+
+## References
+
+
+[1] Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking. arXiv preprint arXiv:2204.08387, 2022.
diff --git a/configs/kie/layoutlmv3/README_CN.md b/configs/kie/layoutlmv3/README_CN.md
new file mode 100644
index 000000000..1f2628ad9
--- /dev/null
+++ b/configs/kie/layoutlmv3/README_CN.md
@@ -0,0 +1,246 @@
+[English](README.md) | 中文
+
+# LayoutLMv3
+
+
+> [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387)
+
+## 1. 模型描述
+
+
+不同于以往的LayoutLM系列模型,在模型架构设计上,LayoutLMv3 不依赖复杂的 CNN 或 Faster R-CNN 网络来表征图像,而是直接利用文档图像的图像块,从而大大节省了参数并避免了复杂的文档预处理(如人工标注目标区域框和文档目标检测)。简单的统一架构和训练目标使 LayoutLMv3 成为通用的预训练模型,可适用于以文本为中心和以图像为中心的文档 AI 任务。
+
+实验结果表明,LayoutLMv3在以下数据集以更少的参数量达到了更优的性能:
+- 以文本为中心的数据集:表单理解FUNSD数据集、票据理解CORD数据集以及文档视觉问答DocVQA数据集。
+- 以图像为中心的数据集:文档图像分类RVL-CDIP数据集以及文档布局分析PubLayNet数据集。
+
+LayoutLMv3 还应用了文本——图像多模态 Transformer 架构来学习跨模态表征。文本向量由词向量、词的一维位置向量和二维位置向量相加得到。文档图像的文本和其相应的二维位置信息(布局信息)则利用光学字符识别(OCR)工具抽取。因为文本的邻接词通常表达了相似的语义,LayoutLMv3 共享了邻接词的二维位置向量,而 LayoutLM 和 LayoutLMv2 的每个词则用了不同的二维位置向量。
+
+图像向量的表示通常依赖于 CNN 抽取特征图网格特征或 Faster R-CNN 提取区域特征,这些方式增加了计算开销或依赖于区域标注。因此,作者将图像块经过线性映射获得图像特征,这种图像表示方式最早在 ViT 中被提出,计算开销极小且不依赖于区域标注,有效解决了以上问题。具体来说,首先将图像缩放为统一的大小(例如224x224),然后将图像切分成固定大小的块(例如16x16),并通过线性映射获得图像特征序列,再加上可学习的一维位置向量后得到图像向量。[1]
+
+
+
+
+
+
+
+ 图1. LayoutLMv3架构图 [1]
+
+
+## 2. 评估结果
+
+
+### 训练端
+
+根据我们的实验,在XFUND中文数据集上训练([模型训练](#32-模型训练))性能和精度评估([模型评估](#33-模型评估))结果如下:
+
+
+
+| **模型** | **任务** | **环境配置** | **训练集** | **参数量** | **单卡批量** | **图模式单卡训练 (s/epoch)** | **图模式单卡训练 (ms/step)** | **图模式单卡训练 (FPS)** | **hmean** | **配置文件** | **模型权重下载** |
+| :----------: | :------: | :-------------: | :--------: | :--------: | :----------: | :--------------------------: | :--------------------------: | :----------------------: | :-------: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------: |
+| LayoutLMv3 | SER | D910x1-MS2.1-G | XFUND_zh | 265.8 M | 8 | 19.53 | 1094.86 | 7.37 | 91.88% | [yaml](../layoutlmv3/ser_layoutlmv3_xfund_zh.yaml) | ckpt(TODO) |
+
+
+
+
+## 3. 快速开始
+### 3.1 环境及数据准备
+
+#### 3.1.1 安装
+环境安装教程请参考MindOCR的 [installation instruction](https://github.com/mindspore-lab/mindocr#installation).
+
+#### 3.1.2 数据集下载
+这里使用[XFUND数据集](https://github.com/doc-analysis/XFUND)做为实验数据集。 XFUN数据集是微软提出的一个用于KIE任务的多语言数据集,共包含七个数据集,每个数据集包含149张训练集和50张验证集
+
+分别为:ZH(中文)、JA(日语)、ES(西班牙)、FR(法语)、IT(意大利)、DE(德语)、PT(葡萄牙)
+
+这里提供了已经过预处理,可以直接用于训练的[中文数据集](https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar)下载。
+
+```bash
+mkdir train_data
+cd train_data
+wget https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar && tar -xf XFUND.tar
+cd ..
+```
+
+#### 3.1.3 数据集使用
+
+解压文件后,数据文件夹结构如下:
+
+```bash
+ └─ zh_train/ 训练集
+ ├── image/ 图片存放文件夹
+ ├── train.json 标注信息
+ └─ zh_val/ 验证集
+ ├── image/ 图片存放文件夹
+ ├── val.json 标注信息
+
+```
+
+该数据集的标注格式为
+
+```bash
+{
+ "height": 3508, # 图像高度
+ "width": 2480, # 图像宽度
+ "ocr_info": [
+ {
+ "text": "邮政地址:", # 单个文本内容
+ "label": "question", # 文本所属类别
+ "bbox": [261, 802, 483, 859], # 单个文本框
+ "id": 54, # 文本索引
+ "linking": [[54, 60]], # 当前文本和其他文本的关系 [question, answer]
+ "words": []
+ },
+ {
+ "text": "湖南省怀化市市辖区",
+ "label": "answer",
+ "bbox": [487, 810, 862, 859],
+ "id": 60,
+ "linking": [[54, 60]],
+ "words": []
+ }
+ ]
+}
+```
+
+**模型训练的数据配置**
+
+如欲重现模型的训练,建议修改配置yaml的数据集相关字段如下:
+
+```yaml
+...
+train:
+ ...
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # 训练数据集根目录
+ data_dir: XFUND/zh_train/image/ # 训练数据集目录,将与`dataset_root`拼接形成完整训练数据集目录
+ label_file: XFUND/zh_train/train.json # 训练数据集的标签文件路径,将与`dataset_root`拼接形成完整的训练数据的标签文件路径。
+...
+eval:
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # 验证数据集根目录
+ data_dir: XFUND/zh_val/image/ # 验证数据集目录,将与`dataset_root`拼接形成完整验证数据集目录
+ label_file: XFUND/zh_val/val.json # 验证数据集的标签文件路径,将与`dataset_root`拼接形成完整的验证或评估数据的标签文件路径。
+ ...
+```
+
+#### 3.1.4 检查配置文件
+除了数据集的设置,请同时重点关注以下配置项:`system.distribute`, `system.val_while_train`, `train.loader.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_root`, `train.dataset.data_dir`, `train.dataset.label_file`,
+`eval.ckpt_load_path`, `eval.dataset.dataset_root`, `eval.dataset.data_dir`, `eval.dataset.label_file`, `eval.loader.batch_size`。说明如下:
+
+```yaml
+system:
+ mode:
+ distribute: False # 分布式训练为True,单卡训练为False
+ amp_level: 'O0'
+ seed: 42
+ val_while_train: True # 边训练边验证
+ drop_overflow_update: False
+model:
+ type: kie
+ transform: null
+ backbone:
+ name: layoutlmv3
+ pretrained: False
+ checkpoints: path/to/layoutlmv3.ckpt # 导入ckpt位置
+ num_classes: &num_classes 7
+ mode: vi
+...
+train:
+ ckpt_save_dir: './tmp_kie_ser' # 训练结果(包括checkpoint、每个epoch的性能和曲线图)保存目录
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # 训练数据集根目录
+ data_dir: XFUND/zh_train/image/ # 训练数据集目录,将与`dataset_root`拼接形成完整训练数据集目录
+ label_file: XFUND/zh_train/train.json # 训练数据集的标签文件路径,将与`dataset_root`拼接形成完整的训练数据的标签文件路径。
+...
+eval:
+ ckpt_load_path: './tmp_kie_ser/best.ckpt' # checkpoint文件路径
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # 验证数据集根目录
+ data_dir: XFUND/zh_val/image/ # 验证数据集目录,将与`dataset_root`拼接形成完整验证数据集目录
+ label_file: XFUND/zh_val/val.json # 验证数据集的标签文件路径,将与`dataset_root`拼接形成完整的验证或评估数据的标签文件路径。
+ ...
+...
+```
+
+**注意:**
+- 由于全局批大小 (batch_size x num_devices) 是对结果复现很重要,因此当GPU/NPU卡数发生变化时,调整`batch_size`以保持全局批大小不变,或根据新的全局批大小线性调整学习率。
+
+
+### 3.2 模型训练
+
+* 多卡数据并行训练
+
+使用预定义的训练配置可以轻松重现报告的结果。对于在多个昇腾910设备上的分布式训练,请将配置参数`distribute`修改为True,并运行:
+
+```shell
+# 在多个 GPU/Ascend 设备上进行分布式训练
+mpirun --allow-run-as-root -n 8 python tools/train.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+
+* 单卡训练
+
+如果要在没有分布式训练的情况下在较小的数据集上训练或微调模型,请将配置参数`distribute`修改为False 并运行:
+
+```shell
+# CPU/GPU/Ascend 设备上的单卡训练
+python tools/train.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+训练结果(包括checkpoint、每个epoch的性能和曲线图)将被保存在yaml配置文件的`ckpt_save_dir`参数配置的目录下,默认为`./tmp_kie_ser`。
+
+### 3.3 模型评估
+
+若要评估已训练模型的准确性,可以使用`eval.py`。请在yaml配置文件的`eval`部分将参数`ckpt_load_path`设置为模型checkpoint的文件路径,然后运行:
+
+```
+python tools/eval.py --config configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
+```
+
+### 3.4 模型推理
+
+若要使用已训练的模型进行推理,可使用`tools/infer/text/predict_ser.py`进行推理并将结果进行可视化展示。
+
+```
+python tools/infer/text/predict_ser.py --rec_algorithm CRNN_CH --image_dir {dir of images or path of image}
+```
+
+以中文表单的实体识别为例,使用脚本识别`configs/kie/vi_layoutxlm/example.jpg`表单中的实体,结果将默认存放在`./inference_results`文件夹内,也可以通过`--draw_img_save_dir`命令行参数自定义结果存储路径。
+
+
+
+
+
+ example.jpg
+
+识别结果如图,图片保存为`inference_results/example_ser.jpg`:
+
+
+
+
+
+ example_ser.jpg
+
+
+
+
+## 参考文献
+
+
+[1] Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei. LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking. arXiv preprint arXiv:2204.08387, 2022.
diff --git a/configs/kie/layoutlmv3/layoutlmv3_arch.jpg b/configs/kie/layoutlmv3/layoutlmv3_arch.jpg
new file mode 100644
index 000000000..9f23f3fc9
Binary files /dev/null and b/configs/kie/layoutlmv3/layoutlmv3_arch.jpg differ
diff --git a/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml b/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
new file mode 100644
index 000000000..b42c84116
--- /dev/null
+++ b/configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
@@ -0,0 +1,128 @@
+system:
+ mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
+ distribute: False
+ amp_level: "O0"
+ seed: 42
+ log_interval: 10
+ val_start_epoch: 50
+ val_while_train: True
+ drop_overflow_update: False
+
+model:
+ type: kie
+ transform: null
+ backbone:
+ name: layoutlmv3
+ head:
+ name: TokenClassificationHead
+ num_classes: 7
+ use_visual_backbone: True
+ use_float16: True
+ pretrained:
+
+postprocess:
+ name: VQASerTokenLayoutLMPostProcess
+ class_path: &class_path
+
+metric:
+ name: VQASerTokenMetric
+ main_indicator: hmean
+
+loss:
+ name: VQASerTokenLayoutLMLoss
+ num_classes: 7
+
+scheduler:
+ scheduler: polynomial_decay
+ lr: 5.0e-5
+ min_lr: 2.0e-7
+ num_epochs: 200
+ warmup_epochs: 2
+
+optimizer:
+ opt: adam
+ filter_bias_and_bn: False
+ weight_decay: 0.0005
+
+train:
+ ckpt_save_dir: "./tmp_kie_ser"
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/train_data/
+ data_dir: XFUND/zh_train/image
+ label_file: XFUND/zh_train/train.json
+ sample_ratio: 1.0
+ transform_pipeline:
+ - DecodeImage:
+ img_mode: RGB
+ to_float32: False
+ - VQATokenLabelEncode:
+ contains_re: False
+ algorithm: &algorithm LayoutLMv3
+ class_path: *class_path
+ - VQATokenPad:
+ max_seq_len: &max_seq_len 512
+ return_attention_mask: True
+ - VQASerTokenChunk:
+ max_seq_len: *max_seq_len
+ - LayoutResize:
+ size: [224, 224]
+ - NormalizeImage:
+ bgr_to_rgb: False
+ is_hwc: True
+ mean: imagenet
+ std: imagenet
+ - ToCHWImage:
+ # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
+ output_columns: [ "input_ids", "bbox", "attention_mask", "token_type_ids", "image", "labels"]
+ net_input_column_index: [0, 1, 2, 3, 4] # input indices for network forward func in output_columns
+ label_column_index: [2, 5] # input indices marked as label
+
+ loader:
+ shuffle: True
+ batch_size: 8
+ drop_remainder: True
+ num_workers: 8
+
+eval:
+ ckpt_load_path: "./tmp_kie_ser/best.ckpt"
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/train_data/
+ data_dir: XFUND/zh_val/image
+ label_file: XFUND/zh_val/val.json
+ sample_ratio: 1.0
+ shuffle: False
+ transform_pipeline:
+ - DecodeImage:
+ img_mode: RGB
+ to_float32: False
+ - VQATokenLabelEncode:
+ contains_re: False
+ algorithm: *algorithm
+ class_path: *class_path
+ - VQATokenPad:
+ max_seq_len: *max_seq_len
+ return_attention_mask: True
+ - VQASerTokenChunk:
+ max_seq_len: *max_seq_len
+ - LayoutResize:
+ size: [224, 224]
+ - NormalizeImage:
+ bgr_to_rgb: False
+ is_hwc: True
+ mean: imagenet
+ std: imagenet
+ - ToCHWImage:
+ # the order of the dataloader list, matching the network input and the labels for evaluation
+ output_columns: ["input_ids", "bbox", "attention_mask", "token_type_ids", "image", "labels"]
+ net_input_column_index: [0, 1, 2, 3, 4] # input indices for network forward func in output_columns
+ label_column_index: [2, 5] # input indices marked as label
+
+ loader:
+ shuffle: False
+ batch_size: 1
+ drop_remainder: False
+ num_workers: 1
diff --git a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yaml b/configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml
similarity index 100%
rename from configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yaml
rename to configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml
diff --git a/configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml b/configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml
similarity index 100%
rename from configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yaml
rename to configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml
diff --git a/configs/kie/vi_layoutxlm/README.md b/configs/kie/vi_layoutxlm/README.md
new file mode 100644
index 000000000..5326f11f5
--- /dev/null
+++ b/configs/kie/vi_layoutxlm/README.md
@@ -0,0 +1,283 @@
+English| [中文](README_CN.md)
+
+# LayoutXLM
+
+
+> [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836)
+
+## 1. Introduction
+
+****
+LayoutXLM is the multilingual version of LayoutLMv2[2]. Unlike the original LayoutLM, which integrates image embeddings during the fine-tuning stage, LayoutXLM integrates visual information during the pre-training stage and utilizes a Transformer architecture to learn cross-modal interactions between text and images. Additionally, inspired by 1-D relative positional representation, the paper proposes a spatial-aware self-attention mechanism, which provides 2-D relative positional representation for token pairs. Unlike using absolute 2-D position embeddings to model document layout, relative positional embeddings can provide a larger receptive field for modeling contextual spatial relationships clearly.
+
+As shown in the architecture diagram [Figure 1](#-Multi-modal-Encoder-with-Spatial-Aware-Self-Attention-Mechanism), LayoutXLM (LayoutLMv2) adopts a multimodal Transformer architecture as its backbone. The backbone takes text, image, and layout information as input, establishing deep cross-modal interactions. At the same time, it introduces the spatial-aware self-attention mechanism, allowing the model to better model document layout.
+
+### Text Embedding
+Tokenizing the OCR text sequence with WordPiece, each token is marked as {[A], [B]}. Then, [CLS] is added to the beginning of the sequence, and [SEP] is added to the end of each text segment. Additional [PAD] tokens are added to the end of the sequence to match the maximum sequence length, denoted as L. The final text embedding is the sum of three embeddings: token embedding representing the token itself, 1-D position embedding representing the token index, and segment embedding used to distinguish different text segments.
+
+### Visual Embedding
+Although all the required information is present in the page image, the model finds it challenging to capture detailed features through a single information-rich representation. Therefore, leveraging a CNN-based visual encoder outputs the page feature map, which also converts the page image into a fixed-length sequence. Using the ResNeXt-FPN architecture as the backbone, its parameters can be trained through backpropagation.
+
+For a given page image I, it is resized to 224×224 before entering the visual backbone. The output feature map is then average-pooled to a fixed size: width W and height H. Afterwards, it is flattened into a visual embedding sequence of length W×H, and its dimension is aligned with the text embedding through a linear projection layer. Since the CNN-based visual backbone cannot acquire position information, 1-D position embedding is also added, which is shared with the text embedding. For segment embedding, all visual tokens are assigned to [C].
+
+### Layout Embedding
+The layout embedding layer is used to represent spatial layout information, which originates from the axis-aligned token bounding boxes obtained from OCR recognition, including the length, width, and coordinates of the boxes. Following the approach of LayoutLM, the coordinates are normalized and discretized, rounding them to integers between 0 and 1000. Two embedding layers are used to embed features along the x-axis and y-axis, respectively.
+
+Given a normalized bounding box with xmin, xmax, ymin, ymax, width, and height, the layout embedding layer concatenates the six bounding box features to construct a 2-D position embedding, which is the layout embedding. Since CNN supports local transformations, image token embeddings can be mapped back to the original image one-to-one, without overlapping or missing tokens. Therefore, when calculating bounding boxes, visual tokens can be assigned to the corresponding grid. For special tokens such as [CLS], [SEP], and [PAD] in the text embedding, zero features for bounding boxes are appended.
+
+### Multi-modal Encoder with Spatial-Aware Self-Attention Mechanism
+The encoder concatenates visual embeddings and text embeddings into a unified sequence and adds them to the layout embeddings to blend spatial information. Following the Transformer architecture, the model constructs a multimodal encoder with a stack of multi-head self-attention layers followed by feed-forward networks. However, the original self-attention mechanism only captures absolute positional relationships between input tokens. To effectively model local invariance in document layout, it is necessary to explicitly insert relative positional information. Therefore, we propose the spatial-aware self-attention mechanism and incorporate it into the self-attention layer.
+
+After obtaining αij from the original self-attention layer, considering the large range of positions, we model semantic relative positions and spatial relative positions as bias terms to avoid introducing too many parameters. We use three biases to represent learnable 1-D and 2-D (x, y) relative positional biases. These biases are different for each attention head but consistent across layers. Assuming a bounding box (xi, yi), the three biases are added to αij to obtain the self-attention map, and finally, the final attention scores are computed in the manner of Transformer.
+ [1] [2]
+
+
+
+
+
+
+
+ Figure 1. LayoutXLM(LayoutLMv2) architecture [1]
+
+
+## 2. Results
+
+
+### Accuracy
+
+According to our experiments, the performance and accuracy evaluation([Model Evaluation](#33-Model-Evaluation)) results of training ([Model Training](#32-Model-Training)) on the XFUND Chinese dataset are as follows:
+
+
+
+| **Model** | **Task** | **Context** | **Dateset** | **Model Params** | **Batch size** | **Graph train 1P (s/epoch)** | **Graph train 1P (ms/step)** | **Graph train 1P (FPS)** | **hmean** | **Config** | **Download** |
+| :----------: | :------: | :-------------: | :--------: | :--------: | :----------: | :--------------------------: | :--------------------------: | :----------------------: | :-------: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------: |
+| LayoutXLM | SER | D910Ax1-MS2.1-G | XFUND_zh | 352.0 M | 8 | 3.41 | 189.50 | 42.24 | 90.41% | [yaml](../layoutxlm/ser_layoutxlm_xfund_zh.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/layoutxlm/ser_layoutxlm_base-a4ea148e.ckpt) |
+| VI-LayoutXLM | SER | D910Ax1-MS2.1-G | XFUND_zh | 265.7 M | 8 | 3.06 | 169.7 | 47.2 | 93.31% | [yaml](ser_vi_layoutxlm_xfund_zh.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/ser_vi_layoutxlm-f3c83585.ckpt) |
+
+
+
+
+
+## 3. Quick Start
+### 3.1 Preparation
+
+#### 3.1.1 Installation
+Please refer to the [installation instruction](https://github.com/mindspore-lab/mindocr#installation) in MindOCR.
+
+#### 3.1.2 Dataset Download
+
+[The XFUND dataset](https://github.com/doc-analysis/XFUND) is used as the experimental dataset. The XFUND dataset is a multilingual dataset proposed by Microsoft for the Knowledge-Intensive Extraction (KIE) task. It consists of seven datasets, each containing 149 training samples and 50 validation samples.
+
+Respectively: ZH (Chinese), JA (Japanese), ES (Spanish), FR (French), IT (Italian), DE (German), PT (Portuguese)
+
+a preprocessed [Chinese dataset](https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar) that can be directly used is provided for everyone to download.
+
+```bash
+mkdir train_data
+cd train_data
+wget https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/XFUND.tar && tar -xf XFUND.tar
+cd ..
+```
+
+#### 3.1.3 Dataset Usage
+
+After decompression, the data folder structure is as follows:
+
+```bash
+ └─ zh_train/ Training set
+ ├── image/ Folder for storing images
+ ├── train.json Annotation information
+ └─ zh_val/ Validation set
+ ├── image/ Folder for storing images
+ ├── val.json Annotation information
+
+```
+
+The annotation format of this dataset is:
+
+```bash
+{
+ "height": 3508, # Image height
+ "width": 2480, # Image width
+ "ocr_info": [
+ {
+ "text": "邮政地址:", # Single text content
+ "label": "question", # Category of the text
+ "bbox": [261, 802, 483, 859], # Single text box
+ "id": 54, # Text index
+ "linking": [[54, 60]], # Relationships between the current text and other texts [question, answer]
+ "words": []
+ },
+ {
+ "text": "湖南省怀化市市辖区",
+ "label": "answer",
+ "bbox": [487, 810, 862, 859],
+ "id": 60,
+ "linking": [[54, 60]],
+ "words": []
+ }
+ ]
+}
+```
+
+**The data configuration for model training.**
+
+If you want to reproduce the training of the model, it is recommended to modify the dataset-related fields in the configuration YAML file as follows:
+
+```yaml
+...
+train:
+ ...
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Root directory of the training dataset
+ data_dir: XFUND/zh_train/image/ # Directory of the training dataset, concatenated with `dataset_root` to form the complete directory of the training dataset
+ label_file: XFUND/zh_train/train.json # Path to the label file of the training dataset, concatenated with `dataset_root` to form the complete path of the label file of the training dataset
+...
+eval:
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Root directory of the validation dataset
+ data_dir: XFUND/zh_val/image/ # Directory of the validation dataset, concatenated with `dataset_root` to form the complete directory of the validation dataset
+ label_file: XFUND/zh_val/val.json # Path to the label file of the validation dataset, concatenated with `dataset_root` to form the complete path of the label file of the validation dataset
+ ...
+
+```
+
+#### 3.1.4 Check YAML Config Files
+Apart from the dataset setting, please also check the following important args: `system.distribute`, `system.val_while_train`, `common.batch_size`, `train.ckpt_save_dir`, `train.dataset.dataset_path`, `eval.ckpt_load_path`, `eval.dataset.dataset_path`, `eval.loader.batch_size`. Explanations of these important args:
+
+```yaml
+system:
+ mode:
+ distribute: False # `True` for distributed training, `False` for standalone training
+ amp_level: 'O0'
+ seed: 42
+ val_while_train: True # Validate while training
+ drop_overflow_update: False
+model:
+ type: kie
+ transform: null
+ backbone:
+ name: layoutxlm
+ pretrained: True
+ num_classes: &num_classes 7
+ use_visual_backbone: False
+ use_float16: True
+ head :
+ name: TokenClassificationHead
+ num_classes: 7
+ use_visual_backbone: False
+ use_float16: True
+ pretrained:
+...
+train:
+ ckpt_save_dir: './tmp_kie_ser' # The training result (including checkpoints, per-epoch performance and curves) saving directory
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Path of training dataset
+ data_dir: XFUND/zh_train/image/ # Path of training dataset data dir
+ label_file: XFUND/zh_train/train.json # Path of training dataset label file
+...
+eval:
+ ckpt_load_path: './tmp_kie_ser/best.ckpt' # checkpoint file path
+ dataset_sink_mode: False
+ dataset:
+ type: KieDataset
+ dataset_root: path/to/dataset/ # Path of evaluation dataset
+ data_dir: XFUND/zh_val/image/ # Path of evaluation dataset data dir
+ label_file: XFUND/zh_val/val.json # Path of evaluation dataset label file
+...
+ ...
+...
+```
+
+**Notes:**
+- As the global batch size (batch_size x num_devices) is important for reproducing the result, please adjust `batch_size` accordingly to keep the global batch size unchanged for a different number of GPUs/NPUs, or adjust the learning rate linearly to a new global batch size.
+
+
+### 3.2 Model Training
+
+* Convert PaddleOCR model
+
+If you want to import the PaddleOCR LayoutXLM model, you can use the `tools/param_converter.py` script to convert the pdparams file to the ckpt format supported by MindSpore, and then import it for further training.
+
+```shell
+python tools/param_converter.py \
+ --input_path path/to/paddleocr.pdparams \
+ --json_path mindocr/models/backbones/layoutxlm/ser_vi_layoutxlm_param_map.json \
+ --output_path path/to/from_paddle.ckpt
+```
+
+* Distributed Training
+
+It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please modify the configuration parameter `distribute` as True and run:
+
+```shell
+# distributed training on multiple GPU/Ascend devices
+mpirun --allow-run-as-root -n 8 python tools/train.py --config configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
+```
+
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please modify the configuration parameter`distribute` as False and run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python tools/train.py --config configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
+```
+
+The training result (including checkpoints, per-epoch performance and curves) will be saved in the directory parsed by the arg `ckpt_save_dir`. The default directory is `./tmp_kie_ser`.
+
+### 3.3 Model Evaluation
+
+To evaluate the accuracy of the trained model, you can use `eval.py`. Please set the checkpoint path to the arg `ckpt_load_path` in the `eval` section of yaml config file, set `distribute` to be False, and then run:
+
+```
+python tools/eval.py --config configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
+```
+
+
+### 3.4 Model Inference
+
+To perform inference using a pre-trained model, you can utilize `tools/infer/text/predict_ser.py` for inference and visualize the results.
+
+```
+python tools/infer/text/predict_ser.py --rec_algorithm CRNN_CH --image_dir {dir of images or path of image}
+```
+
+As an example of entity recognition in Chinese forms, use the script to recognize entities in the form of `configs/kie/vi_layoutxlm/example.jpg`. The results will be stored in the `./inference_results` folder by default, and you can also customize the result storage path through the `--draw_img_save_dir` command-line parameter.
+
+
+
+
+
+ example.jpg
+
+Recognition results are as shown in the image, and the image is saved as`inference_results/example_ser.jpg`:
+
+
+
+
+
+ example_ser.jpg
+
+
+
+
+## References
+
+
+[1] Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding. arXiv preprint arXiv:2012.14740, 2020.
+
+[2] Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding. arXiv preprint arXiv:2104.08836, 2021.
diff --git a/configs/kie/vi_layoutxlm/README_CN.md b/configs/kie/vi_layoutxlm/README_CN.md
index a0fbd468f..f6f843807 100644
--- a/configs/kie/vi_layoutxlm/README_CN.md
+++ b/configs/kie/vi_layoutxlm/README_CN.md
@@ -48,43 +48,17 @@ Table Format:
### 训练端
-根据我们的实验,训练([模型训练](#32-模型训练))性能和精度评估([模型评估](#33-模型评估))结果如下:
+根据我们的实验,在XFUND中文数据集上训练([模型训练](#32-模型训练))性能和精度评估([模型评估](#33-模型评估))结果如下:
| **模型** | **任务** | **环境配置** | **训练集** | **参数量** | **单卡批量** | **图模式单卡训练 (s/epoch)** | **图模式单卡训练 (ms/step)** | **图模式单卡训练 (FPS)** | **hmean** | **配置文件** | **模型权重下载** |
| :----------: | :------: | :-------------: | :--------: | :--------: | :----------: | :--------------------------: | :--------------------------: | :----------------------: | :-------: | :----------------------------------------------------: | :------------------------------------------------------------------------------------------------: |
-| LayoutXLM | SER | D910Ax1-MS2.1-G | XFUND_zh | 352.0 M | 8 | 3.41 | 189.50 | 42.24 | 90.41% | [yaml](../layoutlm_series/ser_layoutxlm_xfund_zh.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/layoutxlm/ser_layoutxlm_base-a4ea148e.ckpt) |
+| LayoutXLM | SER | D910Ax1-MS2.1-G | XFUND_zh | 352.0 M | 8 | 3.41 | 189.50 | 42.24 | 90.41% | [yaml](../layoutxlm/ser_layoutxlm_xfund_zh.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/layoutxlm/ser_layoutxlm_base-a4ea148e.ckpt) |
| VI-LayoutXLM | SER | D910Ax1-MS2.1-G | XFUND_zh | 265.7 M | 8 | 3.06 | 169.7 | 47.2 | 93.31% | [yaml](ser_vi_layoutxlm_xfund_zh.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/vi-layoutxlm/ser_vi_layoutxlm-f3c83585.ckpt) |
-### 3.4 模型推理
-
-若要使用已训练的模型进行推理,可使用`tools/infer/text/predict_ser.py`进行推理并将结果进行可视化展示。
-
-```
-python tools/infer/text/predict_ser.py --rec_algorithm CRNN_CH --image_dir {dir of images or path of image}
-```
-
-以中文表单的实体识别为例,使用脚本识别`configs/kie/vi_layoutxlm/example.jpg`表单中的实体,结果将默认存放在`./inference_results`文件夹内,也可以通过`--draw_img_save_dir`命令行参数自定义结果存储路径。
-
-
-
-
-
- example.jpg
-
-识别结果如图,图片保存为`inference_results/example_ser.jpg`:
-
-
-
-
-
- example_ser.jpg
-
-
-
### 推理端
TODO
@@ -190,11 +164,17 @@ model:
type: kie
transform: null
backbone:
- name: layoutxlm_for_ser
- pretrained: False
- checkpoints: path/to/ser_vi_layoutxlm.ckpt # 导入ckpt位置
+ name: layoutxlm
+ pretrained: True
num_classes: &num_classes 7
- mode: vi
+ use_visual_backbone: False
+ use_float16: True
+ head :
+ name: TokenClassificationHead
+ num_classes: 7
+ use_visual_backbone: False
+ use_float16: True
+ pretrained:
...
train:
ckpt_save_dir: './tmp_kie_ser' # 训练结果(包括checkpoint、每个epoch的性能和曲线图)保存目录
@@ -264,6 +244,32 @@ python tools/eval.py --config configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh
```
+### 3.4 模型推理
+
+若要使用已训练的模型进行推理,可使用`tools/infer/text/predict_ser.py`进行推理并将结果进行可视化展示。
+
+```
+python tools/infer/text/predict_ser.py --rec_algorithm CRNN_CH --image_dir {dir of images or path of image}
+```
+
+以中文表单的实体识别为例,使用脚本识别`configs/kie/vi_layoutxlm/example.jpg`表单中的实体,结果将默认存放在`./inference_results`文件夹内,也可以通过`--draw_img_save_dir`命令行参数自定义结果存储路径。
+
+
+
+
+
+ example.jpg
+
+识别结果如图,图片保存为`inference_results/example_ser.jpg`:
+
+
+
+
+
+ example_ser.jpg
+
+
+
## 4. MindSpore Lite 推理
**TODO**
diff --git a/configs/llm/vary/README.md b/configs/llm/vary/README.md
new file mode 100644
index 000000000..ece3db84a
--- /dev/null
+++ b/configs/llm/vary/README.md
@@ -0,0 +1,125 @@
+English | [中文](README_CN.md)
+
+# Vary-toy
+
+> [Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models](https://arxiv.org/abs/2312.06109)
+> [Small Language Model Meets with Reinforced Vision Vocabulary](https://arxiv.org/abs/2401.12503)
+
+## 1. Model Description
+
+Vary is an effective method to extend the visual vocabulary of the Large Visual Language Model (LVLM). Vary is divided into two parts: the generation and integration of new visual vocabulary. In the first stage, Vary designed a "vocabulary network" and a small decoder Transformer, which produced the required vocabulary through autoregression. Vary then scales up the common visual vocabulary by merging the new visual vocabulary with the original visual vocabulary (CLIP), enabling LVLM to quickly acquire new functionality. Vary-toy is a smaller version of Vary's official open source.
+
+## 2. Assessment Result
+
+According to our experiments, the inference performance of Vary-toy is as follows:
+
+| **model** | **Environment Configuration** | **Total Time** | **Token generation speed** | **Configuration File** |
+|:---------:|:-----------------------------:|:--------------:|:--------------------------:|:------------------------------------------------------------------------------------------:|
+| Vary-toy | D910x1-MS2.2-G | 23.38 s | 30.75 tokens/s | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/llm/vary/vary_toy.yaml) |
+
+**Note:**
+
+- Environment configuration: The training environment configuration is represented as \{processor\} x \{processor quantity\} - \{MS mode\}. The MindSpore mode can be G-graph mode or F-pynative mode. For example, D910x1-MS2.2-G uses graph mode to train on one 910 NPU depending on MindSpore 2.2.
+- If you need to reproduce the training result in another environment, ensure that the global batch size is the same as that in the original configuration file.
+
+## 3. Quick Start
+
+### 3.1 Environment and Model Preparation
+
+#### 3.1.1 Installation
+
+Note: if you want to experience Vary-toy, you shall upgrade your python to 3.8 or above version.
+
+For details about environment installation, see MindOCR [installation instruction](https://github.com/mindspore-lab/mindocr#installation).
+
+In addition, you need to install the `tiktoken` using the following shell command:
+
+``` shell
+pip install tiktoken
+```
+
+#### 3.1.2 Configuration File
+
+Pay special attention to the configuration of the following variables:`seq_length`,`checkpoint_name_or_path`,`repetition_penalty`,`max_decode_length`,`max_new_tokens`,`vocab_file`. Description:
+
+```yaml
+model:
+ ...
+ seq_length: 2048 # Sentence length
+ checkpoint_name_or_path: "/path/to/vary_toy.ckpt" # Weight path
+ repetition_penalty: 1.5 # The penalty for generating duplicate values.
+ max_decode_length: 2048 # Maximum generated sentence length.
+ max_new_tokens: 1024 # Number of new tokens.
+ ...
+...
+processor:
+ ...
+ tokenizer:
+ vocab_file: "/path/to/qwen.tiktoken" # Path of the tokenizer
+ ...
+...
+```
+
+#### 3.1.3 Model Preparation
+
+Users can download tokenizer model from the following link:
+
+- [qwen.tiktoken](https://huggingface.co/HaoranWei/Vary-toy/blob/main/qwen.tiktoken)
+
+Users can download weights from the following link:
+
+- [Vary-toy](https://download-mindspore.osinfra.cn/toolkits/mindocr/vary/vary_toy-e62a3564.ckpt)
+
+Users can also download weights from the following huggingface link:
+
+- [Vary-toy](https://huggingface.co/HaoranWei/Vary-toy/blob/main/pytorch_model.bin)
+
+Then perform the weight conversion according to the following steps:
+
+Note: Install `torch` before starting the conversion script:
+
+```shell
+pip install torch
+```
+
+After the download is complete, run the mindocr/models/llm/convert_weight.py conversion script to convert the huggingface weight to the MindSpore ckpt weight.
+
+```shell
+python mindocr/models/llm/convert_weight.py \
+ --torch_ckpt_path="/path/to/pytorch_model.bin" \
+ --mindspore_ckpt_path="/path/to/vary_toy.ckpt"
+
+# Parameter description:
+# torch_ckpt_path: weight path for downloading huggingface.
+# mindspore_ckpt_path: path of the exported MindSpore weight.
+```
+
+### 3.2 Model Inference
+
+```shell
+python ./tools/infer/text/predict_llm.py \
+ --image_dir=/path/to/image.jpg \
+ --query="Provide the ocr results of this image." \
+ --config_path="/path/to/vary_toy.yaml" \
+ --chat_mode=False
+
+# Parameter description:
+# image_dir: image path.
+# query: query statement
+# config_path: indicates the configuration file path.
+# chat_mode: indicates whether to use the dialog mode.
+```
+
+The execution result is printed on the screen.
+
+For example, you can input the query statement "Describe this image in within 100 words" to generate analysis text of the following image:
+
+![PMC4055390_00006](./images/PMC4055390_00006.jpg)
+
+```
+The article discusses the analysis of traffic signals using deep learning models, specifically focusing on pedestrian crossing data. The authors propose a method to extract features from videos captured by cameras and use them to train a model for predicting pedestrian behavior. They compare their approach with other methods and show that their model outperforms others in terms of accuracy and robustness. The study also highlights the limitations of their approach, such as the need for accurate hand-crafted features and the lack of consideration for different types of vehicles. Overall, the findings suggest the potential of using machine learning models to improve traffic signal analysis and enhance safety.This article is about the use of deep learning models for predicting pedestrian behavior in traffic signals. It compares the performance of different models and highlights the limitations of these approaches.
+```
+
+### 3.3 Model Training
+
+coming soon
diff --git a/configs/llm/vary/README_CN.md b/configs/llm/vary/README_CN.md
new file mode 100644
index 000000000..c03741cb6
--- /dev/null
+++ b/configs/llm/vary/README_CN.md
@@ -0,0 +1,124 @@
+[English](README.md) | 中文
+
+# Vary-toy
+> [Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models](https://arxiv.org/abs/2312.06109)
+> [Small Language Model Meets with Reinforced Vision Vocabulary](https://arxiv.org/abs/2401.12503)
+
+## 1. 模型描述
+Vary是扩展大视觉语言模型(LVLM)视觉词汇的一种有效方法。Vary分为两个部分:新视觉词汇的生成和整合。在第一阶段,Vary设计了一个“词汇网络”以及一个很小的解码器Transformer,并通过自回归产生所需的词汇。然后,Vary通过将新的视觉词汇与原始的视觉词汇(CLIP)合并来扩大普通视觉词汇的规模,使LVLM能够快速获得新的功能。Vary-toy是Vary官方开源的较小规模版本。
+
+## 2. 评估结果
+
+根据我们的实验,Vary-toy的推理性能如下:
+
+
+
+| **模型** | **环境配置** | **总时间** | **token生成速度** | **配置文件** | **模型权重下载** |
+| :-----: | :--------: | :--------: | :-----: | :---------: | :---------: |
+| Vary-toy | D910x1-MS2.2-G | 23.38 s | 30.75 tokens/s | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/llm/vary/vary_toy.yaml)| [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/vary/vary_toy-e62a3564.ckpt) |
+
+
+**注意:**
+
+- 环境配置:训练的环境配置表示为 {处理器}x{处理器数量}-{MS模式},其中 Mindspore 模式可以是 G-graph 模式或 F-pynative 模式。例如,D910x1-MS2.2-G 使用图模式在1张昇腾910 NPU上依赖Mindspore2.2版本进行训练。
+- 如需在其他环境配置重现训练结果,请确保全局批量大小与原配置文件保持一致。
+
+## 3. 快速开始
+### 3.1 环境及模型准备
+
+#### 3.1.1 安装
+
+注:若你想实验Vary-toy,你要将python升级到3.8或以上版本。
+
+环境安装教程请参考MindOCR的 [installation instruction](https://github.com/mindspore-lab/mindocr#installation) 。
+
+此外,还需要使用如下shell命令安装`tiktoken`:
+
+```shell
+pip install tiktoken
+```
+
+#### 3.1.2 配置文件
+请重点关注以下变量的配置:`seq_length`、`checkpoint_name_or_path`、`repetition_penalty`、`max_decode_length`、`max_new_tokens`、`vocab_file`。说明如下:
+
+```yaml
+model:
+ ...
+ seq_length: 2048 # 句子长度
+ checkpoint_name_or_path: "/path/to/vary_toy.ckpt" # 权重路径
+ repetition_penalty: 1.5 # 生成重复值的惩罚项
+ max_decode_length: 2048 # 最大生成的句子长度
+ max_new_tokens: 1024 # 生成的新token的个数
+ ...
+...
+processor:
+ ...
+ tokenizer:
+ vocab_file: "/path/to/qwen.tiktoken" # 分词器路径
+ ...
+...
+```
+
+#### 3.1.3 模型准备
+
+用户可以从下方链接下载分词器模型:
+
+- [qwen.tiktoken](https://huggingface.co/HaoranWei/Vary-toy/blob/main/qwen.tiktoken)
+
+用户可以从下方链接下载权重:
+
+- [Vary-toy](https://download-mindspore.osinfra.cn/toolkits/mindocr/vary/vary_toy-e62a3564.ckpt)
+
+用户也可以从下方huggingface链接下载权重:
+
+- [Vary-toy](https://huggingface.co/HaoranWei/Vary-toy/blob/main/pytorch_model.bin)
+
+然后根据以下步骤进行权重转换:
+
+注:启动转换脚本前请安装`torch`:
+
+```shell
+pip install torch
+```
+
+下载完成后,运行mindocr/models/llm/convert_weight.py转换脚本,将huggingface的权重转换为MindSpore的ckpt权重。
+
+```shell
+python mindocr/models/llm/convert_weight.py \
+ --torch_ckpt_path="/path/to/pytorch_model.bin" \
+ --mindspore_ckpt_path="/path/to/vary_toy.ckpt"
+
+# 参数说明:
+# torch_ckpt_path:huggingface下载的权重路径
+# mindspore_ckpt_path:导出的MindSpore权重路径
+```
+
+### 3.2 模型推理
+
+```shell
+python ./tools/infer/text/predict_llm.py \
+ --image_dir=/path/to/image.jpg \
+ --query="Provide the ocr results of this image." \
+ --config_path="/path/to/vary_toy.yaml" \
+ --chat_mode=False
+
+# 参数说明:
+# image_dir:图片路径
+# query:查询语句
+# config_path:配置文件路径
+# chat_mode:是否使用对话模式
+```
+
+执行结果将打印到屏幕上。
+
+例如,可使用查询语句:"Describe this image in within 100 words.",生成对下图文本的分析结果:
+
+![PMC4055390_00006](./images/PMC4055390_00006.jpg)
+
+```txt
+The article discusses the analysis of traffic signals using deep learning models, specifically focusing on pedestrian crossing data. The authors propose a method to extract features from videos captured by cameras and use them to train a model for predicting pedestrian behavior. They compare their approach with other methods and show that their model outperforms others in terms of accuracy and robustness. The study also highlights the limitations of their approach, such as the need for accurate hand-crafted features and the lack of consideration for different types of vehicles. Overall, the findings suggest the potential of using machine learning models to improve traffic signal analysis and enhance safety.This article is about the use of deep learning models for predicting pedestrian behavior in traffic signals. It compares the performance of different models and highlights the limitations of these approaches.
+```
+
+### 3.3 模型训练
+
+coming soon
diff --git a/configs/llm/vary/images/PMC4055390_00006.jpg b/configs/llm/vary/images/PMC4055390_00006.jpg
new file mode 100644
index 000000000..610e8c3a9
Binary files /dev/null and b/configs/llm/vary/images/PMC4055390_00006.jpg differ
diff --git a/configs/llm/vary/vary_toy.yaml b/configs/llm/vary/vary_toy.yaml
new file mode 100644
index 000000000..1b6b59b9b
--- /dev/null
+++ b/configs/llm/vary/vary_toy.yaml
@@ -0,0 +1,44 @@
+model:
+ name: VaryQwenForCausalLM
+ batch_size: 1
+ seq_length: 2048
+ hidden_size: 2048
+ num_layers: 24
+ num_heads: 16
+ vocab_size: 151860
+ intermediate_size: 5504
+ rms_norm_eps: 1.0e-6
+ emb_dropout_prob: 0.0
+ eos_token_id: 151643
+ pad_token_id: 151643
+ compute_dtype: "float16"
+ layernorm_compute_type: "float32"
+ softmax_compute_type: "float16"
+ rotary_dtype: "float16"
+ param_init_type: "float16"
+ ln_param_init_type: "float16"
+ use_past: True
+ use_flash_attention: False
+ use_past_shard: False
+ offset: 0
+ checkpoint_name_or_path: "/path/to/vary_toy.ckpt"
+ repetition_penalty: 1.5
+ max_decode_length: 2048
+ top_k: 0
+ top_p: 0.8
+ do_sample: False
+ max_new_tokens: 1024
+ temperature: 1.0
+ num_beams: 1
+
+ # configuration items copied from Qwen
+ rotary_pct: 1.0
+ rotary_emb_base: 10000
+ kv_channels: 128
+
+processor:
+ return_tensors: ms
+ tokenizer:
+ vocab_file: "/path/to/qwen.tiktoken"
+ pad_token: "<|endoftext|>"
+ type: QwenProcessor
diff --git a/configs/rec/svtr/README.md b/configs/rec/svtr/README.md
index 3b2f2d3fc..c9b9d0dcc 100644
--- a/configs/rec/svtr/README.md
+++ b/configs/rec/svtr/README.md
@@ -39,6 +39,7 @@ According to our experiments, the evaluation results on public benchmark dataset
| **Model** | **Context** | **Avg Accuracy** | **Train T.** | **FPS** | **Recipe** | **Download** |
| :-----: | :-----------: | :--------------: | :----------: | :--------: | :--------: |:----------: |
| SVTR-Tiny | D910x4-MS1.10-G | 90.23% | 3638 s/epoch | 4560 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
+| SVTR-Tiny-8P | D910x8-MS2.2-G | 90.32% | 1646 s/epoch | 9840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir) |
@@ -48,6 +49,7 @@ According to our experiments, the evaluation results on public benchmark dataset
| **Model** | **IC03_860** | **IC03_867** | **IC13_857** | **IC13_1015** | **IC15_1811** | **IC15_2077** | **IIIT5k_3000** | **SVT** | **SVTP** | **CUTE80** | **Average** |
| :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
| SVTR-Tiny | 95.70% | 95.50% | 95.33% | 93.99% | 83.60% | 79.83% | 94.70% | 91.96% | 85.58% | 86.11% | 90.23% |
+ | SVTR-Tiny-8P | 95.93% | 95.62% | 95.33% | 93.89% | 84.32% | 80.55% | 94.33% | 90.57% | 86.20% | 86.46% | 90.32% |
diff --git a/configs/rec/svtr/README_CN.md b/configs/rec/svtr/README_CN.md
index 03cf014ee..a0579ac82 100644
--- a/configs/rec/svtr/README_CN.md
+++ b/configs/rec/svtr/README_CN.md
@@ -39,7 +39,7 @@ Table Format:
| **模型** | **环境配置** | **平均准确率** | **训练时间** | **FPS** | **配置文件** | **模型权重下载** |
|:------------:|:---------------:|:---------:|:------------:|:-------:|:---------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|
| SVTR-Tiny | D910x4-MS1.10-G | 90.23% | 3638 s/epoch | 4560 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny.yaml) | [ckpt](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3.ckpt) \| [mindir](https://download.mindspore.cn/toolkits/mindocr/svtr/svtr_tiny-950be1c3-86ece8c8.mindir) |
-| SVTR-Tiny-8P | D910x8-MS2.2-G | 90.32% | 1646 s/epoch | 9840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | Coming soon |
+| SVTR-Tiny-8P | D910x8-MS2.2-G | 90.32% | 1646 s/epoch | 9840 | [yaml](https://github.com/mindspore-lab/mindocr/blob/main/configs/rec/svtr/svtr_tiny_8p.yaml) | [ckpt](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6.ckpt) \| [mindir](https://download-mindspore.osinfra.cn/toolkits/mindocr/svtr/svtr_tiny_8p-0afc75d6-255191ef.mindir |
diff --git a/deploy/py_infer/src/configs/cls/ppocr/cls_mv3.yaml b/deploy/py_infer/src/configs/cls/ppocr/cls_mv3.yaml
index 1c59a2416..c62b667e1 100644
--- a/deploy/py_infer/src/configs/cls/ppocr/cls_mv3.yaml
+++ b/deploy/py_infer/src/configs/cls/ppocr/cls_mv3.yaml
@@ -10,7 +10,7 @@ eval:
channel_first: False
- RecResizeNormForInfer:
target_height: 48
- target_width: 320
+ target_width: 192 # 320 for ppocrv3
keep_ratio: True
padding: True
norm_before_pad: True
diff --git a/deploy/py_infer/src/configs/det/ppocr/ch_PP-OCRv4_det_cml.yaml b/deploy/py_infer/src/configs/det/ppocr/ch_PP-OCRv4_det_cml.yaml
index 7211bc988..54e401d64 100644
--- a/deploy/py_infer/src/configs/det/ppocr/ch_PP-OCRv4_det_cml.yaml
+++ b/deploy/py_infer/src/configs/det/ppocr/ch_PP-OCRv4_det_cml.yaml
@@ -3,7 +3,7 @@ postprocess:
binary_thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
- expand_ratio: 1.5
+ expand_ratio: 1.6
if_merge_longedge_bbox: True
merge_inter_area_thres: 300
merge_ratio: 1.3
diff --git a/docs/cn/inference/inference_thirdparty_quickstart.md b/docs/cn/inference/inference_thirdparty_quickstart.md
index b0f563642..2bc06af17 100644
--- a/docs/cn/inference/inference_thirdparty_quickstart.md
+++ b/docs/cn/inference/inference_thirdparty_quickstart.md
@@ -71,6 +71,7 @@ graph LR;
```
### 3. 第三方模型推理方法
+对于`ppocrv4`模型,这里提供了[快速转换工具](#35-快速转换工具),供用户快速将paddle模型转换为MindIR模型。
#### 3.1 文本检测
下面主要以[第三方模型支持列表](#11-文本检测)中的`ch_pp_det_OCRv4`为例介绍推理方法:
@@ -428,5 +429,25 @@ python deploy/py_infer/infer.py \
--res_save_dir=/path/to/infer_results
```
+### 3.5 快速转换工具
+对于ppocrv4,我们提供了快速转换工具,方便用户将ppocrv4的paddle模型转换为MindIR模型,使用方法如下
+ - 确认`MindSpore Lite`已成功下载并配置,详见[MindSpore Lite](https://www.mindspore.cn/lite),且`converter_lite`已加入环境变量
+ - 执行以下命令开始转换
+```bash
+cd tools
+bash paddle2mindir.sh -m=${ppocr_model_name} -p=${save_dir}
+```
+ - `$ppocr_model_name`: 进行转换的ppocr模型,支持`ch_PP-OCRv4`, `ch_PP-OCRv4_server`
+ - `$save_dir`: 保存模型下载与转换的路径
+
+转换过程将执行较久,请等待。执行后,将得到以下转换后的MindIR文件
+```
+ppocr_models
+├── ${PPOCR_MODEL_NAME}_det_db_dynamic_output.mindir
+├── ${PPOCR_MODEL_NAME}_rec_crnn_static_output.mindir
+├── ${PPOCR_MODEL_NAME}_cls_mv4_static_output.mindir
+├── ...
+```
+
## 4 FAQ
转换与推理相关问题可参考[FAQ](../tutorials/frequently_asked_questions.md)
diff --git a/docs/cn/tutorials/frequently_asked_questions.md b/docs/cn/tutorials/frequently_asked_questions.md
index ad54abd11..40622584e 100644
--- a/docs/cn/tutorials/frequently_asked_questions.md
+++ b/docs/cn/tutorials/frequently_asked_questions.md
@@ -11,6 +11,8 @@
- [`libgomp-d22c30c5.so.1.0.0`相关错误](#q10-libgomp-d22c30c5so100相关错误)
- [当在lmdb dataset上训练abinet报数据管道错误](#q11-当在lmdb-dataset上训练abinet报数据管道错误)
- [当在synthtext数据集上训练dbnet报运行时错误](#q12-当在synthtext数据集上训练dbnet报运行时错误)
+ - [安装seqeval相关错误](#q13-安装seqeval相关错误)
+ - [安装lanms相关错误](#q14-安装lanms相关错误)
### Q1 未定义符号
@@ -755,3 +757,99 @@ RuntimeError: Run task for graph:kernel_graph_1 error! The details reger to 'Asc
```
请尝试将CANN更新到7.1。
+
+
+### Q13 安装seqeval相关错误
+当运行`pip install -r requirements.txt`时,报以下错误
+```bash
+Collecting seqeval>=1.2.2 (from -r requirements.txt (line 19))
+ Downloading http://mirrors.aliyun.com/pypi/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43 kB)
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.6/43.6 kB 181.0 kB/s eta 0:00:00
+ Preparing metadata (setup.py) ... error
+ error: subprocess-exited-with-error
+
+ × python setup.py egg_info did not run successfully.
+ │ exit code: 1
+ ╰─> [48 lines of output]
+ /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py:80: _DeprecatedInstaller: setuptools.installer and fetch_build_eggs are deprecated.
+ !!
+
+ ********************************************************************************
+ Requirements should be satisfied by a PEP 517 installer.
+ If you are using pip, you can try `pip install --use-pep517`.
+ ********************************************************************************
+
+ !!
+ dist.fetch_build_eggs(dist.setup_requires)
+ WARNING: The repository located at mirrors.aliyun.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host mirrors.aliyun.com'.
+ ERROR: Could not find a version that satisfies the requirement setuptools_scm (from versions: none)
+ ERROR: No matching distribution found for setuptools_scm
+ Traceback (most recent call last):
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 101, in _fetch_build_egg_no_warn
+ subprocess.check_call(cmd)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/subprocess.py", line 373, in check_call
+ raise CalledProcessError(retcode, cmd)
+ subprocess.CalledProcessError: Command '['/home/ma-user/anaconda3/envs/MindSpore/bin/python3.9', '-m', 'pip', '--disable-pip-version-check', 'wheel', '--no-deps', '-w', '/tmp/tmpusgt0k69', '--quiet', 'setuptools_scm']' returned non-zero exit status 1.
+
+ The above exception was the direct cause of the following exception:
+
+ Traceback (most recent call last):
+ File "", line 2, in
+ File "", line 34, in
+ File "/tmp/pip-install-m2kqztlz/seqeval_da00f708dc0e483b92cd18083513d5e7/setup.py", line 27, in
+ setup(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 102, in setup
+ _install_setup_requires(attrs)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 75, in _install_setup_requires
+ _fetch_build_eggs(dist)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 80, in _fetch_build_eggs
+ dist.fetch_build_eggs(dist.setup_requires)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/dist.py", line 636, in fetch_build_eggs
+ return _fetch_build_eggs(self, requires)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 38, in _fetch_build_eggs
+ resolved_dists = pkg_resources.working_set.resolve(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 829, in resolve
+ dist = self._resolve_dist(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 865, in _resolve_dist
+ dist = best[req.key] = env.best_match(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 1135, in best_match
+ return self.obtain(req, installer)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 1147, in obtain
+ return installer(requirement)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 103, in _fetch_build_egg_no_warn
+ raise DistutilsError(str(e)) from e
+ distutils.errors.DistutilsError: Command '['/home/ma-user/anaconda3/envs/MindSpore/bin/python3.9', '-m', 'pip', '--disable-pip-version-check', 'wheel', '--no-deps', '-w', '/tmp/tmpusgt0k69', '--quiet', 'setuptools_scm']' returned non-zero exit status 1.
+ [end of output]
+
+ note: This error originates from a subprocess, and is likely not a problem with pip.
+error: metadata-generation-failed
+
+× Encountered error while generating package metadata.
+╰─> See above for output.
+
+note: This is an issue with the package mentioned above, not pip.
+
+```
+尝试以下步骤修复:
+ - 更新`setuptools`: `pip3 install --upgrade setuptools`
+ - 更新`setuptools_scm`: `pip3 install --upgrade setuptools_scm`
+ - 安装`seqeval`:`pip3 install seqeval -i https://pypi.tuna.tsinghua.edu.cn/simple`
+
+
+### Q14 安装lanms相关错误
+当安装lanms时,报
+```bash
+ImportError: Python version mismatch: module was compiled for version 3.8, while the interpreter is running version 3.7.
+```
+该问题可能是当前存在多个python3环境导致,你可使用以下步骤解决该问题
+ - 执行`pip3 install lanms -i https://pypi.tuna.tsinghua.edu.cn/simple`,得到`lanms-1.0.2.tar.gz`的下载链接(如https://pypi.tuna.tsinghua.edu.cn/packages/96/c0/50dc2c857ed060e907adaef31184413a7706e475c322236d346382e45195/lanms-1.0.2.tar.gz)
+ - 使用该下载链接,下载`lanms-1.0.2.tar.gz`,执行`tar -zxvf lanms-1.0.2.tar.gz`以解压该包
+ - `cd lanms-1.0.2`
+ - 编辑`Makefile`,在第1,2行中,用`python3.7-config`替代`python3-config`,得到如下修改
+ ```bash
+ CXXFLAGS = -I include -std=c++11 -O3 $(shell python3.7-config --cflags)
+ LDFLAGS = $(shell python3.7-config --ldflags)
+ ...
+ ```
+ 保存该`Makefile`, 执行过程将匹配到python 3.7环境
+ - 执行`python setup.py install`以安装`lanms`
diff --git a/docs/en/inference/inference_thirdparty_quickstart.md b/docs/en/inference/inference_thirdparty_quickstart.md
index 8ca9e840e..1010ff704 100644
--- a/docs/en/inference/inference_thirdparty_quickstart.md
+++ b/docs/en/inference/inference_thirdparty_quickstart.md
@@ -73,7 +73,7 @@ graph LR;
```
### 3. Third-Party Model Inference Methods
-
+For ppocrv4, we provide [Quick Convertion Tool](#35-quick-convertion-tool) for converting Paddle model to MindIR model.
#### 3.1 Text Detection
Let's take `ch_pp_det_OCRv4` in [Third-Party Model Support List](#11-text-detection) as an example to introduce the inference method:
@@ -448,5 +448,25 @@ python deploy/py_infer/infer.py \
--res_save_dir=/path/to/infer_results
```
+### 3.5 Quick Convertion Tool
+For ppocrv4,we provide tools for converting Paddle model to MindIR model, the guidence is as following:
+ - Make sure `MindSpore Lite` has been downloaded and configured successfully, please refer to [MindSpore Lite](https://www.mindspore.cn/lite). And make sure `converter_lite` has been added into the environment variable.
+ - Run the following command:
+```bash
+cd tools
+bash paddle2mindir.sh -m=${ppocr_model_name} -p=${save_dir}
+```
+ - `$ppocr_model_name`: ppocr models to be converted. `ch_PP-OCRv4`, `ch_PP-OCRv4_server` are supported
+ - `$save_dir`: folder to save downloaded ppocr models and converted mindir. Default: ppocr_models
+
+The convertion may cost minutes, please wait. And You could get the following MindIR models after convertion:
+```
+ppocr_models
+├── ${PPOCR_MODEL_NAME}_det_db_dynamic_output.mindir
+├── ${PPOCR_MODEL_NAME}_rec_crnn_static_output.mindir
+├── ${PPOCR_MODEL_NAME}_cls_mv4_static_output.mindir
+├── ...
+```
+
## 4.FAQ about converting and inference
For problems about converting model and inference, please refer to [FAQ](../tutorials/frequently_asked_questions.md) for solutions.
diff --git a/docs/en/tutorials/frequently_asked_questions.md b/docs/en/tutorials/frequently_asked_questions.md
index 3e91c0d87..ef303e296 100644
--- a/docs/en/tutorials/frequently_asked_questions.md
+++ b/docs/en/tutorials/frequently_asked_questions.md
@@ -11,6 +11,9 @@
- [Error about `libgomp-d22c30c5.so.1.0.0`](#q10-error-about-libgomp-d22c30c5so100)
- [Dataset Pipeline Error when training abinet on lmdb dataset](#q11-dataset-pipeline-error-when-training-abinet-on-lmdb-dataset)
- [Runtime Error when training dbnet on synthtext dataset](#q12-runtime-error-when-training-dbnet-on-synthtext-dataset)
+ - [Failed to install seqeval](#q13-failed-to-install-seqeval)
+ - [Failed to install lanms](#q14-failed-to-install-lanms)
+>>>>>>> dca11cc9989deabe86985f0729502266e5ba6f42
### Q1 Undefined symbol
@@ -744,3 +747,100 @@ Traceback (most recent call last):
RuntimeError: Run task for graph:kernel_graph_1 error! The details reger to 'Ascend Error Message'
```
Please update CANN to 7.1 version.
+
+
+### Q13 Failed to install seqeval
+The following error occur when run `pip install -r requirements.txt`
+```bash
+Collecting seqeval>=1.2.2 (from -r requirements.txt (line 19))
+ Downloading http://mirrors.aliyun.com/pypi/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43 kB)
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.6/43.6 kB 181.0 kB/s eta 0:00:00
+ Preparing metadata (setup.py) ... error
+ error: subprocess-exited-with-error
+
+ × python setup.py egg_info did not run successfully.
+ │ exit code: 1
+ ╰─> [48 lines of output]
+ /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py:80: _DeprecatedInstaller: setuptools.installer and fetch_build_eggs are deprecated.
+ !!
+
+ ********************************************************************************
+ Requirements should be satisfied by a PEP 517 installer.
+ If you are using pip, you can try `pip install --use-pep517`.
+ ********************************************************************************
+
+ !!
+ dist.fetch_build_eggs(dist.setup_requires)
+ WARNING: The repository located at mirrors.aliyun.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with '--trusted-host mirrors.aliyun.com'.
+ ERROR: Could not find a version that satisfies the requirement setuptools_scm (from versions: none)
+ ERROR: No matching distribution found for setuptools_scm
+ Traceback (most recent call last):
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 101, in _fetch_build_egg_no_warn
+ subprocess.check_call(cmd)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/subprocess.py", line 373, in check_call
+ raise CalledProcessError(retcode, cmd)
+ subprocess.CalledProcessError: Command '['/home/ma-user/anaconda3/envs/MindSpore/bin/python3.9', '-m', 'pip', '--disable-pip-version-check', 'wheel', '--no-deps', '-w', '/tmp/tmpusgt0k69', '--quiet', 'setuptools_scm']' returned non-zero exit status 1.
+
+ The above exception was the direct cause of the following exception:
+
+ Traceback (most recent call last):
+ File "", line 2, in
+ File "", line 34, in
+ File "/tmp/pip-install-m2kqztlz/seqeval_da00f708dc0e483b92cd18083513d5e7/setup.py", line 27, in
+ setup(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 102, in setup
+ _install_setup_requires(attrs)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 75, in _install_setup_requires
+ _fetch_build_eggs(dist)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/__init__.py", line 80, in _fetch_build_eggs
+ dist.fetch_build_eggs(dist.setup_requires)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/dist.py", line 636, in fetch_build_eggs
+ return _fetch_build_eggs(self, requires)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 38, in _fetch_build_eggs
+ resolved_dists = pkg_resources.working_set.resolve(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 829, in resolve
+ dist = self._resolve_dist(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 865, in _resolve_dist
+ dist = best[req.key] = env.best_match(
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 1135, in best_match
+ return self.obtain(req, installer)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/pkg_resources/__init__.py", line 1147, in obtain
+ return installer(requirement)
+ File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/setuptools/installer.py", line 103, in _fetch_build_egg_no_warn
+ raise DistutilsError(str(e)) from e
+ distutils.errors.DistutilsError: Command '['/home/ma-user/anaconda3/envs/MindSpore/bin/python3.9', '-m', 'pip', '--disable-pip-version-check', 'wheel', '--no-deps', '-w', '/tmp/tmpusgt0k69', '--quiet', 'setuptools_scm']' returned non-zero exit status 1.
+ [end of output]
+
+ note: This error originates from a subprocess, and is likely not a problem with pip.
+error: metadata-generation-failed
+
+× Encountered error while generating package metadata.
+╰─> See above for output.
+
+note: This is an issue with the package mentioned above, not pip.
+
+```
+Please try the following steps to fix this problem:
+ - Update `setuptools`: `pip3 install --upgrade setuptools`
+ - Update `setuptools_scm`: `pip3 install --upgrade setuptools_scm`
+ - Install `seqeval`:`pip3 install seqeval -i https://pypi.tuna.tsinghua.edu.cn/simple`
+
+
+### Q14 Failed to install lanms
+The following error occur when installing lanms
+```bash
+ImportError: Python version mismatch: module was compiled for version 3.8, while the interpreter is running version 3.7.
+```
+Some Python 3.7 environment may meet this problem when multiple python3 environment exists. You could try the following steps to solve this problem:
+1. run `pip3 install lanms -i https://pypi.tuna.tsinghua.edu.cn/simple`, and get the url for downloading `lanms-1.0.2.tar.gz`(like https://pypi.tuna.tsinghua.edu.cn/packages/96/c0/50dc2c857ed060e907adaef31184413a7706e475c322236d346382e45195/lanms-1.0.2.tar.gz)
+2. use this url and dowload the `lanms-1.0.2.tar.gz`, run `tar -zxvf lanms-1.0.2.tar.gz` to decompress the package.
+3. `cd lanms-1.0.2`
+4. edit the `Makefile`, replace `python3-config` with `python3.7-config` in line 1 and line 2, and you could get
+ ```text
+ CXXFLAGS = -I include -std=c++11 -O3 $(shell python3.7-config --cflags)
+ LDFLAGS = $(shell python3.7-config --ldflags)
+ ...
+ ```
+ save `Makefile`. So that the make process would exactly compile with python3.7 environment
+5. run `python setup.py install` and completely install lanms.
+>>>>>>> dca11cc9989deabe86985f0729502266e5ba6f42
diff --git a/mindocr/data/transforms/layoutlm_transforms.py b/mindocr/data/transforms/layoutlm_transforms.py
index 2766e05ad..4b1095a9e 100644
--- a/mindocr/data/transforms/layoutlm_transforms.py
+++ b/mindocr/data/transforms/layoutlm_transforms.py
@@ -7,6 +7,7 @@
from mindspore import nn
+from mindocr.models.backbones.layoutlmv3 import LayoutLMv3Tokenizer
from mindocr.models.backbones.layoutxlm import LayoutXLMTokenizer
from mindocr.utils.kie_utils import load_vqa_bio_label_maps
@@ -65,6 +66,7 @@ def __init__(
super(VQATokenLabelEncode, self).__init__()
tokenizer_dict = {
"LayoutXLM": {"class": LayoutXLMTokenizer, "pretrained_model": "layoutxlm-base-uncased"},
+ "LayoutLMv3": {"class": LayoutLMv3Tokenizer, "pretrained_model": "layoutxlm-base-uncased"},
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
diff --git a/mindocr/data/transforms/llm_transform.py b/mindocr/data/transforms/llm_transform.py
new file mode 100644
index 000000000..92bfaa35d
--- /dev/null
+++ b/mindocr/data/transforms/llm_transform.py
@@ -0,0 +1,302 @@
+import albumentations as alb
+import numpy as np
+from PIL import Image
+
+import mindspore as ms
+from mindspore.dataset import vision
+from mindspore.dataset.vision.utils import Inter
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+INTERPOLATION = {
+ "nearest": Inter.NEAREST,
+ "antialias": Inter.ANTIALIAS,
+ "linear": Inter.LINEAR,
+ "cubic": Inter.PILCUBIC,
+ "bicubic": Inter.BICUBIC,
+}
+
+
+def alb_wrapper(transform):
+ def f(im):
+ img = transform(image=np.asarray(im))["image"]
+ img = np.transpose(img, (2, 0, 1))
+ img = np.expand_dims(img, axis=0)
+ return img
+
+ return f
+
+
+image_processor_high = alb_wrapper(
+ alb.Compose(
+ [
+ alb.Resize(1024, 1024),
+ alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+ ]
+ )
+)
+
+
+class BCHW2BHWC:
+ """
+ Transform a batch of image from CHW to HWC.
+
+ Args:
+ image_batch (tensor, numpy.array, PIL.Image, list): for tensor or numpy input, the
+ channel should be (bz, c, h, w) or (c, h, w). for list, the item should be
+ PIL.Image or numpy.array (c, h, w).
+
+ Return:
+ transformed image batch: for numpy or tensor input, return a numpy array, the channel
+ is (bz, h, w, c) or (h, w, c); for PIL.Image input, it is returned directly.
+ """
+
+ def __call__(self, image_batch):
+ """the call function"""
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, list):
+ return [self(item) for item in image_batch]
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 4:
+ return image_batch.transpose(0, 2, 3, 1)
+ if len(image_batch.shape) == 3:
+ return image_batch.transpose(1, 2, 0)
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+ if isinstance(image_batch, Image.Image):
+ return image_batch
+ raise TypeError(f"the type {type(image_batch)} of image_batch is unsupported.")
+
+
+class BatchPILize:
+ """transform a batch of image to PIL.Image list."""
+
+ def __call__(self, image_batch):
+ """
+ The forward process.
+
+ Args:
+ image_batch (tensor, numpy.array, list): for tensor or numpy input,
+ the rank should be 4 or 3. for list, the item should be PIL.Image.
+
+ Returns:
+ return a tensor or a list of tensor.
+ """
+ if isinstance(image_batch, Image.Image):
+ return image_batch
+
+ if isinstance(image_batch, list):
+ for item in image_batch:
+ if not isinstance(item, Image.Image):
+ raise TypeError(
+ "unsupported type in list,"
+ " when the image_batch is a list,"
+ " the item in list should be PIL.Image."
+ )
+ return image_batch
+
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 4:
+ return [Image.fromarray(item.astype(np.uint8)) for item in image_batch]
+ if len(image_batch.shape) == 3:
+ return Image.fromarray(image_batch.astype(np.uint8))
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+
+ raise ValueError("unsupported input type.")
+
+
+class BatchResize:
+ """
+ Resize a batch of image to the given shape.
+
+ Args:
+ image_resolution (int): the target size.
+ interpolation: interpolate method, default is "cubic".
+ """
+
+ def __init__(self, image_resolution, interpolation="cubic"):
+ self.interpolation = INTERPOLATION.get(interpolation)
+ self.resize = vision.c_transforms.Resize(image_resolution, self.interpolation)
+
+ def __call__(self, image_batch):
+ """
+ The forward process.
+
+ Args:
+ image_batch (tensor, numpy.array, PIL.Image, list): for tensor or numpy input,
+ the shape should be (bz, h, w, c) or (h, w, c). for list, the item should be
+ PIL.Image or numpy.array (h, w, c).
+
+ Returns:
+ resized image batch: for numpy or tensor input, return a numpy array;
+ for PIL.Image input, it returns PIL.Image.
+ """
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, list):
+ return [self.resize(item) for item in image_batch]
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 4:
+ return np.row_stack([self.resize(item)[np.newaxis, :] for item in image_batch])
+ if len(image_batch.shape) == 3:
+ return self.resize(image_batch)
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+ if isinstance(image_batch, Image.Image):
+ return self.resize(image_batch)
+ raise TypeError(f"the type {type(image_batch)} of image_batch is unsupported.")
+
+
+class BatchCenterCrop:
+ """
+ CenterCrop a batch of image to the given shape.
+
+ Args:
+ image_resolution (int): the target size.
+ """
+
+ def __init__(self, image_resolution):
+ self.crop = vision.CenterCrop(image_resolution)
+
+ def __call__(self, image_batch):
+ """
+ The forward process.
+
+ Args:
+ image_batch (tensor, numpy.array, PIL.Image, list): for tensor or numpy input,
+ the shape should be (bz, h, w, c) or (h, w, c). for list, the item should be
+ PIL.Image or numpy.array (h, w, c).
+
+ Returns:
+ center cropped image batch: for numpy or tensor input, return a numpy array, the shape
+ is (bz, image_resolution, image_resolution, c) or (image_resolution,
+ image_resolution, c); for PIL.Image input, it is returned with shape (image_resolution,
+ image_resolution).
+ """
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, list):
+ return [self.crop(item) for item in image_batch]
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 4:
+ return np.row_stack([self.crop(item)[np.newaxis, :] for item in image_batch])
+ if len(image_batch.shape) == 3:
+ return self.crop(image_batch)
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+ if isinstance(image_batch, Image.Image):
+ return self.crop(image_batch)
+ raise TypeError(f"the type {type(image_batch)} of image_batch is unsupported.")
+
+
+class BatchToTensor:
+ """Transform a batch of image to tensor and scale to (0, 1)."""
+
+ def __init__(self):
+ self.totensor = ms.dataset.vision.ToTensor()
+
+ def __call__(self, image_batch):
+ """
+ The forward process.
+
+ Args:
+ image_batch (tensor, numpy.array, PIL.Image, list): for tensor or numpy input,
+ the rank should be 4 or 3. for list, the item should be PIL.Image or numpy.array.
+
+ Returns:
+ return a tensor or a list of tensor.
+ """
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, list):
+ return [self.totensor(item) for item in image_batch]
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 4:
+ return np.row_stack([self.totensor(item)[np.newaxis, :] for item in image_batch])
+ if len(image_batch.shape) == 3:
+ return self.totensor(image_batch)
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+ if isinstance(image_batch, Image.Image):
+ return self.totensor(image_batch)
+ raise TypeError(f"the type {type(image_batch)} of image_batch is unsupported.")
+
+
+class BatchNormalize:
+ """Normalize a batch of image."""
+
+ def __init__(
+ self, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), is_hwc=False
+ ):
+ self.normalize = vision.Normalize(mean=mean, std=std, is_hwc=is_hwc)
+
+ def __call__(self, image_batch):
+ """
+ The forward process.
+
+ Args:
+ image_batch (tensor, numpy.array, list): for tensor or numpy input,
+ the rank should be 4 or 3. for list, the item should be numpy.array.
+
+ Returns:
+ return a tensor or a list of tensor.
+ """
+ if isinstance(image_batch, ms.Tensor):
+ image_batch = image_batch.asnumpy()
+
+ if isinstance(image_batch, list):
+ return [self.normalize(item) for item in image_batch]
+ if isinstance(image_batch, np.ndarray):
+ if len(image_batch.shape) == 3:
+ return self.normalize(image_batch)
+ if len(image_batch.shape) == 4:
+ return np.row_stack([self.normalize(item)[np.newaxis, :] for item in image_batch])
+ raise ValueError(f"the rank of image_batch should be 3 or 4, but got {len(image_batch.shape)}")
+ raise TypeError(f"the type {type(image_batch)} of image_batch is unsupported.")
+
+
+class VaryCLIPImageProcessor:
+ def __init__(self, image_resolution=224):
+ self.image_resolution = image_resolution
+ self.bchw2bhwc = BCHW2BHWC()
+ self.batch_pilizer = BatchPILize()
+ self.batch_resizer = BatchResize(self.image_resolution)
+ self.batch_crop = BatchCenterCrop(self.image_resolution)
+ self.batch_totensor = BatchToTensor()
+ self.batch_normalizer = BatchNormalize()
+
+ def preprocess(self, images):
+ if not self._bhwc_check(images):
+ images = self.bchw2bhwc(images)
+ images = self.batch_pilizer(images)
+ images = self.batch_resizer(images)
+ images = self.batch_crop(images)
+ images = self.batch_totensor(images)
+ images = self.batch_normalizer(images)
+
+ if isinstance(images, list):
+ return np.row_stack([np.expand_dims(item, axis=0) for item in images])
+ if len(images.shape) == 4:
+ return images
+ return np.expand_dims(images, axis=0)
+
+ @staticmethod
+ def _bhwc_check(image_batch):
+ r"""Bhwc_check"""
+ if isinstance(image_batch, np.ndarray):
+ if image_batch.shape[-1] == 3:
+ return True
+ if isinstance(image_batch, ms.Tensor):
+ if image_batch.asnumpy().shape[-1] == 3:
+ return True
+ if isinstance(image_batch, (list, Image.Image)):
+ return True
+ return False
+
+
+image_processor = VaryCLIPImageProcessor().preprocess
diff --git a/mindocr/metrics/kie_metrics.py b/mindocr/metrics/kie_metrics.py
index 34d07b51c..2f0cdee3e 100644
--- a/mindocr/metrics/kie_metrics.py
+++ b/mindocr/metrics/kie_metrics.py
@@ -3,7 +3,7 @@
from glob import glob
import numpy as np
-import seqeval
+import seqeval.metrics
import sklearn
from mindspore import get_context, nn
diff --git a/mindocr/models/backbones/__init__.py b/mindocr/models/backbones/__init__.py
index b6d423152..8f71299ab 100644
--- a/mindocr/models/backbones/__init__.py
+++ b/mindocr/models/backbones/__init__.py
@@ -8,6 +8,7 @@
from .cls_mobilenet_v3 import *
from .det_mobilenet import *
from .det_resnet import *
+from .layoutlmv3 import layoutlmv3
from .layoutxlm import layoutxlm
from .rec_abinet_backbone import *
from .rec_master import *
diff --git a/mindocr/models/backbones/layoutlmv3/__init__.py b/mindocr/models/backbones/layoutlmv3/__init__.py
new file mode 100644
index 000000000..d86a97b37
--- /dev/null
+++ b/mindocr/models/backbones/layoutlmv3/__init__.py
@@ -0,0 +1,5 @@
+from .configuration import LayoutLMv3PretrainedConfig
+from .layoutlmv3 import LayoutLMv3Model
+from .tokenizer import LayoutLMv3Tokenizer
+
+__all__ = ["LayoutLMv3PretrainedConfig", "LayoutLMv3Model", "LayoutLMv3Tokenizer"]
diff --git a/mindocr/models/backbones/layoutlmv3/configuration.py b/mindocr/models/backbones/layoutlmv3/configuration.py
new file mode 100644
index 000000000..93243ddb5
--- /dev/null
+++ b/mindocr/models/backbones/layoutlmv3/configuration.py
@@ -0,0 +1,44 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class LayoutLMv3PretrainedConfig:
+ def __init__(self, use_float16=False):
+ pretrained_config = {
+ "use_float16": use_float16,
+ "fast_qkv": False,
+ "vocab_size": 250002,
+ "hidden_size": 768,
+ "num_hidden_layers": 12,
+ "num_attention_heads": 12,
+ "intermediate_size": 3072,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "attention_probs_dropout_prob": 0.1,
+ "max_position_embeddings": 514,
+ "type_vocab_size": 1,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-5,
+ "pad_token_id": 1,
+ "bos_token_id": 0,
+ "eos_token_id": 2,
+ "max_2d_position_embeddings": 1024,
+ "coordinate_size": 128,
+ "shape_size": 128,
+ "has_relative_attention_bias": True,
+ "rel_pos_bins": 32,
+ "max_rel_pos": 128,
+ "rel_2d_pos_bins": 64,
+ "max_rel_2d_pos": 256,
+ "has_spatial_attention_bias": True,
+ "text_embed": True,
+ "visual_embed": True,
+ "input_size": 224,
+ "num_channels": 3,
+ "patch_size": 16,
+ "classifier_dropout": None,
+ "num_labels": None,
+ }
+
+ for key, value in pretrained_config.items():
+ setattr(self, key, value)
diff --git a/mindocr/models/backbones/layoutlmv3/layoutlmv3.py b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py
new file mode 100644
index 000000000..1e1bc1f9b
--- /dev/null
+++ b/mindocr/models/backbones/layoutlmv3/layoutlmv3.py
@@ -0,0 +1,524 @@
+import collections
+
+import numpy as np
+
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore.common import dtype as mstype
+
+from mindocr.models.backbones._registry import register_backbone, register_backbone_class
+
+from ..transformer_common.layer import (
+ LayoutXLMAttention,
+ LayoutXLMEmbeddings,
+ LayoutXLMEncoder,
+ LayoutXLMLayer,
+ LayoutXLMSelfAttention,
+ finfo,
+)
+from .configuration import LayoutLMv3PretrainedConfig
+
+
+class LayoutLMv3PatchEmbeddings(nn.Cell):
+ """
+ LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
+ image sizes.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ image_size = (
+ config.input_size
+ if isinstance(config.input_size, collections.abc.Iterable)
+ else (config.input_size, config.input_size)
+ )
+ patch_size = (
+ config.patch_size
+ if isinstance(config.patch_size, collections.abc.Iterable)
+ else (config.patch_size, config.patch_size)
+ )
+ self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+ self.proj = nn.Conv2d(
+ config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size, has_bias=True
+ )
+
+ def construct(self, pixel_values: Tensor, position_embedding: Tensor = None):
+ embeddings = self.proj(pixel_values)
+
+ if position_embedding is not None:
+ # interpolate the position embedding to the corresponding size
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
+ position_embedding = position_embedding.transpose(0, 3, 1, 2)
+ patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
+ position_embedding = ops.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
+ embeddings = embeddings + position_embedding
+
+ embeddings = embeddings.flatten(start_dim=2).transpose(0, 2, 1)
+ return embeddings
+
+
+class LayoutLMv3TextEmbeddings(LayoutXLMEmbeddings):
+ """
+ LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def create_position_ids_from_input_ids(self, input_ids: Tensor, padding_idx):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).astype(mstype.int32)
+ incremental_indices = (ops.cumsum(mask, axis=1)) * mask
+ return incremental_indices.astype(mstype.int64) + padding_idx
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = Tensor(np.arange(self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=np.int64))
+ return position_ids.unsqueeze(0).broadcast_to(input_shape)
+
+ def construct(
+ self,
+ input_ids=None,
+ bbox=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ else:
+ input_shape = inputs_embeds.shape[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = ops.zeros(input_shape, dtype=mstype.int64)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ spatial_position_embeddings = self._cal_spatial_position_embeddings(bbox)
+
+ embeddings = embeddings + spatial_position_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class LayoutLMv3SelfAttention(LayoutXLMSelfAttention):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def cogview_attention(self, attention_scores: Tensor, alpha=32):
+ """
+ https://arxiv.org/abs/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
+ (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
+ will result in a slower speed and a little bias. Can use allclose(standard_attention_probs,
+ cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
+ """
+ scaled_attention_scores = attention_scores / alpha
+ max_value = scaled_attention_scores.max(axis=-1).unsqueeze(-1)
+ new_attention_scores = (scaled_attention_scores - max_value) * alpha
+ return nn.Softmax(axis=-1)(new_attention_scores)
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ q, k, v = self.compute_qkv(hidden_states)
+
+ # (B, L, H*D) -> (B, H, L, D)
+ query_layer = self.transpose_for_scores(q)
+ key_layer = self.transpose_for_scores(k)
+ value_layer = self.transpose_for_scores(v)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
+ # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
+ attention_scores = ops.matmul(query_layer / self.attention_head_size_sqrt, key_layer.transpose(0, 1, 3, 2))
+ if self.has_relative_attention_bias and self.has_spatial_attention_bias:
+ attention_scores += (rel_pos + rel_2d_pos) / self.attention_head_size_sqrt
+ elif self.has_relative_attention_bias:
+ attention_scores += rel_pos / self.attention_head_size_sqrt
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+ attention_scores = attention_scores + attention_mask.astype(self.dense_dtype)
+
+ # Normalize the attention scores to probabilities.
+ # Use the trick of the CogView paper to stablize training
+ attention_probs = self.cogview_attention(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = ops.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.transpose(0, 2, 1, 3)
+ new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class LayoutLMv3Attention(LayoutXLMAttention):
+ def __init__(self, config):
+ super().__init__(config)
+ self.self_attention = LayoutLMv3SelfAttention(config)
+
+
+class LayoutLMv3Layer(LayoutXLMLayer):
+ def __init__(self, config):
+ super().__init__(config)
+ self.attention = LayoutLMv3Attention(config)
+
+
+class LayoutLMv3Encoder(LayoutXLMEncoder):
+ def __init__(self, config):
+ super().__init__(config)
+ self.layer = nn.CellList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
+
+
+@register_backbone_class
+class LayoutLMv3Model(nn.Cell):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.num_hidden_layers = config.num_hidden_layers
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+ self.patch_size = config.patch_size
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+ self.min = finfo(self.dense_dtype)
+ self.out_channels = 1
+ self.use_visual_backbone = True
+
+ if config.text_embed:
+ self.embeddings = LayoutLMv3TextEmbeddings(config)
+
+ if config.visual_embed:
+ # use the default pre-training parameters for fine-tuning (e.g., input_size)
+ # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
+ self.patch_embed = LayoutLMv3PatchEmbeddings(config)
+
+ size = int(config.input_size / config.patch_size)
+ self.cls_token = Parameter(ops.zeros((1, 1, config.hidden_size)))
+ self.pos_embed = Parameter(ops.zeros((1, size * size + 1, config.hidden_size)))
+ self.pos_drop = nn.Dropout(p=0.0)
+
+ self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ if config.has_relative_attention_bias or config.has_spatial_attention_bias:
+ self.init_visual_bbox(image_size=(size, size))
+
+ self.norm = nn.LayerNorm((config.hidden_size,), epsilon=1e-6)
+
+ self.encoder = LayoutLMv3Encoder(config)
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
+ """
+ Create the bounding boxes for the visual (patch) tokens.
+ """
+ visual_bbox_x = ops.truncate_div(Tensor(np.arange(0, max_len * (image_size[1] + 1), max_len)), image_size[1])
+ visual_bbox_y = ops.truncate_div(Tensor(np.arange(0, max_len * (image_size[0] + 1), max_len)), image_size[0])
+ visual_bbox = ops.stack(
+ [
+ visual_bbox_x[:-1].broadcast_to((image_size[0], -1)),
+ visual_bbox_y[:-1].broadcast_to((image_size[1], -1)).transpose(0, 1),
+ visual_bbox_x[1:].broadcast_to((image_size[0], -1)),
+ visual_bbox_y[1:].broadcast_to((image_size[1], -1)).transpose(0, 1),
+ ],
+ axis=-1,
+ ).reshape(-1, 4)
+
+ cls_token_box = Tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
+ self.visual_bbox = ops.cat([cls_token_box, visual_bbox], axis=0)
+
+ def calculate_visual_bbox(self, dtype, batch_size):
+ final_shape = self.visual_bbox.shape
+ visual_bbox = self.visual_bbox.broadcast_to((batch_size, final_shape[0], final_shape[1]))
+ visual_bbox = visual_bbox.astype(dtype)
+ return visual_bbox
+
+ def visual_embeddings(self, pixel_values):
+ embeddings = self.patch_embed(pixel_values)
+
+ # add [CLS] token
+ batch_size, seq_len, _ = embeddings.shape
+ cls_tokens = self.cls_token.broadcast_to((batch_size, -1, -1))
+ embeddings = ops.cat((cls_tokens, embeddings), axis=1)
+
+ # add position embeddings
+ if self.pos_embed is not None:
+ embeddings = embeddings + self.pos_embed
+
+ embeddings = self.pos_drop(embeddings)
+ embeddings = self.norm(embeddings)
+
+ return embeddings
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`Tuple[int]`):
+ The shape of the input to the model.
+
+ Returns:
+ `Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+ """
+ if attention_mask.ndim == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.ndim == 2:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely. # fp16 compatibility
+ extended_attention_mask = extended_attention_mask.astype(dtype)
+ extended_attention_mask = (1.0 - extended_attention_mask) * self.min
+ return extended_attention_mask
+
+ def get_head_mask(self, head_mask, num_hidden_layers: int, is_attention_chunked: bool = False):
+ """
+ Prepare the head mask if needed.
+
+ Args:
+ head_mask (`Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+ num_hidden_layers (`int`):
+ The number of hidden layers in the model.
+ is_attention_chunked (`bool`, *optional*, defaults to `False`):
+ Whether or not the attentions scores are computed by chunks or not.
+
+ Returns:
+ `Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+ `[None]` for each layer.
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ if is_attention_chunked is True:
+ head_mask = head_mask.unsqueeze(-1)
+ else:
+ head_mask = [None] * num_hidden_layers
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask: Tensor, num_hidden_layers: int):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.ndim == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.broadcast_to((self.num_hidden_layers, -1, -1, -1, -1))
+ elif head_mask.ndim == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
+ if head_mask.ndim != 5:
+ raise ValueError(f"head_mask.dim != 5, instead {head_mask.ndim}")
+ head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
+ return head_mask
+
+ def construct(
+ self,
+ input_ids=None, # input_ids
+ bbox=None, # b_box
+ attention_mask=None, # attention_mask
+ token_type_ids=None, # token_type_ids
+ pixel_values=None, # image
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ """
+ Constructs the LayoutLMv3 model according to the input provided.
+
+ Args:
+ input_ids (Tensor, optional): Tensor containing the token IDs of the input text sequence.
+ bbox (Tensor, optional): Tensor containing the bounding box information of the input text sequence.
+ attention_mask (Tensor, optional): Tensor containing the attention mask for the input sequence.
+ token_type_ids (Tensor, optional): Tensor containing the token type IDs to distinguish different sequences.
+ pixel_values (Tensor, optional): Tensor containing the pixel values of the input image.
+ position_ids (Tensor, optional): Tensor containing the position IDs indicating the position of tokens.
+ head_mask (Tensor, optional): Mask to control which heads of the attention mechanism should be used.
+ inputs_embeds (Tensor, optional): Pre-computed embeddings for the input tokens.
+ output_attentions (bool, optional): Whether to return attention weights.
+ output_hidden_states (bool, optional): Whether to return hidden states.
+ return_dict (bool, optional): Whether to return a dictionary or a tuple of outputs.
+
+ Returns:
+ Tensor or Tuple[Tensor]: Depending on the configuration, returns either a tensor or a tuple
+ containing the output sequence and additional outputs such as hidden states and attention weights.
+ """
+ output_attentions = output_attentions if output_attentions is not None else False
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+ return_dict = return_dict if return_dict is not None else False
+ seq_length = None
+ input_shape = None
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ batch_size, seq_length = input_shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.shape[:-1]
+ batch_size, seq_length = input_shape
+ elif pixel_values is not None:
+ batch_size = len(pixel_values)
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
+ embedding_output = None
+
+ if input_ids is not None or inputs_embeds is not None:
+ if attention_mask is None:
+ attention_mask = ops.ones(((batch_size, seq_length)))
+ if token_type_ids is None:
+ token_type_ids = ops.zeros(input_shape, dtype=mstype.int64)
+ if bbox is None:
+ bbox = ops.zeros(tuple(list(input_shape) + [4]), dtype=mstype.int64)
+
+ # ocr information text embeddings
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ bbox=bbox,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ final_bbox = final_position_ids = None
+ if pixel_values is not None:
+ visual_embeddings = self.visual_embeddings(pixel_values)
+ visual_embeddings_shape = visual_embeddings.shape
+ visual_attention_mask = ops.ones((batch_size, visual_embeddings_shape[1]), dtype=mstype.int64)
+ if attention_mask is not None:
+ attention_mask = ops.cat([attention_mask, visual_attention_mask.astype(attention_mask.dtype)], axis=1)
+ else:
+ attention_mask = visual_attention_mask
+
+ if self.has_relative_attention_bias or self.has_spatial_attention_bias:
+ if self.has_spatial_attention_bias:
+ visual_bbox = self.calculate_visual_bbox(dtype=mstype.int64, batch_size=batch_size)
+ if bbox is not None:
+ final_bbox = ops.cat([bbox, visual_bbox], axis=1)
+ else:
+ final_bbox = visual_bbox
+
+ visual_embeddings_shape = visual_embeddings.shape
+ visual_position_ids = ops.arange(0, visual_embeddings_shape[1], dtype=mstype.int64).broadcast_to(
+ (batch_size, visual_embeddings_shape[1])
+ )
+ if input_ids is not None or inputs_embeds is not None:
+ position_ids = ops.arange(0, input_shape[1], dtype=mstype.int64).unsqueeze(0)
+ position_ids = position_ids.broadcast_to(input_shape)
+ final_position_ids = ops.cat([position_ids, visual_position_ids], axis=1)
+ else:
+ final_position_ids = visual_position_ids
+
+ if input_ids is not None or inputs_embeds is not None:
+ embedding_output = ops.cat([embedding_output, visual_embeddings], axis=1)
+ else:
+ embedding_output = visual_embeddings
+
+ embedding_output = self.LayerNorm(embedding_output)
+ embedding_output = self.dropout(embedding_output)
+ elif self.has_relative_attention_bias or self.has_spatial_attention_bias:
+ if self.has_spatial_attention_bias:
+ final_bbox = bbox
+ if self.has_relative_attention_bias:
+ position_ids = self.embeddings.position_ids[:, : input_shape[1]]
+ position_ids = position_ids.expand_as(input_ids)
+ final_position_ids = position_ids
+
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, None, embedding_output.dtype)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ bbox=final_bbox,
+ position_ids=final_position_ids,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ return (sequence_output,) + encoder_outputs[1:]
+
+
+@register_backbone
+def layoutlmv3(use_float16: bool = True, **kwargs):
+ pretrained_config = LayoutLMv3PretrainedConfig(use_float16)
+ model = LayoutLMv3Model(pretrained_config)
+ return model
diff --git a/mindocr/models/backbones/layoutlmv3/tokenizer.py b/mindocr/models/backbones/layoutlmv3/tokenizer.py
new file mode 100644
index 000000000..eb34972df
--- /dev/null
+++ b/mindocr/models/backbones/layoutlmv3/tokenizer.py
@@ -0,0 +1,7 @@
+from ..layoutxlm.tokenizer import LayoutXLMTokenizer
+
+
+class LayoutLMv3Tokenizer(LayoutXLMTokenizer):
+ """
+ Tokenizer of LayoutLMv3-chinese, same as LayoutXLMTokenizer.
+ """
diff --git a/mindocr/models/backbones/layoutxlm/layoutxlm.py b/mindocr/models/backbones/layoutxlm/layoutxlm.py
index bd075159b..4737c03a4 100644
--- a/mindocr/models/backbones/layoutxlm/layoutxlm.py
+++ b/mindocr/models/backbones/layoutxlm/layoutxlm.py
@@ -1,13 +1,11 @@
-import math
-
import numpy as np
-import mindspore as ms
-from mindspore import Parameter, nn, ops, set_context
-from mindspore.common.initializer import Constant, initializer
+from mindspore import Parameter, Tensor, nn, ops, set_context
+from mindspore.common import dtype as mstype
from .._registry import register_backbone, register_backbone_class
from ..mindcv_models.utils import load_pretrained
+from ..transformer_common.layer import LayoutXLMEmbeddings, LayoutXLMEncoder, LayoutXLMPooler
from .configuration import LayoutXLMPretrainedConfig
from .visual_backbone import build_resnet_fpn_backbone, read_config
@@ -34,23 +32,21 @@ def _cfg(url="", use_visual_backbone=True, **kwargs):
class VisualBackbone(nn.Cell):
def __init__(self, config):
- super(VisualBackbone, self).__init__()
+ super().__init__()
self.cfg = read_config()
self.backbone = build_resnet_fpn_backbone(self.cfg)
if len(self.cfg.MODEL.PIXEL_MEAN) != len(self.cfg.MODEL.PIXEL_STD):
- raise ValueError(
- "cfg.model.pixel_mean is not equal with cfg.model.pixel_std."
- )
+ raise ValueError("cfg.model.pixel_mean is not equal with cfg.model.pixel_std.")
num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
self.pixel_mean = Parameter(
- ms.Tensor(self.cfg.MODEL.PIXEL_MEAN).reshape((num_channels, 1, 1)),
+ Tensor(self.cfg.MODEL.PIXEL_MEAN).reshape((num_channels, 1, 1)),
name="pixel_mean",
requires_grad=False,
)
self.pixel_std = Parameter(
- ms.Tensor(self.cfg.MODEL.PIXEL_STD).reshape((num_channels, 1, 1)),
+ Tensor(self.cfg.MODEL.PIXEL_STD).reshape((num_channels, 1, 1)),
name="pixel_std",
requires_grad=False,
)
@@ -58,9 +54,7 @@ def __init__(self, config):
self.out_feature_key = "p2"
self.pool_shape = tuple(config.image_feature_pool_shape[:2]) # (7,7)
if len(config.image_feature_pool_shape) == 2:
- config.image_feature_pool_shape.append(
- self.backbone.output_shape()[self.out_feature_key].channels
- )
+ config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels)
input_shape = (224, 224)
outsize = config.image_feature_pool_shape[0] # (7,7)
@@ -70,12 +64,12 @@ def __init__(self, config):
stride = insize // outsize
kernel = insize - (outsize - 1) * stride
- self.weight = ms.Tensor(np.ones([channels, 1, kernel, kernel]), dtype=ms.float32) / (kernel * kernel)
+ self.weight = Tensor(np.ones([channels, 1, kernel, kernel]), dtype=mstype.float32) / (kernel * kernel)
self.conv2d = ops.Conv2D(channels, kernel, stride=stride, group=channels)
def pool(self, features):
"""
- To enhance performance, customize the AdaptiveAvgPool2d layer
+ Custom AvgPool2d
"""
features = self.conv2d(features, self.weight)
return features
@@ -97,628 +91,34 @@ def construct(self, images):
return features.flatten(start_dim=2).transpose(0, 2, 1)
-def relative_position_bucket(
- relative_position, bidirectional=True, num_buckets=32, max_distance=128
-):
- ret = 0
- if bidirectional:
- num_buckets //= 2
- ret += (relative_position > 0).astype(ms.int64) * num_buckets
- n = ops.abs(relative_position)
- else:
- n = ops.maximum(
- -relative_position, ops.zeros_like(relative_position)
- ) # to be confirmed
- # Now n is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = n < max_exact
-
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- val_if_large = max_exact + (
- ops.log(n.astype(ms.float32) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
- ).astype(ms.int64)
-
- val_if_large = ops.minimum(
- val_if_large, ops.full_like(val_if_large, num_buckets - 1)
- )
-
- ret += ops.where(is_small, n, val_if_large)
- return ret
-
-
-class LayoutXLMEmbeddings(nn.Cell):
- """
- Include embeddings from word, position and token_type embeddings
- """
-
- def __init__(self, config):
- super(LayoutXLMEmbeddings, self).__init__()
- self.word_embeddings = nn.Embedding(
- config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
- )
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size
- )
-
- self.x_position_embeddings = nn.Embedding(
- config.max_2d_position_embeddings, config.coordinate_size
- )
- self.y_position_embeddings = nn.Embedding(
- config.max_2d_position_embeddings, config.coordinate_size
- )
- self.h_position_embeddings = nn.Embedding(
- config.max_2d_position_embeddings, config.shape_size
- )
- self.w_position_embeddings = nn.Embedding(
- config.max_2d_position_embeddings, config.shape_size
- )
- self.token_type_embeddings = nn.Embedding(
- config.type_vocab_size, config.hidden_size
- )
-
- self.LayerNorm = nn.LayerNorm(
- (config.hidden_size,), epsilon=config.layer_norm_eps
- )
- self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
- self.position_ids = Parameter(
- ms.Tensor(np.arange(0, config.max_position_embeddings)).broadcast_to(
- (1, -1)
- ),
- name="position_ids",
- requires_grad=False,
- )
-
- def _cal_spatial_position_embeddings(self, bbox):
- bbox_0 = bbox[:, :, 0]
- bbox_1 = bbox[:, :, 1]
- bbox_2 = bbox[:, :, 2]
- bbox_3 = bbox[:, :, 3]
- left_position_embeddings = self.x_position_embeddings(bbox_0)
- upper_position_embeddings = self.y_position_embeddings(bbox_1)
- right_position_embeddings = self.x_position_embeddings(bbox_2)
- lower_position_embeddings = self.y_position_embeddings(bbox_3)
-
- h_position_embeddings = self.h_position_embeddings(bbox_3 - bbox_1)
- w_position_embeddings = self.w_position_embeddings(bbox_2 - bbox_0)
-
- spatial_position_embeddings = ops.concat(
- (
- left_position_embeddings,
- upper_position_embeddings,
- right_position_embeddings,
- lower_position_embeddings,
- h_position_embeddings,
- w_position_embeddings,
- ),
- axis=-1,
- )
- return spatial_position_embeddings
-
- def construct(self, input_ids, bbox=None, token_type_ids=None, position_ids=None):
- if position_ids is None:
- ones = ops.ones_like(input_ids, dtype=ms.int64)
- seq_length = ops.cumsum(ones, axis=-1)
-
- position_ids = seq_length - ones
- position_ids = ops.stop_gradient(
- position_ids
- ) # position_ids.stop_gradient = True
- if token_type_ids is None:
- token_type_ids = ops.zeros_like(input_ids, dtype=ms.int64)
-
- input_embedings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
-
- left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
- upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
- right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
- lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
- h_position_embeddings = self.h_position_embeddings(
- bbox[:, :, 3] - bbox[:, :, 1]
- )
- w_position_embeddings = self.w_position_embeddings(
- bbox[:, :, 2] - bbox[:, :, 0]
- )
-
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
-
- embeddings = (
- input_embedings
- + position_embeddings
- + left_position_embeddings
- + upper_position_embeddings
- + right_position_embeddings
- + lower_position_embeddings
- + h_position_embeddings
- + w_position_embeddings
- + token_type_embeddings
- )
-
- embeddings = self.layer_norm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
-
-
-class LayoutXLMSelfAttention(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMSelfAttention, self).__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
- config, "embedding_size"
- ):
- raise ValueError(
- "The hidden size {} is not a multiple of the number of attention "
- "heads {}".format(config.hidden_size, config.num_attention_heads)
- )
- self.fast_qkv = config.fast_qkv
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.attention_head_size_sqrt = math.sqrt(self.attention_head_size)
-
- self.has_relative_attention_bias = config.has_relative_attention_bias
- self.has_spatial_attention_bias = config.has_spatial_attention_bias
-
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
-
- if config.fast_qkv:
- self.qkv_linear = nn.Dense(
- config.hidden_size, 3 * self.all_head_size, has_bias=False
- ).to_float(self.dense_dtype)
- self.q_bias = Parameter(
- initializer(Constant(0.0), [1, 1, self.all_head_size], ms.float32)
- )
- self.v_bias = Parameter(
- initializer(Constant(0.0), [1, 1, self.all_head_size], ms.float32)
- )
- else:
- self.query = nn.Dense(config.hidden_size, self.all_head_size).to_float(
- self.dense_dtype
- )
- self.key = nn.Dense(config.hidden_size, self.all_head_size).to_float(
- self.dense_dtype
- )
- self.value = nn.Dense(config.hidden_size, self.all_head_size).to_float(
- self.dense_dtype
- )
-
- self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob)
-
- def transpose_for_scores(self, x):
- new_x_shape = list(x.shape[:-1]) + [
- self.num_attention_heads,
- self.attention_head_size,
- ]
-
- x = x.reshape(tuple(new_x_shape))
- return x.transpose((0, 2, 1, 3))
-
- def compute_qkv(self, hidden_states):
- if self.fast_qkv:
- qkv = self.qkv_linear(hidden_states)
- q, k, v = ops.chunk(qkv, 3, axis=-1)
- if q.ndimension() == self.q_bias.ndimension():
- q = q + self.q_bias
- v = v + self.v_bias
- else:
- _sz = (1,) * (q.ndimension() - 1) + (-1,)
- q = q + self.q_bias.reshape(_sz)
- v = v + self.v_bias.reshape(_sz)
- else:
- q = self.query(hidden_states)
- k = self.key(hidden_states)
- v = self.value(hidden_states)
- return q, k, v
-
- def construct(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- rel_pos=None,
- rel_2d_pos=None,
- ):
- q, k, v = self.compute_qkv(hidden_states)
-
- # (B, L, H*D) -> (B, H, L, D)
- query_layer = self.transpose_for_scores(q)
- key_layer = self.transpose_for_scores(k)
- value_layer = self.transpose_for_scores(v)
-
- query_layer = query_layer / self.attention_head_size_sqrt
- # [BSZ, NAT, L, L]
- attention_scores = ops.matmul(
- query_layer.astype(ms.float16),
- key_layer.transpose((0, 1, 3, 2)).astype(ms.float16),
- ).astype(ms.float32)
- if self.has_relative_attention_bias:
- attention_scores += rel_pos
- if self.has_spatial_attention_bias:
- attention_scores += rel_2d_pos
- attention_scores = ops.masked_fill(
- attention_scores.astype(ms.float32), ops.stop_gradient(attention_mask.astype(ms.bool_)), float("-1e10")
- )
- attention_probs = ops.softmax(attention_scores, axis=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- # attention_probs = self.dropout(attention_probs)
- context_layer = ops.matmul(
- attention_probs.astype(ms.float16), value_layer.astype(ms.float16)
- ).astype(ms.float32)
-
- context_layer = context_layer.transpose((0, 2, 1, 3))
- new_context_layer_shape = list(context_layer.shape[:-2]) + [self.all_head_size]
- context_layer = context_layer.reshape(new_context_layer_shape)
-
- if output_attentions:
- outputs = [context_layer, attention_probs]
- else:
- outputs = [context_layer]
- return outputs
-
-
-class LayoutXLMSelfOutput(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMSelfOutput, self).__init__()
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
- self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(
- self.dense_dtype
- )
- self.LayerNorm = nn.LayerNorm(
- (config.hidden_size,), epsilon=config.layer_norm_eps
- )
- self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
-
- def construct(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
-
-
-class LayoutXLMAttention(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMAttention, self).__init__()
- self.self_attention = LayoutXLMSelfAttention(config)
- self.output = LayoutXLMSelfOutput(config)
-
- def construct(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- rel_pos=None,
- rel_2d_pos=None,
- ):
- self_outputs = self.self_attention(
- hidden_states,
- attention_mask,
- head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_value,
- output_attentions,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- # add attentions if we output them
- if output_attentions:
- outputs = [
- attention_output,
- ] + self_outputs[1:]
- else:
- outputs = [attention_output]
- return outputs
-
-
-class LayoutXLMIntermediate(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMIntermediate, self).__init__()
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
- self.dense = nn.Dense(config.hidden_size, config.intermediate_size).to_float(
- self.dense_dtype
- )
- if config.hidden_act == "gelu":
- self.intermediate_act_fn = nn.GELU()
- else:
- raise ValueError(
- "hidden_act is set as: {}, please check it..".format(config.hidden_act)
- )
-
- def construct(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
-
-
-class LayoutXLMOutput(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMOutput, self).__init__()
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
- self.dense = nn.Dense(config.intermediate_size, config.hidden_size).to_float(
- self.dense_dtype
- )
- self.LayerNorm = nn.LayerNorm(
- (config.hidden_size,), epsilon=config.layer_norm_eps
- )
- self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
-
- def construct(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
-
-
-class LayoutXLMLayer(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMLayer, self).__init__()
- # since chunk_size_feed_forward is 0 as default, no chunk is needed here.
- self.seq_len_dim = 1
- self.attention = LayoutXLMAttention(config)
- self.add_cross_attention = False # default as false
- self.intermediate = LayoutXLMIntermediate(config)
- self.output = LayoutXLMOutput(config)
-
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
-
- def construct(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- rel_pos=None,
- rel_2d_pos=None,
- ):
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = (
- past_key_value[:2] if past_key_value is not None else None
- )
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_value=self_attn_past_key_value,
- output_attentions=output_attentions,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- )
- attention_output = self_attention_outputs[0]
- layer_output = self.feed_forward_chunk(attention_output)
-
- if output_attentions:
- outputs = self_attention_outputs[
- 1:
- ] # add self attentions if we output attention weights
- outputs = [
- layer_output,
- ] + list(outputs)
- else:
- outputs = [layer_output]
- return outputs
-
-
-class LayoutXLMEncoder(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMEncoder, self).__init__()
- self.config = config
- self.layer = nn.CellList(
- [LayoutXLMLayer(config) for _ in range(config.num_hidden_layers)]
- )
-
- self.has_relative_attention_bias = config.has_relative_attention_bias
- self.has_spatial_attention_bias = config.has_spatial_attention_bias
-
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
-
- if self.has_relative_attention_bias:
- self.rel_pos_bins = config.rel_pos_bins
- self.max_rel_pos = config.max_rel_pos
- self.rel_pos_onehot_size = config.rel_pos_bins
- self.rel_pos_bias = nn.Dense(
- self.rel_pos_onehot_size, config.num_attention_heads, has_bias=False
- ).to_float(ms.float16)
-
- if self.has_spatial_attention_bias:
- self.max_rel_2d_pos = config.max_rel_2d_pos
- self.rel_2d_pos_bins = config.rel_2d_pos_bins
- self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
- self.rel_pos_x_bias = nn.Dense(
- self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False
- ).to_float(self.dense_dtype)
- self.rel_pos_y_bias = nn.Dense(
- self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False
- ).to_float(self.dense_dtype)
-
- def _cal_1d_pos_emb(self, hidden_states, position_ids):
- rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
- rel_pos = relative_position_bucket(
- rel_pos_mat,
- num_buckets=self.rel_pos_bins,
- max_distance=self.max_rel_pos,
- )
- on_value, off_value = ms.Tensor(1.0, ms.float32), ms.Tensor(0.0, ms.float32)
- rel_pos = ops.one_hot(
- rel_pos, self.rel_pos_onehot_size, on_value, off_value
- ).astype(hidden_states.dtype)
- rel_pos = self.rel_pos_bias(rel_pos).transpose((0, 3, 1, 2))
- return rel_pos
-
- def _cal_2d_pos_emb(self, hidden_states, bbox):
- position_coord_x = bbox[:, :, 0]
- position_coord_y = bbox[:, :, 3]
- rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(
- -1
- )
- rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(
- -1
- )
- rel_pos_x = relative_position_bucket(
- rel_pos_x_2d_mat,
- num_buckets=self.rel_2d_pos_bins,
- max_distance=self.max_rel_2d_pos,
- )
- rel_pos_y = relative_position_bucket(
- rel_pos_y_2d_mat,
- num_buckets=self.rel_2d_pos_bins,
- max_distance=self.max_rel_2d_pos,
- )
- on_value, off_value = ms.Tensor(1.0, ms.float32), ms.Tensor(0.0, ms.float32)
- rel_pos_x = ops.one_hot(
- rel_pos_x, self.rel_2d_pos_onehot_size, on_value, off_value
- ).astype(hidden_states.dtype)
- rel_pos_y = ops.one_hot(
- rel_pos_y, self.rel_2d_pos_onehot_size, on_value, off_value
- ).astype(hidden_states.dtype)
- rel_pos_x = self.rel_pos_x_bias(rel_pos_x).transpose((0, 3, 1, 2))
- rel_pos_y = self.rel_pos_y_bias(rel_pos_y).transpose((0, 3, 1, 2))
- rel_2d_pos = rel_pos_x + rel_pos_y
- return rel_2d_pos
-
- def construct(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- output_attentions=False,
- output_hidden_states=False,
- bbox=None,
- position_ids=None,
- ):
- all_hidden_states = () if output_hidden_states else None
-
- rel_pos = (
- self._cal_1d_pos_emb(hidden_states, position_ids)
- if self.has_relative_attention_bias
- else None
- )
- rel_2d_pos = (
- self._cal_2d_pos_emb(hidden_states, bbox)
- if self.has_spatial_attention_bias
- else None
- )
-
- hidden_save = dict()
- hidden_save["input_hidden_states"] = hidden_states
-
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- layer_head_mask = None
- past_key_value = None
- # gradient_checkpointing is set as False here so we remove some codes here
- hidden_save["input_attention_mask"] = attention_mask
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_value,
- output_attentions,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- )
-
- hidden_states = layer_outputs[0]
-
- hidden_save["{}_data".format(i)] = hidden_states
-
- return hidden_states, hidden_save
-
-
-class LayoutXLMPooler(nn.Cell):
- def __init__(self, config):
- super(LayoutXLMPooler, self).__init__()
- self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
- if self.use_float16 is True:
- self.dense_dtype = ms.float16
- self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(
- self.dense_dtype
- )
- self.activation = nn.Tanh()
-
- def construct(self, hidden_states):
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
-
-
@register_backbone_class
class LayoutXLMModel(nn.Cell):
def __init__(self, config):
- super(LayoutXLMModel, self).__init__()
+ super().__init__()
self.config = config
self.has_visual_segment_embedding = config.has_visual_segment_embedding
self.embeddings = LayoutXLMEmbeddings(config)
self.use_visual_backbone = config.use_visual_backbone
self.use_float16 = config.use_float16
- self.dense_dtype = ms.float32
+ self.dense_dtype = mstype.float32
if self.use_float16 is True:
- self.dense_dtype = ms.float16
+ self.dense_dtype = mstype.float16
if self.use_visual_backbone is True:
set_context(jit_syntax_level=0)
self.visual = VisualBackbone(config)
self.visual.freeze()
- self.visual_proj = nn.Dense(
- config.image_feature_pool_shape[-1], config.hidden_size
- ).to_float(self.dense_dtype)
- if self.has_visual_segment_embedding:
- self.visual_segment_embedding = Parameter(
- nn.Embedding(1, config.hidden_size).embedding_table[0]
+ self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size).to_float(
+ self.dense_dtype
)
- self.visual_LayerNorm = nn.LayerNorm(
- (config.hidden_size,), epsilon=config.layer_norm_eps
- )
+ if self.has_visual_segment_embedding:
+ self.visual_segment_embedding = Parameter(nn.Embedding(1, config.hidden_size).embedding_table[0])
+ self.visual_LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
self.visual_dropout = nn.Dropout(p=config.hidden_dropout_prob)
self.encoder = LayoutXLMEncoder(config)
self.pooler = LayoutXLMPooler(config)
- self.image_feature_pool_shape_size = (
- config.image_feature_pool_shape[0] * config.image_feature_pool_shape[1]
- )
+ self.image_feature_pool_shape_size = config.image_feature_pool_shape[0] * config.image_feature_pool_shape[1]
self.image_feature_pool_shape = config.image_feature_pool_shape
self.num_hidden_layers = config.num_hidden_layers
self.max_position_embeddings = config.max_position_embeddings
@@ -734,16 +134,9 @@ def set_input_embeddings(self, value):
def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):
words_embeddings = self.embeddings.word_embeddings(input_ids)
position_embeddings = self.embeddings.position_embeddings(position_ids)
- spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(
- bbox
- )
+ spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
- embeddings = (
- words_embeddings
- + position_embeddings
- + spatial_position_embeddings
- + token_type_embeddings
- )
+ embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings
embeddings = self.embeddings.LayerNorm(embeddings)
embeddings = self.embeddings.dropout(embeddings)
return embeddings
@@ -751,14 +144,10 @@ def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):
def _calc_img_embeddings(self, image, bbox, position_ids):
use_image_info = self.use_visual_backbone and image is not None
position_embeddings = self.embeddings.position_embeddings(position_ids)
- spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(
- bbox
- )
+ spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
if use_image_info:
- visual_embeddings = self.visual_proj(self.visual(image.astype(ms.float32)))
- embeddings = (
- visual_embeddings + position_embeddings + spatial_position_embeddings
- )
+ visual_embeddings = self.visual_proj(self.visual(image.astype(mstype.float32)))
+ embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
else:
embeddings = position_embeddings + spatial_position_embeddings
if self.has_visual_segment_embedding:
@@ -777,9 +166,7 @@ def resize_position_embeddings(self, new_num_position_embeddings):
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end.
"""
- num_position_embeds_diff = (
- new_num_position_embeddings - self.max_position_embeddings
- )
+ num_position_embeds_diff = new_num_position_embeddings - self.max_position_embeddings
# no resizing needs to be done if the length stays the same
if num_position_embeds_diff == 0:
@@ -787,32 +174,24 @@ def resize_position_embeddings(self, new_num_position_embeddings):
self.max_position_embeddings = new_num_position_embeddings
- old_position_embeddings_weight = (
- self.embeddings.position_embeddings.embedding_table
- )
+ old_position_embeddings_weight = self.embeddings.position_embeddings.embedding_table
- self.embeddings.position_embeddings = nn.Embedding(
- self.max_position_embeddings, self.hidden_size
- )
+ self.embeddings.position_embeddings = nn.Embedding(self.max_position_embeddings, self.hidden_size)
if num_position_embeds_diff > 0:
self.embeddings.position_embeddings.embedding_table[
:-num_position_embeds_diff
] = old_position_embeddings_weight
else:
- self.embeddings.position_embeddings.embedding_table = (
- old_position_embeddings_weight[:num_position_embeds_diff]
- )
+ self.embeddings.position_embeddings.embedding_table = old_position_embeddings_weight[
+ :num_position_embeds_diff
+ ]
def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):
x_size = image_feature_pool_shape[1]
y_size = image_feature_pool_shape[0]
- visual_bbox_x = ms.Tensor(
- np.arange(0, 1000 * (x_size + 1), 1000) // x_size, dtype=ms.int64
- )
- visual_bbox_y = ms.Tensor(
- np.arange(0, 1000 * (y_size + 1), 1000) // y_size, dtype=ms.int64
- )
+ visual_bbox_x = Tensor(np.arange(0, 1000 * (x_size + 1), 1000) // x_size, dtype=mstype.int64)
+ visual_bbox_y = Tensor(np.arange(0, 1000 * (y_size + 1), 1000) // y_size, dtype=mstype.int64)
expand_shape = image_feature_pool_shape[0:2]
expand_shape = tuple(expand_shape)
visual_bbox = ops.stack(
@@ -824,16 +203,12 @@ def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):
],
axis=-1,
).reshape((expand_shape[0] * expand_shape[1], ops.shape(bbox)[-1]))
- visual_bbox = visual_bbox.broadcast_to(
- (visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1])
- )
+ visual_bbox = visual_bbox.broadcast_to((visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]))
return visual_bbox
def _get_input_shape(self, input_ids=None, inputs_embeds=None):
if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time"
- )
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
return input_ids.shape
elif inputs_embeds is not None:
@@ -845,9 +220,9 @@ def construct(
self,
input_ids=None,
bbox=None,
- image=None,
attention_mask=None,
token_type_ids=None,
+ image=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
@@ -857,9 +232,7 @@ def construct(
input_shape = self._get_input_shape(input_ids, inputs_embeds)
visual_shape = list(input_shape)
visual_shape[1] = self.image_feature_pool_shape_size
- visual_bbox = self._calc_visual_bbox(
- self.image_feature_pool_shape, bbox, visual_shape
- )
+ visual_bbox = self._calc_visual_bbox(self.image_feature_pool_shape, bbox, visual_shape)
final_bbox = ops.concat([bbox, visual_bbox], axis=1)
if attention_mask is None:
@@ -872,21 +245,17 @@ def construct(
attention_mask = attention_mask.astype(visual_attention_mask.dtype)
- final_attention_mask = ops.concat(
- [attention_mask, visual_attention_mask], axis=1
- )
+ final_attention_mask = ops.concat([attention_mask, visual_attention_mask], axis=1)
if token_type_ids is None:
- token_type_ids = ops.zeros(input_shape, dtype=ms.int64)
+ token_type_ids = ops.zeros(input_shape, dtype=mstype.int64)
if position_ids is None:
seq_length = input_shape[1]
position_ids = self.embeddings.position_ids[:, :seq_length]
position_ids = position_ids.broadcast_to(input_shape)
- visual_position_ids = ms.Tensor(np.arange(0, visual_shape[1])).broadcast_to(
- (input_shape[0], visual_shape[1])
- )
+ visual_position_ids = Tensor(np.arange(0, visual_shape[1])).broadcast_to((input_shape[0], visual_shape[1]))
final_position_ids = ops.concat([position_ids, visual_position_ids], axis=1)
if bbox is None:
@@ -911,12 +280,8 @@ def construct(
if head_mask is not None:
if head_mask.dim() == 1:
- head_mask = (
- head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
- )
- head_mask = head_mask.broadcast_to(
- (self.num_hidden_layers, -1, -1, -1, -1)
- )
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ head_mask = head_mask.broadcast_to((self.num_hidden_layers, -1, -1, -1, -1))
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
else:
@@ -924,7 +289,7 @@ def construct(
encoder_outputs = self.encoder(
final_emb,
- extended_attention_mask,
+ attention_mask=extended_attention_mask,
bbox=final_bbox,
position_ids=final_position_ids,
head_mask=head_mask,
@@ -937,12 +302,7 @@ def construct(
@register_backbone
-def layoutxlm(
- pretrained: bool = True,
- use_visual_backbone: bool = True,
- use_float16: bool = False,
- **kwargs
-):
+def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs):
pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16)
model = LayoutXLMModel(pretrained_config)
if pretrained:
diff --git a/mindocr/models/backbones/transformer_common/activation.py b/mindocr/models/backbones/transformer_common/activation.py
new file mode 100644
index 000000000..a75123065
--- /dev/null
+++ b/mindocr/models/backbones/transformer_common/activation.py
@@ -0,0 +1,23 @@
+from collections import OrderedDict
+
+from mindspore import nn
+
+
+class ClassInstantier(OrderedDict):
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+act_cls = {
+ "gelu": nn.GELU,
+ "relu": nn.ReLU,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": nn.SiLU,
+ "swish": nn.SiLU,
+ "tanh": nn.Tanh,
+}
+
+act_fn = ClassInstantier(act_cls)
diff --git a/mindocr/models/backbones/transformer_common/layer.py b/mindocr/models/backbones/transformer_common/layer.py
new file mode 100644
index 000000000..d906a26a8
--- /dev/null
+++ b/mindocr/models/backbones/transformer_common/layer.py
@@ -0,0 +1,522 @@
+import math
+
+import numpy as np
+
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore.common import dtype as mstype
+from mindspore.common.initializer import Constant, initializer
+
+from .activation import act_fn
+
+
+def finfo(dtype):
+ if dtype == mstype.float32:
+ return np.finfo(np.float32).min
+ elif dtype == mstype.float16:
+ return np.finfo(np.float16).min
+ else:
+ raise TypeError(f"For 'finfo', the input dtype should be float32 or float16, bug got {dtype}")
+
+
+class LayoutXLMEmbeddings(nn.Cell):
+ """
+ Include embeddings from word, position and token_type embeddings
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
+ self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_ids = Parameter(
+ Tensor(np.arange(0, config.max_position_embeddings)).broadcast_to((1, -1)),
+ name="position_ids",
+ requires_grad=False,
+ )
+
+ def _cal_spatial_position_embeddings(self, bbox):
+ bbox_0 = bbox[:, :, 0]
+ bbox_1 = bbox[:, :, 1]
+ bbox_2 = bbox[:, :, 2]
+ bbox_3 = bbox[:, :, 3]
+ left_position_embeddings = self.x_position_embeddings(bbox_0)
+ upper_position_embeddings = self.y_position_embeddings(bbox_1)
+ right_position_embeddings = self.x_position_embeddings(bbox_2)
+ lower_position_embeddings = self.y_position_embeddings(bbox_3)
+
+ h_position_embeddings = self.h_position_embeddings(bbox_3 - bbox_1)
+ w_position_embeddings = self.w_position_embeddings(bbox_2 - bbox_0)
+
+ spatial_position_embeddings = ops.concat(
+ (
+ left_position_embeddings,
+ upper_position_embeddings,
+ right_position_embeddings,
+ lower_position_embeddings,
+ h_position_embeddings,
+ w_position_embeddings,
+ ),
+ axis=-1,
+ )
+ return spatial_position_embeddings
+
+ def construct(self, input_ids, bbox=None, token_type_ids=None, position_ids=None):
+ raise NotImplementedError(
+ f"'construct' is not implemented for {self.__class__}. "
+ f"For implement it, you should overwrite this method."
+ )
+
+
+class LayoutXLMSelfAttention(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size {} is not a multiple of the number of attention "
+ "heads {}".format(config.hidden_size, config.num_attention_heads)
+ )
+ self.fast_qkv = config.fast_qkv
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.attention_head_size_sqrt = math.sqrt(self.attention_head_size)
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+
+ if config.fast_qkv:
+ self.qkv_linear = nn.Dense(config.hidden_size, 3 * self.all_head_size, has_bias=False).to_float(
+ self.dense_dtype
+ )
+ self.q_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size], self.dense_dtype))
+ self.v_bias = Parameter(initializer(Constant(0.0), [1, 1, self.all_head_size], self.dense_dtype))
+ else:
+ self.query = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype)
+ self.key = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype)
+ self.value = nn.Dense(config.hidden_size, self.all_head_size).to_float(self.dense_dtype)
+
+ self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob)
+ self.min = finfo(self.dense_dtype)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = list(x.shape[:-1]) + [
+ self.num_attention_heads,
+ self.attention_head_size,
+ ]
+
+ x = x.reshape(tuple(new_x_shape))
+ return x.transpose((0, 2, 1, 3))
+
+ def compute_qkv(self, hidden_states):
+ if self.fast_qkv:
+ qkv = self.qkv_linear(hidden_states)
+ q, k, v = ops.chunk(qkv, 3, axis=-1)
+ if q.ndimension() == self.q_bias.ndimension():
+ q = q + self.q_bias
+ v = v + self.v_bias
+ else:
+ _sz = (1,) * (q.ndimension() - 1) + (-1,)
+ q = q + self.q_bias.reshape(_sz)
+ v = v + self.v_bias.reshape(_sz)
+ else:
+ q = self.query(hidden_states)
+ k = self.key(hidden_states)
+ v = self.value(hidden_states)
+ return q, k, v
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ q, k, v = self.compute_qkv(hidden_states)
+
+ # (B, L, H*D) -> (B, H, L, D)
+ query_layer = self.transpose_for_scores(q)
+ key_layer = self.transpose_for_scores(k)
+ value_layer = self.transpose_for_scores(v)
+
+ query_layer = query_layer / self.attention_head_size_sqrt
+ # [BSZ, NAT, L, L]
+ attention_scores = ops.matmul(
+ query_layer,
+ key_layer.transpose((0, 1, 3, 2)),
+ )
+ if self.has_relative_attention_bias:
+ attention_scores += rel_pos
+ if self.has_spatial_attention_bias:
+ attention_scores += rel_2d_pos
+ attention_scores = ops.masked_fill(
+ attention_scores,
+ ops.stop_gradient(attention_mask.astype(mstype.bool_)),
+ self.min,
+ )
+ attention_probs = ops.softmax(attention_scores, axis=-1)
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+ context_layer = ops.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.transpose((0, 2, 1, 3))
+ new_context_layer_shape = list(context_layer.shape[:-2]) + [self.all_head_size]
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ if output_attentions:
+ outputs = [context_layer, attention_probs]
+ else:
+ outputs = [context_layer]
+ return outputs
+
+
+class LayoutXLMSelfOutput(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+ self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(self.dense_dtype)
+ self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ def construct(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class LayoutXLMAttention(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.self_attention = LayoutXLMSelfAttention(config)
+ self.output = LayoutXLMSelfOutput(config)
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ self_outputs = self.self_attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ # add attentions if we output them
+ if output_attentions:
+ outputs = [
+ attention_output,
+ ] + self_outputs[1:]
+ else:
+ outputs = [attention_output]
+ return outputs
+
+
+class LayoutXLMIntermediate(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+ self.dense = nn.Dense(config.hidden_size, config.intermediate_size).to_float(self.dense_dtype)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = act_fn[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def construct(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class LayoutXLMOutput(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+ self.dense = nn.Dense(config.intermediate_size, config.hidden_size).to_float(self.dense_dtype)
+ self.LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ def construct(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class LayoutXLMLayer(nn.Cell):
+ def __init__(self, config):
+ super().__init__(config)
+ # since chunk_size_feed_forward is 0 as default, no chunk is needed here.
+ self.seq_len_dim = 1
+ self.attention = LayoutXLMAttention(config)
+ self.add_cross_attention = False # default as false
+ self.intermediate = LayoutXLMIntermediate(config)
+ self.output = LayoutXLMOutput(config)
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ rel_pos=None,
+ rel_2d_pos=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=self_attn_past_key_value,
+ output_attentions=output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+ attention_output = self_attention_outputs[0]
+ layer_output = self.feed_forward_chunk(attention_output)
+
+ if output_attentions:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+ outputs = [
+ layer_output,
+ ] + list(outputs)
+ else:
+ outputs = [layer_output]
+ return outputs
+
+
+class LayoutXLMEncoder(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.CellList([LayoutXLMLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.has_relative_attention_bias = config.has_relative_attention_bias
+ self.has_spatial_attention_bias = config.has_spatial_attention_bias
+
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+
+ if self.has_relative_attention_bias:
+ self.rel_pos_bins = config.rel_pos_bins
+ self.max_rel_pos = config.max_rel_pos
+ self.rel_pos_onehot_size = config.rel_pos_bins
+ self.rel_pos_bias = nn.Dense(self.rel_pos_onehot_size, config.num_attention_heads, has_bias=False).to_float(
+ mstype.float16
+ )
+
+ if self.has_spatial_attention_bias:
+ self.max_rel_2d_pos = config.max_rel_2d_pos
+ self.rel_2d_pos_bins = config.rel_2d_pos_bins
+ self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
+ self.rel_pos_x_bias = nn.Dense(
+ self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False
+ ).to_float(self.dense_dtype)
+ self.rel_pos_y_bias = nn.Dense(
+ self.rel_2d_pos_onehot_size, config.num_attention_heads, has_bias=False
+ ).to_float(self.dense_dtype)
+
+ def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ def test(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
+ ret = 0
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).astype(mstype.int64) * num_buckets
+ n = ops.abs(relative_position)
+ else:
+ n = ops.maximum(-relative_position, ops.zeros_like(relative_position)) # to be confirmed
+ # Now n is in the range [0, inf)
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ scaling_val = ops.log(n.astype(mstype.float32) / max_exact) / math.log(max_distance / max_exact)
+ scaling_val = scaling_val * (num_buckets - max_exact)
+ val_if_large = max_exact + scaling_val.astype(mstype.int64)
+
+ val_if_large = ops.minimum(val_if_large, ops.full_like(val_if_large, num_buckets - 1))
+
+ ret += ops.where(is_small, n, val_if_large)
+ return ret
+
+ # test(relative_position.copy(), num_buckets=num_buckets, max_distance=max_distance)
+ ret = 0
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).long() * num_buckets
+ n = ops.abs(relative_position)
+ else:
+ n = ops.maximum(-relative_position, ops.zeros_like(relative_position))
+ # now n is in the range [0, inf)
+
+ # half of the buckets are for exact increments in positions
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
+ val_if_large = max_exact + (
+ ops.log(n.astype(mstype.float32) / max_exact)
+ / ops.log(Tensor(max_distance / max_exact))
+ * (num_buckets - max_exact)
+ ).astype(mstype.int64)
+ val_if_large = ops.minimum(val_if_large, ops.full_like(val_if_large, num_buckets - 1))
+
+ ret += ops.where(is_small, n, val_if_large)
+ return ret
+
+ def _cal_1d_pos_emb(self, hidden_states, position_ids):
+ rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
+ rel_pos = self.relative_position_bucket(
+ rel_pos_mat,
+ num_buckets=self.rel_pos_bins,
+ max_distance=self.max_rel_pos,
+ )
+ on_value, off_value = Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)
+ rel_pos = ops.one_hot(rel_pos, self.rel_pos_onehot_size, on_value, off_value).astype(hidden_states.dtype)
+ rel_pos = self.rel_pos_bias(rel_pos).transpose((0, 3, 1, 2))
+ return rel_pos
+
+ def _cal_2d_pos_emb(self, hidden_states, bbox):
+ position_coord_x = bbox[:, :, 0]
+ position_coord_y = bbox[:, :, 3]
+ rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
+ rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
+ rel_pos_x = self.relative_position_bucket(
+ rel_pos_x_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ rel_pos_y = self.relative_position_bucket(
+ rel_pos_y_2d_mat,
+ num_buckets=self.rel_2d_pos_bins,
+ max_distance=self.max_rel_2d_pos,
+ )
+ on_value, off_value = Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32)
+ rel_pos_x = ops.one_hot(rel_pos_x, self.rel_2d_pos_onehot_size, on_value, off_value).astype(hidden_states.dtype)
+ rel_pos_y = ops.one_hot(rel_pos_y, self.rel_2d_pos_onehot_size, on_value, off_value).astype(hidden_states.dtype)
+ rel_pos_x = self.rel_pos_x_bias(rel_pos_x).transpose((0, 3, 1, 2))
+ rel_pos_y = self.rel_pos_y_bias(rel_pos_y).transpose((0, 3, 1, 2))
+ rel_2d_pos = rel_pos_x + rel_pos_y
+ return rel_2d_pos
+
+ def construct(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ bbox=None,
+ position_ids=None,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+
+ rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
+ rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
+
+ hidden_save = dict()
+ hidden_save["input_hidden_states"] = hidden_states
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = None
+ past_key_value = None
+ # gradient_checkpointing is set as False here so we remove some codes here
+ hidden_save["input_attention_mask"] = attention_mask
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ rel_pos=rel_pos,
+ rel_2d_pos=rel_2d_pos,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ hidden_save["{}_data".format(i)] = hidden_states
+
+ return hidden_states, hidden_save
+
+
+class LayoutXLMPooler(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.use_float16 = config.use_float16
+ self.dense_dtype = mstype.float32
+ if self.use_float16 is True:
+ self.dense_dtype = mstype.float16
+ self.dense = nn.Dense(config.hidden_size, config.hidden_size).to_float(self.dense_dtype)
+ self.activation = nn.Tanh()
+
+ def construct(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
diff --git a/mindocr/models/base_model.py b/mindocr/models/base_model.py
index 7ed0699c1..de3e56691 100644
--- a/mindocr/models/base_model.py
+++ b/mindocr/models/base_model.py
@@ -62,34 +62,19 @@ def __init__(self, config: dict):
self.model_name = f"{backbone_name}_{neck_name}_{self.head_name}"
def ser(self, *inputs):
- image = inputs[4]
-
- x = self.backbone(
- input_ids=inputs[0],
- bbox=inputs[1],
- attention_mask=inputs[2],
- token_type_ids=inputs[3],
- pixel_values=image,
- )
- x = self.head(x, inputs[0])
+ input_ids, bbox, attention_mask, token_type_ids = inputs[:4]
+ image = inputs[4] if self.backbone.use_visual_backbone else None
+ x = self.backbone(input_ids, bbox, attention_mask, token_type_ids, image)
+ x = self.head(x, input_ids)
return x
def re(self, *inputs):
- if self.backbone.use_visual_backbone is True:
- image = inputs[4]
- else:
- image = None
-
- x = self.backbone(
- input_ids=inputs[0],
- bbox=inputs[1],
- attention_mask=inputs[2],
- token_type_ids=inputs[3],
- image=image,
- )
- x = self.head(x, inputs[0], inputs[5], inputs[6], inputs[7], inputs[8])
+ (input_ids, bbox, attention_mask, token_type_ids, question, question_label, answer, answer_label) = inputs[:8]
+ image = inputs[8] if self.backbone.use_visual_backbone else None
+ x = self.backbone(input_ids, bbox, attention_mask, token_type_ids, image)
+ x = self.head(x, input_ids, question, question_label, answer, answer_label)
return x
def construct(self, *args):
diff --git a/mindocr/models/builder.py b/mindocr/models/builder.py
index 28f6a1575..0ee0ae85d 100644
--- a/mindocr/models/builder.py
+++ b/mindocr/models/builder.py
@@ -7,7 +7,7 @@
from ._registry import is_model, list_models, model_entrypoint
from .base_model import BaseModel
-from .utils import load_model
+from .utils import load_model, set_amp_attr
__all__ = ["build_model"]
@@ -74,5 +74,6 @@ def build_model(name_or_config: Union[str, dict], **kwargs):
if "amp_level" in kwargs:
auto_mixed_precision(network, amp_level=kwargs["amp_level"])
+ set_amp_attr(network, kwargs["amp_level"])
return network
diff --git a/mindocr/models/necks/rnn.py b/mindocr/models/necks/rnn.py
index 629439eb3..3caf55868 100644
--- a/mindocr/models/necks/rnn.py
+++ b/mindocr/models/necks/rnn.py
@@ -2,7 +2,9 @@
import numpy as np
-from mindspore import Tensor, nn, ops
+import mindspore.ops.functional as F
+from mindspore import Tensor, nn, ops, version
+from mindspore.common import dtype
__all__ = ['RNNEncoder']
@@ -37,6 +39,11 @@ def __init__(self, in_channels: int, hidden_size: int = 512, batch_size: Option
has_bias=True,
dropout=0.,
bidirectional=True)
+ self.encoder_cast_to_fp16 = False
+ if version.__version__ >= "2.3":
+ # Adapted to MindSpore r2.3, nn.LSTM has bugs when input is FP32.
+ self.seq_encoder.to_float(dtype.float16)
+ self.encoder_cast_to_fp16 = True
self.hx = None
if batch_size is not None:
@@ -49,9 +56,15 @@ def construct(self, features: List[Tensor]) -> Tensor:
x = ops.squeeze(x, axis=2) # [N, C, W]
x = ops.transpose(x, (2, 0, 1)) # [W, N, C]
+ if self.encoder_cast_to_fp16 and self._amp_level == "O0":
+ x = F.cast(x, dtype.float16)
+
if self.hx is None:
x, _ = self.seq_encoder(x)
else:
x, _ = self.seq_encoder(x, self.hx)
- return x
+ if self.encoder_cast_to_fp16 and self._amp_level == "O0":
+ return F.cast(x, dtype.float32)
+ else:
+ return x
diff --git a/mindocr/models/transforms/tps_spatial_transformer.py b/mindocr/models/transforms/tps_spatial_transformer.py
index a49736fe3..006f72420 100644
--- a/mindocr/models/transforms/tps_spatial_transformer.py
+++ b/mindocr/models/transforms/tps_spatial_transformer.py
@@ -10,7 +10,8 @@
def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor:
- output = ops.grid_sample(input, grid)
+ out_type = input.dtype
+ output = ops.grid_sample(input.astype(ms.float64), grid.astype(ms.float64)).astype(out_type)
if canvas is None:
return output
else:
diff --git a/mindocr/models/utils/__init__.py b/mindocr/models/utils/__init__.py
index 264b95fe2..a5f9047df 100644
--- a/mindocr/models/utils/__init__.py
+++ b/mindocr/models/utils/__init__.py
@@ -1,3 +1,3 @@
from .attention_cells import *
-from .load_model import load_model
+from .load_model import load_model, set_amp_attr
from .rnn_cells import GRUCell
diff --git a/mindocr/models/utils/load_model.py b/mindocr/models/utils/load_model.py
index ec0f3f652..c44d01032 100644
--- a/mindocr/models/utils/load_model.py
+++ b/mindocr/models/utils/load_model.py
@@ -2,11 +2,11 @@
import os
from typing import Callable, Dict, Optional
-from mindspore import load_checkpoint, load_param_into_net
+from mindspore import load_checkpoint, load_param_into_net, nn
from ..backbones.mindcv_models.utils import auto_map, download_pretrained
-__all__ = ["load_model", "drop_inconsistent_shape_parameters"]
+__all__ = ["load_model", "drop_inconsistent_shape_parameters", "set_amp_attr"]
_logger = logging.getLogger(__name__)
@@ -78,3 +78,9 @@ def load_model(
f"Finish loading model checkoint from {load_from}. "
"If no parameter fail-load warning displayed, all checkpoint params have been successfully loaded."
)
+
+
+def set_amp_attr(network : nn.Cell, amp_level : str):
+ cells = network.name_cells()
+ for name in cells:
+ setattr(network._cells[name], "_amp_level", amp_level)
diff --git a/mindocr/nlp/__init__.py b/mindocr/nlp/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindocr/nlp/generation/__init__.py b/mindocr/nlp/generation/__init__.py
new file mode 100644
index 000000000..2a41f6588
--- /dev/null
+++ b/mindocr/nlp/generation/__init__.py
@@ -0,0 +1,6 @@
+from mindocr.nlp.generation.text_generator import GeneratorMixin
+
+from . import text_generator
+
+__all__ = []
+__all__.extend(text_generator.__all__)
diff --git a/mindocr/nlp/generation/beam_search.py b/mindocr/nlp/generation/beam_search.py
new file mode 100644
index 000000000..a5c931156
--- /dev/null
+++ b/mindocr/nlp/generation/beam_search.py
@@ -0,0 +1,399 @@
+"""Beam search for text generation."""
+from abc import ABC, abstractmethod
+from collections import UserDict
+from typing import List, Optional, Union
+
+import numpy as np
+
+
+class BeamScorer(ABC):
+ """Abstract base class for all beam scorers"""
+
+ @abstractmethod
+ def process(
+ self, input_ids, next_scores, next_tokens, next_indices, pad_token_id, eos_token_id, beam_indices, group_index
+ ):
+ r"""
+ Args:
+ input_ids:
+ Indices of input sequence tokens in the vocabulary.
+ next_scores:
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
+ next_tokens:
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
+ next_indices:
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
+ pad_token_id:
+ The id of the *padding* token.
+ eos_token_id:
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ beam_indices:
+ Beam indices indicating to which beam hypothesis each token correspond.
+ group_index:
+ The index of the group of beams.
+ """
+ raise NotImplementedError("This is an abstract method.")
+
+ def finalize(self, input_ids, final_beam_scores, max_length, pad_token_id, eos_token_id, beam_indices):
+ r"""
+ Args:
+ input_ids:
+ Indices of input sequence tokens in the vocabulary.
+ final_beam_scores:
+ The final scores of all non-finished beams.
+ max_length:
+ The max_length of output ids.
+ pad_token_id:
+ The id of the *padding* token.
+ eos_token_id:
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ beam_indices:
+ Beam indices indicating to which beam hypothesis each token correspond.
+ """
+ raise NotImplementedError("This is an abstract method.")
+
+
+class BeamSearchScorer(BeamScorer):
+ r"""
+ [`BeamScorer`] implementing standard beam search decoding.
+
+ Args:
+ batch_size (`int`):
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
+ num_beams (`int`):
+ Number of beams for beam search.
+ length_penalty (`float`, *optional*, defaults to 1.0):
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
+ `length_penalty` < 0.0 encourages shorter sequences.
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
+ beam search algorithm).
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
+ The number of beam hypotheses that shall be returned upon calling
+ [`~transformer.BeamSearchScorer.finalize`].
+ num_beam_groups (`int`):
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
+ See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
+ max_length (`int`, *optional*):
+ The maximum length of the sequence to be generated.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ num_beams: int,
+ length_penalty: Optional[float] = 1.0,
+ do_early_stopping: Optional[Union[bool, str]] = False,
+ num_beam_hyps_to_keep: Optional[int] = 1,
+ num_beam_groups: Optional[int] = 1,
+ max_length: Optional[int] = None,
+ ):
+ self.num_beams = num_beams
+ self.length_penalty = length_penalty
+ self.do_early_stopping = do_early_stopping
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
+ self.num_beam_groups = num_beam_groups
+ self.group_size = self.num_beams // self.num_beam_groups
+
+ self._is_init = False
+ # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
+ # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
+ self._beam_hyps = [
+ BeamHypotheses(
+ num_beams=self.group_size,
+ length_penalty=self.length_penalty,
+ early_stopping=self.do_early_stopping,
+ max_length=max_length,
+ )
+ for _ in range(batch_size * self.num_beam_groups)
+ ]
+ # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
+ # in the i-th mini-batch is complete.
+ self._done = np.array([False for _ in range(batch_size * self.num_beam_groups)])
+
+ if not isinstance(num_beams, int) or num_beams <= 1:
+ raise ValueError(
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
+ " one should make use of `greedy_search` instead."
+ )
+
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
+ raise ValueError(
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
+ )
+
+ @property
+ def is_done(self) -> bool:
+ return self._done.all()
+
+ def process(
+ self,
+ input_ids,
+ next_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ beam_indices=None,
+ group_index: Optional[int] = 0,
+ ):
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
+
+ if not batch_size == (input_ids.shape[0] // self.group_size):
+ if self.num_beam_groups > 1:
+ raise ValueError(
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
+ f"size of {self.group_size} is expected by the beam scorer."
+ )
+ raise ValueError(
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
+ f"{self.group_size} is expected by the beam scorer."
+ )
+
+ next_beam_scores = np.zeros((batch_size, self.group_size), dtype=next_scores.dtype)
+ next_beam_tokens = np.zeros((batch_size, self.group_size), dtype=next_tokens.dtype)
+ next_beam_indices = np.zeros((batch_size, self.group_size), dtype=next_indices.dtype)
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ for batch_idx in range(batch_size):
+ batch_group_idx = batch_idx * self.num_beam_groups + group_index
+ if self._done[batch_group_idx]:
+ if self.num_beams < len(self._beam_hyps[batch_group_idx]):
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
+ if eos_token_id is None or pad_token_id is None:
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
+ # pad the batch
+ next_beam_scores[batch_idx, :] = 0
+ next_beam_tokens[batch_idx, :] = pad_token_id
+ next_beam_indices[batch_idx, :] = 0
+ continue
+
+ # next tokens for this sentence
+ beam_idx = 0
+ cur_len = np.min(np.where(input_ids[beam_idx] == pad_token_id))
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
+ ):
+ batch_beam_idx = batch_idx * self.group_size + next_index
+ # add to generated hypotheses if end of sentence
+ if (eos_token_id is not None) and (next_token in eos_token_id):
+ # if beam_token does not belong to top num_beams tokens, it should not be added
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
+ if is_beam_token_worse_than_top_num_beams:
+ continue
+ if beam_indices is not None:
+ beam_index = beam_indices[batch_beam_idx]
+ beam_index = beam_index + (batch_beam_idx,)
+ else:
+ beam_index = None
+
+ self._beam_hyps[batch_group_idx].add(
+ input_ids[batch_beam_idx].copy(),
+ next_score,
+ beam_indices=beam_index,
+ )
+ else:
+ # add next predicted token since it is not eos_token
+ next_beam_scores[batch_idx, beam_idx] = next_score
+ next_beam_tokens[batch_idx, beam_idx] = next_token
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
+ beam_idx += 1
+
+ # once the beam for next step is full, don't add more tokens to it.
+ if beam_idx == self.group_size:
+ break
+
+ if beam_idx < self.group_size:
+ raise ValueError(
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+ )
+
+ # Check if we are done so that we can save a pad step if all(done)
+ self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
+ next_scores[batch_idx].max(), cur_len
+ )
+
+ return UserDict(
+ {
+ "next_beam_scores": next_beam_scores.flatten(),
+ "next_beam_tokens": next_beam_tokens.flatten(),
+ "next_beam_indices": next_beam_indices.flatten(),
+ }
+ )
+
+ def finalize(
+ self,
+ input_ids,
+ final_beam_scores,
+ max_length: int,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ beam_indices=None,
+ ):
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ # finalize all open beam hypotheses and add to generated hypotheses
+ for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
+ if self._done[batch_group_idx]:
+ continue
+
+ # all open beam hypotheses are added to the beam hypothesis
+ # beam hypothesis class automatically keeps the best beams
+ for index_per_group in range(self.group_size):
+ batch_beam_idx = batch_group_idx * self.group_size + index_per_group
+ final_score = final_beam_scores[batch_beam_idx].item()
+ final_tokens = input_ids[batch_beam_idx]
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
+
+ # select the best hypotheses
+ sent_lengths = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.int32)
+ best = []
+ best_indices = []
+ best_scores = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.float32)
+
+ # retrieve best hypotheses
+ for i in range(batch_size):
+ beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
+ candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
+ sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
+ for j in range(self.num_beam_hyps_to_keep):
+ best_hyp_tuple = sorted_hyps.pop()
+ best_score = best_hyp_tuple[0]
+ best_hyp = best_hyp_tuple[1]
+ best_index = best_hyp_tuple[2]
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
+
+ # append hyp to lists
+ best.append(best_hyp)
+
+ # append indices to list
+ best_indices.append(best_index)
+
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
+
+ # prepare for adding eos
+ sent_lengths_max = max(sent_lengths) + 1
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
+ decoded = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=np.int32)
+
+ if best_indices and best_indices[0] is not None:
+ indices = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=np.int32)
+ else:
+ indices = None
+
+ # shorter batches are padded if needed
+ if min(sent_lengths) != max(sent_lengths):
+ if pad_token_id is None:
+ raise ValueError("`pad_token_id` has to be defined")
+ decoded.fill(pad_token_id)
+
+ if indices is not None:
+ indices.fill(-1)
+
+ # fill with hypotheses and eos_token_id if the latter fits in
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
+ sent_length = min(decoded.shape[-1], sent_lengths[i])
+ decoded[i, :sent_length] = hypo[:sent_length]
+
+ if indices is not None:
+ indices[i, : len(best_idx)] = best_idx
+
+ if sent_lengths[i] < sent_max_len:
+ # inserting only the first eos_token_id
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
+
+ return UserDict(
+ {
+ "sequences": decoded,
+ "sequence_scores": best_scores,
+ "beam_indices": indices,
+ }
+ )
+
+
+class BeamHypotheses:
+ """
+ Beam hypotheses maintaining n-best list
+ """
+
+ def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
+ """
+ Initialize n-best list of hypotheses.
+ """
+ self.length_penalty = length_penalty
+ self.early_stopping = early_stopping
+ self.max_length = max_length
+ self.num_beams = num_beams
+ self.beams = []
+ self.worst_score = 1e9
+
+ if not isinstance(self.early_stopping, bool) and self.max_length is None:
+ raise ValueError(
+ "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
+ " BeamScorer class instance at initialization time."
+ )
+
+ def __len__(self):
+ """
+ Number of hypotheses in the list.
+ """
+ return len(self.beams)
+
+ def add(self, hyp, sum_logprobs: float, beam_indices=None):
+ """
+ Add a new hypothesis to the list.
+ """
+ score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
+ if len(self) < self.num_beams or score > self.worst_score:
+ self.beams.append((score, hyp, beam_indices))
+ if len(self) > self.num_beams:
+ sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
+ del self.beams[sorted_next_scores[0][1]]
+ self.worst_score = sorted_next_scores[1][0]
+ else:
+ self.worst_score = min(score, self.worst_score)
+
+ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
+ """
+ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
+ one in the heap, then we are done with this sentence.
+ """
+
+ if len(self) < self.num_beams:
+ return False
+
+ # `True`: stop as soon as at least `num_beams` hypotheses are finished
+ if self.early_stopping is True:
+ return True
+
+ # `False`: heuristic compute the best possible score from `cur_len`, even though it is not entirely accurate
+ # when `length_penalty` is positive.
+ if self.early_stopping is False:
+ highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
+ ret = self.worst_score >= highest_attainable_score
+ return ret
+
+ # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
+ # `length_penalty` > 0.0 -> max denominator is obtained from `max_length`, not from `cur_len` -> min
+ # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
+ # its max this way
+ if self.length_penalty > 0.0:
+ highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
+ # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
+ else:
+ highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
+ ret = self.worst_score >= highest_attainable_score
+ return ret
diff --git a/mindocr/nlp/generation/generation_config.py b/mindocr/nlp/generation/generation_config.py
new file mode 100644
index 000000000..5c4b874f9
--- /dev/null
+++ b/mindocr/nlp/generation/generation_config.py
@@ -0,0 +1,166 @@
+"""generation config."""
+import copy
+import logging
+from typing import Any, Dict
+
+__all__ = ["GenerationConfig"]
+_logger = logging.getLogger(__name__)
+
+
+class GenerationConfig:
+ """Class that holds a configuration for a generation task.
+ Args:
+ > Parameters that control the length of the output
+
+ max_length (`int`, *optional*, defaults to 20):
+ The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
+ `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
+ max_new_tokens (`int`, *optional*):
+ The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
+
+ > Parameters that control the generation strategy used
+
+ do_sample (`bool`, *optional*, defaults to `False`):
+ Whether to use sampling ; use greedy decoding otherwise.
+ use_past (`bool`, *optional*, defaults to `False`):
+ Whether the model should use the past last key/values attentions
+ (if applicable to the model) to speed up decoding.
+ num_beams(`int`, *optional*, defaults to 1):
+ Number of beams for beam search. 1 means no beam search. If larger than 1, use beam search strategy.
+
+ > Parameters for manipulation of the model output logits
+
+ temperature (`float`, *optional*, defaults to 1.0):
+ The value used to modulate the next token probabilities.
+ top_k (`int`, *optional*, defaults to 50):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ top_p (`float`, *optional*, defaults to 1.0):
+ If set to float < 1, only the smallest set of most probable tokens with probabilities
+ that add up to `top_p` or higher are kept for generation.
+ repetition_penalty (`float`, *optional*, defaults to 1.0):
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ encoder_repetition_penalty (`float`, *optional*, defaults to 1.0):
+ The parameter for encoder_repetition_penalty. An exponential penalty on sequences
+ that are not in the original input. 1.0 means no penalty.
+ renormalize_logits (`bool`, *optional*, defaults to `False`):
+ Whether to renormalize the logits after applying all the logits processors or wrappers (including the custom
+ ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
+ are normalized but some logit processors or wrappers break the normalization.
+
+ > Special tokens that can be used at generation time
+
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ bos_token_id (`int`, *optional*):
+ The id of the *beginning-of-sequence* token.
+ eos_token_id (`Union[int, List[int]]`, *optional*):
+ The id of the *end-of-sequence* token. Optionally, use a list to
+ set multiple *end-of-sequence* tokens.
+
+ > Wild card
+
+ generation_kwargs:
+ Additional generation kwargs will be forwarded to the `generate` function of the model.
+ Kwargs that are not present in `generate`'s signature will be used in the
+ model forward pass.
+ """
+
+ def __init__(self, **kwargs):
+ # max generate length
+ self.max_length = kwargs.pop("max_decode_length", 20)
+ self.max_length = kwargs.pop("max_length", self.max_length)
+ self.max_new_tokens = kwargs.pop("max_new_tokens", None)
+
+ # number of beams
+ self.num_beams = kwargs.pop("num_beams", 1)
+ # do sample or not
+ self.do_sample = kwargs.pop("do_sample", False)
+ # incremental infer
+ self.use_past = kwargs.pop("use_past", False)
+ # logits processors
+ self.temperature = kwargs.pop("temperature", 1.0)
+ self.top_k = kwargs.pop("top_k", 50)
+ self.top_p = kwargs.pop("top_p", 1.0)
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
+ self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0)
+ self.renormalize_logits = kwargs.pop("renormalize_logits", False)
+
+ # Special tokens that can be used at generation time
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
+
+ # interface.
+ self._from_model_config = kwargs.pop("_from_model_config", False)
+ # Additional attributes without default values
+ if not self._from_model_config:
+ # we don't want to copy values from the model config
+ # if we're initializing a `GenerationConfig` from a
+ # model's default configuration file
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ _logger.error("Can't set %s with value %s for %s", key, value, self)
+ raise err
+
+ def __str__(self) -> str:
+ return str(self.__dict__)
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
+ """
+ Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
+
+ Args:
+ config_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the configuration object.
+ kwargs:
+ Additional parameters from which to initialize the configuration object.
+
+ Returns:
+ [`GenerationConfig`]: The configuration object instantiated from those parameters.
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ config = cls(**{**config_dict, **kwargs})
+ unused_kwargs = config.update(**kwargs)
+ _logger.debug("Generate config %s", config)
+ if return_unused_kwargs:
+ return config, unused_kwargs
+ return config
+
+ @classmethod
+ def from_model_config(cls, model_config) -> "GenerationConfig":
+ config_dict = model_config
+ config_dict.pop("_from_model_config", None)
+ config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)
+
+ return config
+
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs`
+ if they match existing attributes, returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs
+ that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+ def to_dict(self) -> Dict[str, Any]:
+ """to dict convert function."""
+ output = copy.deepcopy(self.__dict__)
+ return output
diff --git a/mindocr/nlp/generation/logits_process.py b/mindocr/nlp/generation/logits_process.py
new file mode 100644
index 000000000..cd8e9e9d4
--- /dev/null
+++ b/mindocr/nlp/generation/logits_process.py
@@ -0,0 +1,220 @@
+"""Logits Processor for generation."""
+import inspect
+from threading import Thread
+
+import numpy as np
+
+from .utils import log_softmax, softmax, topk
+
+__all__ = [
+ "LogitsProcessor",
+ "LogitsWarper",
+ "LogitsProcessorList",
+ "RepetitionPenaltyLogitsProcessor",
+ "LogitNormalization",
+ "TemperatureLogitsWarper",
+ "TopKLogitsWarper",
+ "TopPLogitsWarper",
+]
+
+
+class LogitsProcessor:
+ """Abstract base class for all logit processors that can be applied during generation."""
+
+ def __call__(self, input_ids, scores):
+ """Torch method for processing logits."""
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsWarper:
+ """Abstract base class for all logit warpers that can be applied during generation
+ with multinomial sampling."""
+
+ def __call__(self, input_ids, scores):
+ """Torch method for warping logits."""
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsProcessorList(list):
+ """
+ This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently
+ process a `scores` input tensor. This class inherits from list and adds a specific *__call__* method
+ to apply each [`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
+ """
+
+ def __call__(self, input_ids, scores, is_finished=None, **kwargs):
+ all_threads = []
+ for i in range(0, input_ids.shape[0]):
+ if is_finished and is_finished[i]:
+ continue
+ thread = Thread(target=self.process, args=(i, input_ids, scores), kwargs=kwargs)
+ all_threads.append(thread)
+ thread.start()
+ for thread in all_threads:
+ thread.join()
+ return scores
+
+ def process(self, i, input_ids, scores, **kwargs):
+ """apply process"""
+ input_ids = input_ids[i : i + 1]
+ scores_i = scores[i : i + 1]
+ for processor in self:
+ function_args = inspect.signature(processor.__call__).parameters
+ if len(function_args) > 2:
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
+ raise ValueError(
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
+ f"{processor.__class__} are passed to the logits processor."
+ )
+ scores_i = processor(input_ids, scores_i, **kwargs)
+ else:
+ scores_i = processor(input_ids, scores_i)
+ scores[i] = scores_i
+
+
+class TemperatureLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
+
+ Args:
+ temperature (`float`):
+ The value used to module the logits distribution.
+ """
+
+ def __init__(self, temperature: float):
+ temperature = float(temperature)
+ if temperature <= 0:
+ raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
+
+ self.temperature = temperature
+
+ def __call__(self, input_ids, scores):
+ scores = scores / self.temperature
+ return scores
+
+
+class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
+
+ Args:
+ repetition_penalty (`float`):
+ The parameter for repetition penalty. 1.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ """
+
+ def __init__(self, repetition_penalty: float):
+ repetition_penalty = float(repetition_penalty)
+ if repetition_penalty <= 0:
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {repetition_penalty}")
+
+ self.penalty = repetition_penalty
+
+ def __call__(self, input_ids, scores):
+ score = np.take_along_axis(scores, input_ids, axis=1)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ negative_index = score < 0
+ positive_index = ~negative_index
+ score[negative_index] = score[negative_index] * self.penalty
+ score[positive_index] = score[positive_index] / self.penalty
+
+ np.put_along_axis(scores, input_ids, score, axis=1)
+ return scores
+
+
+class TopPLogitsWarper(LogitsWarper):
+ """
+ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
+
+ Args:
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities
+ that add up to `top_p` or higher are kept for generation.
+ filter_value (`float`, *optional*, defaults to `-50000`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ candidate_token_num (`int`, *optional*, defaults to 200):
+ Number of candidate tokens to calculate top_p. this can avoid sorting a huge seq,
+ save time to speed up generation.
+ """
+
+ def __init__(
+ self, top_p: float, filter_value: float = -50000, min_tokens_to_keep: int = 1, candidate_token_num: int = 200
+ ):
+ top_p = float(top_p)
+ if top_p < 0 or top_p > 1.0:
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
+ raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
+
+ self.top_p = top_p
+ self.filter_value = float(filter_value)
+ self.min_tokens_to_keep = min_tokens_to_keep
+ self.candicate_token_num = candidate_token_num
+
+ def __call__(self, input_ids, scores):
+ candidate_logits, candidate_indices = topk(scores, self.candicate_token_num)
+ cumulative_probs = softmax(candidate_logits)
+ cumulative_probs = np.cumsum(cumulative_probs, axis=-1)
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_keep = cumulative_probs < self.top_p
+ # add the last token that exceed top_p
+ sorted_indices_to_keep = np.concatenate(
+ [np.ones(shape=(scores.shape[0], 1)).astype(np.bool_), sorted_indices_to_keep[..., :-1]], axis=-1
+ )
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_keep[..., : self.min_tokens_to_keep] = 1
+
+ # set remove indices, filter negative value
+ indices_to_remove = np.ones_like(scores).astype(np.bool_)
+ np.put_along_axis(indices_to_remove, candidate_indices, ~sorted_indices_to_keep, axis=-1)
+ scores[indices_to_remove] = self.filter_value
+
+ return scores
+
+
+class TopKLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
+
+ Args:
+ top_k (`int`):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ filter_value (`float`, *optional*, defaults to `-float("Inf")`):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(self, top_k: int, filter_value: float = -50000, min_tokens_to_keep: int = 1):
+ if not isinstance(top_k, int) or top_k <= 0:
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
+
+ self.top_k = max(top_k, min_tokens_to_keep)
+ self.filter_value = float(filter_value)
+
+ def __call__(self, input_ids, scores: np.ndarray):
+ top_k = min(self.top_k, scores.shape[-1]) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = scores < topk(scores, top_k)[0][..., -1, None]
+ scores[indices_to_remove] = self.filter_value
+ return scores
+
+
+class LogitNormalization(LogitsProcessor, LogitsWarper):
+ r"""
+ [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
+ the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
+ this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
+ the scores are normalized when comparing the hypotheses.
+ """
+
+ def __call__(self, input_ids, scores):
+ scores = log_softmax(scores, axis=-1)
+ return scores
diff --git a/mindocr/nlp/generation/text_generator.py b/mindocr/nlp/generation/text_generator.py
new file mode 100644
index 000000000..ed427fe09
--- /dev/null
+++ b/mindocr/nlp/generation/text_generator.py
@@ -0,0 +1,1082 @@
+"""For text generation"""
+import copy
+import logging
+import time
+from typing import List, Optional, Union
+
+import numpy as np
+
+import mindspore.common.dtype as mstype
+from mindspore import ops
+from mindspore.common.tensor import Tensor
+
+from mindocr.nlp.generation.beam_search import BeamSearchScorer
+from mindocr.nlp.generation.generation_config import GenerationConfig
+from mindocr.nlp.generation.logits_process import (
+ LogitNormalization,
+ LogitsProcessorList,
+ RepetitionPenaltyLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+)
+from mindocr.nlp.generation.utils import softmax_with_threads, topk
+
+__all__ = ["GeneratorMixin"]
+_logger = logging.getLogger(__name__)
+
+
+class GenerationMode:
+ """
+ Possible generation modes.
+ """
+
+ # Non-beam methods
+ GREEDY_SEARCH = "greedy_search"
+ SAMPLE = "sample"
+ # Beam methods
+ BEAM_SEARCH = "beam_search"
+
+
+class GeneratorMixin:
+ """Generator For the nlp models"""
+
+ def __init__(self):
+ pass
+
+ # pylint: disable=W0613
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
+ """
+ prepare inputs for generation.
+ A model class needs to define a `prepare_inputs_for_generation` method
+ in order to use `.generate()`
+
+ Raises:
+ RuntimeError: Not implemented in model but call `.generate()`
+ """
+ raise RuntimeError(
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
+ )
+
+ # pylint: disable=W0613
+ def update_model_kwargs_before_generate(self, input_ids, model_kwargs: dict):
+ """
+ update model kwargs before generate.
+ If your model needs to update model kwargs before generate, implement
+ this method in your model, else do nothing.
+ """
+ return
+
+ @staticmethod
+ def slice_incremental_inputs(model_inputs: dict, current_index):
+ """used for non-first iterations, slice the inputs to length 1."""
+ input_ids = model_inputs.pop("input_ids")
+ if isinstance(input_ids, Tensor):
+ input_ids = input_ids.asnumpy()
+ inputs_tmp = []
+ for i, index_value in enumerate(current_index):
+ current_index_tmp = int(index_value) - i * input_ids.shape[1] # multi batch
+ # use numpy to slice array to avoid compile ascend slice op
+ inputs_tmp.append(input_ids[i][current_index_tmp : current_index_tmp + 1])
+ inputs_tmp = np.array(inputs_tmp, dtype=np.int32)
+ model_inputs["input_ids"] = Tensor(inputs_tmp, mstype.int32)
+
+ @staticmethod
+ def process_logits(logits, current_index=None, keep_all=False):
+ """Process the logits"""
+ logits = logits.reshape(-1, logits.shape[-1])
+ if not keep_all and current_index is not None:
+ index = current_index.view(
+ -1,
+ )
+ logits = ops.Gather()(logits, index, 0)
+ outputs = ops.LogSoftmax(-1)(logits)
+ outputs = ops.tensor_pow(np.e, outputs)
+ return outputs
+
+ def _get_logits_processor(
+ self, generation_config: GenerationConfig, logits_processor: Optional[LogitsProcessorList]
+ ):
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
+ instances used to modify the scores of the language model head.
+ """
+ # instantiate processors list
+ processors = LogitsProcessorList()
+
+ if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
+ processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty=generation_config.repetition_penalty))
+ processors = self._merge_processor_list(processors, logits_processor)
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ processors.append(LogitNormalization())
+ return processors
+
+ def _merge_processor_list(self, default_list: LogitsProcessorList, custom_list: LogitsProcessorList):
+ """merge custom processor list with default list."""
+ if not custom_list:
+ return default_list
+ for default in default_list:
+ for custom in custom_list:
+ if type(custom) is type(default):
+ object_type = "logits processor"
+ raise ValueError(
+ f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
+ f" `.generate()`, but it has already been created with the values {default}."
+ f" {default} has been created by passing the corresponding arguments to generate or"
+ f" by the model's config default values. If you just want to change the default values"
+ f" of {object_type} consider passing them as arguments to `.generate()`"
+ f" instead of using a custom {object_type}."
+ )
+ default_list.extend(custom_list)
+ return default_list
+
+ def _get_logits_warper(self, generation_config: GenerationConfig):
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
+ used for multinomial sampling.
+ """
+
+ # instantiate wrappers list
+ wrappers = LogitsProcessorList()
+
+ # all samplers can be found in `generation_utils_samplers.py`
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ wrappers.append(TemperatureLogitsWarper(generation_config.temperature))
+ min_tokens_to_keep = 1
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ wrappers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ wrappers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ wrappers.append(LogitNormalization())
+ return wrappers
+
+ @staticmethod
+ def _get_generation_mode(generation_config: GenerationConfig):
+ """determine the generation mode by config"""
+ if generation_config.num_beams == 1:
+ if generation_config.do_sample:
+ _logger.info("The generation mode will be **SAMPLE**.")
+ return GenerationMode.SAMPLE
+ _logger.info("The generation mode will be **GREEDY_SEARCH**.")
+ return GenerationMode.GREEDY_SEARCH
+ _logger.info("The generation mode will be **BEAM_SEARCH**.")
+ return GenerationMode.BEAM_SEARCH
+
+ def _prepare_model_inputs_for_decoder(self, input_ids, input_mask):
+ """generate the inputs for the decoder"""
+ batch_size = input_ids.shape[0]
+
+ encoder_mask = Tensor(input_mask, mstype.float32)
+
+ encoder_output = self.encoder_forward(Tensor(input_ids, mstype.int32), encoder_mask)
+
+ input_ids = np.zeros((batch_size, self.config.max_decode_length))
+ _logger.debug("Decoder: pad the origin inputs into shape: %s", input_ids.shape)
+ target_mask = np.zeros_like(input_ids)
+ target_mask[:, 0] = 1
+
+ # As the decoder is generating from [START] token
+ return encoder_output, encoder_mask, input_ids, target_mask
+
+ def _pad_inputs_using_max_length(self, origin_inputs, pad_token_id=0):
+ """pad the input_ids to the max_length"""
+ pad_length = self.config.seq_length - origin_inputs.shape[-1]
+ if pad_length < 0:
+ raise ValueError(
+ f"origin_inputs size is {origin_inputs.shape}, you should"
+ f"increase the seq_length of the model {self.config.seq_length}."
+ )
+ # Pad original inputs to model_origin_max_length
+ input_ids = np.pad(
+ origin_inputs,
+ ((0, 0), (0, pad_length)),
+ "constant",
+ constant_values=(0, pad_token_id),
+ )
+ return input_ids
+
+ def _incremental_infer(self, model_inputs: dict, current_index, valid_length_each_example):
+ """model forward for incremental infer."""
+ # Claim the first graph
+ if self.is_first_iteration:
+ self.add_flags_recursive(is_first_iteration=True)
+ model_inputs["input_position"] = Tensor(current_index, mstype.int32)
+ model_inputs["init_reset"] = Tensor([False], mstype.bool_) # init_reset (1,) bool False
+ model_inputs["batch_valid_length"] = Tensor([valid_length_each_example], mstype.int32)
+ # pylint: disable=E1102
+ res = self(**model_inputs)
+ # first iter done, go to other iters
+ self.is_first_iteration = False
+ self.add_flags_recursive(is_first_iteration=False)
+ else:
+ # slice model inputs for incremental infer
+ self.slice_incremental_inputs(model_inputs, current_index)
+ model_inputs["input_position"] = Tensor(current_index, mstype.int32)
+ model_inputs["init_reset"] = Tensor([True], mstype.bool_) # init_reset (1,) bool True
+ model_inputs["batch_valid_length"] = Tensor([valid_length_each_example], mstype.int32)
+ # pylint: disable=E1102
+ res = self(
+ **model_inputs,
+ )
+
+ return res
+
+ def _greedy_search(
+ self,
+ origin_inputs,
+ generation_config: GenerationConfig,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ streamer=None,
+ **model_kwargs,
+ ):
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
+ used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
+ instead.
+
+ Parameters:
+ origin_inputs (`List(str), List(List(str))`):
+ The sequence used as a prompt for the generation.
+ generation_config (`GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation
+ call. `**kwargs` passed to generate matching the attributes of `generation_config`
+ will override them. If `generation_config` is not provided, the default config
+ from the model configuration will be used. Please note that unspecified parameters
+ will inherit [`GenerationConfig`]'s default values, whose documentation should be
+ checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ streamer (`TextStreamer, *optional*`):
+ The streamer that generator uses.
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ A list of the generated token ids
+ """
+ total_time = time.time()
+ prepare_time = time.time()
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+
+ if generation_config.pad_token_id is None:
+ generation_config.pad_token_id = 0
+
+ if streamer is not None:
+ streamer.put(origin_inputs)
+
+ batch_size = origin_inputs.shape[0]
+ is_encoder_decoder = self.config.is_encoder_decoder
+ _logger.debug("The input shape is: %s", origin_inputs.shape)
+ valid_length_each_example = []
+ for i in range(batch_size):
+ # As the nonzero returns the index and we need length
+ valid_length_each_example.append(
+ np.max(np.argwhere(origin_inputs[i] != generation_config.pad_token_id)) + 1
+ )
+ valid_length_each_example = np.array(valid_length_each_example)
+ _logger.debug("Get the valid for each example is: %s", valid_length_each_example)
+
+ # Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = np.max(valid_length_each_example)
+ if generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_length
+
+ if generation_config.max_length > self.config.seq_length:
+ _logger.warning(
+ "max_length %s can not exceeds model seq_length %s, set max_length = seq_length.",
+ generation_config.max_length,
+ self.config.seq_length,
+ )
+ generation_config.max_length = self.config.seq_length
+
+ _logger.debug("max length is: %s", generation_config.max_length)
+ if not is_encoder_decoder and input_ids_length >= generation_config.max_length:
+ raise ValueError(
+ f"the input_ids length {input_ids_length} exceeds the max length config {generation_config.max_length}."
+ f"check your inputs and set max_length larger than your inputs length."
+ )
+
+ input_ids = self._pad_inputs_using_max_length(
+ origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id
+ )
+
+ _logger.debug(
+ "pad the origin inputs from %s into shape: %s",
+ origin_inputs.shape,
+ input_ids.shape,
+ )
+
+ input_mask = np.zeros_like(input_ids)
+ for i in range(valid_length_each_example.shape[0]):
+ input_mask[i, : valid_length_each_example[i]] = 1
+ encoder_output = None
+ encoder_mask = None
+ if is_encoder_decoder:
+ if generation_config.max_length > self.config.max_decode_length:
+ generation_config.max_length = self.config.max_decode_length
+ _logger.debug("max decode length is: %s", generation_config.max_length)
+
+ # When do encoder and decoder prediction, the encoder can be cached
+ # to speed up the inference
+ (
+ encoder_output,
+ encoder_mask,
+ input_ids,
+ target_mask,
+ ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask)
+ valid_length_each_example = [1 for _ in range(batch_size)]
+ # A single loop generates one token, loop until reaching target
+ # model_origin_max_length or generating eod token
+ is_finished = [False] * batch_size
+
+ # update model kwargs once, before go into generate loop.
+ self.update_model_kwargs_before_generate(input_ids, model_kwargs)
+
+ # setup is_first_iteration flag for incremental infer
+ if generation_config.use_past:
+ self.is_first_iteration = True
+ need_gather_logits = True
+
+ origin_len = np.sum(valid_length_each_example)
+ prepare_time = time.time() - prepare_time
+ _logger.debug("forward prepare time: %s s", prepare_time)
+
+ while np.sum(is_finished) != batch_size:
+ forward_time = time.time()
+ seq_length = input_ids.shape[1]
+ current_index = [valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_size)]
+ _logger.debug("validate length: %s", valid_length_each_example)
+ if is_encoder_decoder:
+ inputs = Tensor(input_ids, mstype.int32)
+ # pylint: disable=E1102
+ res = self(
+ input_ids=None,
+ attention_mask=encoder_mask,
+ encoder_outputs=encoder_output,
+ decoder_input_ids=inputs,
+ decoder_attention_mask=Tensor(target_mask, mstype.float32),
+ )
+ else:
+ model_kwargs["current_index"] = current_index
+ # model prepare input dict
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # incremental generate
+ if generation_config.use_past:
+ # when first iteration, gather last logits; others keep all logits.
+ need_gather_logits = self.is_first_iteration
+ # incremental generate
+ res = self._incremental_infer(
+ model_inputs=model_inputs,
+ current_index=current_index,
+ valid_length_each_example=valid_length_each_example,
+ )
+ # auto-aggressive generate
+ else:
+ res = self(**model_inputs) # pylint: disable=E1102
+ forward_time = time.time() - forward_time
+
+ search_time = time.time()
+ # post process logits; skip this phase if post process is done in graph
+ if not self.config.is_sample_acceleration:
+ # convert to numpy for post process
+ logits = res[0] if isinstance(res, tuple) else res
+ if isinstance(logits, Tensor):
+ logits = logits.asnumpy()
+ logits = np.reshape(logits, (-1, logits.shape[-1]))
+ # need gather last seq logits using current_index
+ # compare length to determine if need gather; if not, gather should be done in model construct
+ if need_gather_logits and logits.shape[0] > len(current_index):
+ logits = logits[current_index]
+
+ # post process logits, without changing logits shape and order
+ probs = logits_processor(input_ids, logits, is_finished)
+ p_args = np.tile(np.arange(logits.shape[-1]), (batch_size, 1))
+ else:
+ probs, p_args = res
+ if isinstance(probs, Tensor):
+ probs = probs.asnumpy()
+ if isinstance(p_args, Tensor):
+ p_args = p_args.asnumpy()
+ search_time = time.time() - search_time
+
+ update_time = time.time()
+
+ # Random select a token as final output for this round
+ target_list = [[] for _ in range(batch_size)]
+ for i in range(batch_size):
+ if is_finished[i]:
+ continue
+
+ target_index = np.argmax(probs[i])
+
+ # get target token id
+ target = p_args[i][target_index]
+ input_ids[i, valid_length_each_example[i]] = target
+
+ if streamer is not None:
+ # assign target element
+ target_list[i] = [target]
+
+ if is_encoder_decoder:
+ target_mask[i][valid_length_each_example[i]] = int(1)
+
+ valid_length_each_example[i] += int(1)
+ input_mask[i][valid_length_each_example[i] - 1] = 1
+
+ # Stop judgment
+ if (
+ p_args[i][target_index] == generation_config.eos_token_id
+ or valid_length_each_example[i] == generation_config.max_length
+ ):
+ is_finished[i] = True
+ continue
+ if streamer is not None:
+ if batch_size == 1:
+ streamer.put(target_list[0])
+ else:
+ streamer.put(target_list)
+ update_time = time.time() - update_time
+ _logger.debug(
+ "forward time: %s s; greedy search time: %s s; update time: %s s; total count: %s s",
+ forward_time,
+ search_time,
+ update_time,
+ forward_time + search_time + update_time,
+ )
+
+ # Return valid outputs out of padded outputs
+ output_ids = []
+ for i in range(batch_size):
+ output_ids.append(input_ids[i, : int(valid_length_each_example[i])].astype(np.int32))
+ _logger.debug("The output is: %s", output_ids)
+ if streamer is not None:
+ streamer.end()
+
+ generate_len = np.sum(valid_length_each_example) - origin_len
+ total_time = time.time() - total_time
+ _logger.info(
+ "total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s",
+ total_time,
+ generate_len,
+ generate_len / total_time,
+ )
+
+ return output_ids
+
+ def _sample(
+ self,
+ origin_inputs,
+ generation_config: GenerationConfig,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ streamer=None,
+ **model_kwargs,
+ ):
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ Parameters:
+ origin_inputs (`List(str), List(List(str))`):
+ The sequence used as a prompt for the generation.
+ generation_config (`GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation
+ call. `**kwargs` passed to generate matching the attributes of `generation_config`
+ will override them. If `generation_config` is not provided, the default config
+ from the model configuration will be used. Please note that unspecified parameters
+ will inherit [`GenerationConfig`]'s default values, whose documentation should be
+ checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step.
+ streamer (`TextStreamer, *optional*`):
+ The streamer that generator uses.
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ A list of the generated token ids
+ """
+ total_time = time.time()
+ prepare_time = time.time()
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
+
+ if generation_config.pad_token_id is None:
+ generation_config.pad_token_id = 0
+
+ if streamer is not None:
+ streamer.put(origin_inputs)
+
+ batch_size = origin_inputs.shape[0]
+ is_encoder_decoder = self.config.is_encoder_decoder
+ _logger.debug("The input shape is: %s", origin_inputs.shape)
+ valid_length_each_example = []
+ for i in range(batch_size):
+ # As the nonzero returns the index and we need length
+ valid_length_each_example.append(
+ np.max(np.argwhere(origin_inputs[i] != generation_config.pad_token_id)) + 1
+ )
+ valid_length_each_example = np.array(valid_length_each_example)
+ _logger.debug("Get the valid for each example is: %s", valid_length_each_example)
+
+ # Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = np.max(valid_length_each_example)
+ if generation_config.max_new_tokens is not None:
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_length
+
+ if generation_config.max_length > self.config.seq_length:
+ _logger.warning(
+ "max_length %s can not exceeds model seq_length %s, set max_length = seq_length.",
+ generation_config.max_length,
+ self.config.seq_length,
+ )
+ generation_config.max_length = self.config.seq_length
+
+ _logger.debug("max length is: %s", generation_config.max_length)
+ if not is_encoder_decoder and input_ids_length >= generation_config.max_length:
+ raise ValueError(
+ f"the input_ids length {input_ids_length} exceeds the max length config {generation_config.max_length}."
+ f"check your inputs and set max_length larger than your inputs length."
+ )
+
+ input_ids = self._pad_inputs_using_max_length(
+ origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id
+ )
+
+ _logger.debug(
+ "pad the origin inputs from %s into shape: %s",
+ origin_inputs.shape,
+ input_ids.shape,
+ )
+
+ input_mask = np.zeros_like(input_ids)
+ for i in range(valid_length_each_example.shape[0]):
+ input_mask[i, : valid_length_each_example[i]] = 1
+ encoder_output = None
+ encoder_mask = None
+ if is_encoder_decoder:
+ if generation_config.max_length > self.config.max_decode_length:
+ generation_config.max_length = self.config.max_decode_length
+ _logger.debug("max decode length is: %s", generation_config.max_length)
+
+ # When do encoder and decoder prediction, the encoder can be cached
+ # to speed up the inference
+ (
+ encoder_output,
+ encoder_mask,
+ input_ids,
+ target_mask,
+ ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask)
+ valid_length_each_example = [1 for _ in range(batch_size)]
+ # A single loop generates one token, loop until reaching target
+ # model_origin_max_length or generating eod token
+ is_finished = [False] * batch_size
+
+ # update model kwargs once, before go into generate loop.
+ self.update_model_kwargs_before_generate(input_ids, model_kwargs)
+
+ # setup is_first_iteration flag for incremental infer
+ if generation_config.use_past:
+ self.is_first_iteration = True
+ need_gather_logits = True
+
+ origin_len = np.sum(valid_length_each_example)
+ prepare_time = time.time() - prepare_time
+ _logger.debug("forward prepare time: %s s", prepare_time)
+
+ while np.sum(is_finished) != batch_size:
+ forward_time = time.time()
+ seq_length = input_ids.shape[1]
+ current_index = [valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_size)]
+ _logger.debug("validate length: %s", valid_length_each_example)
+ if is_encoder_decoder:
+ inputs = Tensor(input_ids, mstype.int32)
+ # pylint: disable=E1102
+ res = self(
+ input_ids=None,
+ attention_mask=encoder_mask,
+ encoder_outputs=encoder_output,
+ decoder_input_ids=inputs,
+ decoder_attention_mask=Tensor(target_mask, mstype.float32),
+ )
+ else:
+ model_kwargs["current_index"] = current_index
+ # model prepare input dict
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # incremental generate
+ if generation_config.use_past:
+ # when first iteration, gather last logits; others keep all logits.
+ need_gather_logits = self.is_first_iteration
+ # incremental generate
+ res = self._incremental_infer(
+ model_inputs=model_inputs,
+ current_index=current_index,
+ valid_length_each_example=valid_length_each_example,
+ )
+ # auto-aggressive generate
+ else:
+ res = self(**model_inputs) # pylint: disable=E1102
+ forward_time = time.time() - forward_time
+
+ sample_time = time.time()
+ # post process logits; skip this phase if post process is done in graph
+ if not self.config.is_sample_acceleration:
+ # convert to numpy for post process
+ logits = res[0] if isinstance(res, tuple) else res
+ if isinstance(logits, Tensor):
+ logits = logits.asnumpy()
+ logits = np.reshape(logits, (-1, logits.shape[-1]))
+ # need gather last seq logits using current_index
+ # compare length to determine if need gather; if not, gather should be done in model construct
+ if need_gather_logits and logits.shape[0] > len(current_index):
+ logits = logits[current_index]
+
+ # post process logits, without changing logits shape and order
+ probs = logits_processor(input_ids, logits, is_finished)
+ probs = logits_warper(input_ids, probs, is_finished)
+ p_args = np.tile(np.arange(logits.shape[-1]), (batch_size, 1))
+ else:
+ probs, p_args = res
+ if isinstance(probs, Tensor):
+ probs = probs.asnumpy()
+ if isinstance(p_args, Tensor):
+ p_args = p_args.asnumpy()
+ sample_time = time.time() - sample_time
+
+ update_time = time.time()
+ p_norms = softmax_with_threads(probs, is_finished)
+
+ # Random select a token as final output for this round
+ target_list = [[] for _ in range(batch_size)]
+ for i in range(batch_size):
+ if is_finished[i]:
+ continue
+
+ p_norm = p_norms[i]
+ target_index = np.random.choice(len(probs[i]), p=p_norm)
+
+ # get target token id
+ target = p_args[i][target_index]
+ input_ids[i, valid_length_each_example[i]] = target
+
+ if streamer is not None:
+ # assign target element
+ target_list[i] = [target]
+
+ if is_encoder_decoder:
+ target_mask[i][valid_length_each_example[i]] = int(1)
+
+ valid_length_each_example[i] += int(1)
+ input_mask[i][valid_length_each_example[i] - 1] = 1
+
+ # Stop judgment
+ if (
+ p_args[i][target_index] == generation_config.eos_token_id
+ or valid_length_each_example[i] == generation_config.max_length
+ ):
+ is_finished[i] = True
+ continue
+ if streamer is not None:
+ if batch_size == 1:
+ streamer.put(target_list[0])
+ else:
+ streamer.put(target_list)
+ update_time = time.time() - update_time
+ _logger.debug(
+ "forward time: %s s; sample time: %s s; update time: %s s; total count: %s s",
+ forward_time,
+ sample_time,
+ update_time,
+ forward_time + sample_time + update_time,
+ )
+
+ # Return valid outputs out of padded outputs
+ output_ids = []
+ for i in range(batch_size):
+ output_ids.append(input_ids[i, : int(valid_length_each_example[i])].astype(np.int32))
+ _logger.debug("The output is: %s", output_ids)
+
+ if streamer is not None:
+ streamer.end()
+
+ generate_len = np.sum(valid_length_each_example) - origin_len
+ total_time = time.time() - total_time
+ _logger.info(
+ "total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s",
+ total_time,
+ generate_len,
+ generate_len / total_time,
+ )
+
+ return output_ids
+
+ def _beam_search(
+ self,
+ origin_inputs,
+ beam_scorer: BeamSearchScorer,
+ generation_config: GenerationConfig,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ streamer=None,
+ **model_kwargs,
+ ):
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ Parameters:
+ origin_inputs (`List(str), List(List(str))`):
+ The sequence used as a prompt for the generation.
+ beam_scorer (`BeamScorer`):
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
+ generation_config (`GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation
+ call. `**kwargs` passed to generate matching the attributes of `generation_config`
+ will override them. If `generation_config` is not provided, the default config
+ from the model configuration will be used. Please note that unspecified parameters
+ will inherit [`GenerationConfig`]'s default values, whose documentation should be
+ checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ streamer (`TextStreamer, *optional*`):
+ The streamer that generator uses.
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ A list of the generated token ids
+ """
+ if streamer is not None:
+ raise ValueError("Streamer does not support in beam search method yet!")
+ if generation_config.use_past:
+ raise ValueError("Beam search does not support incremental inference yet! Please set use_past to False.")
+ if self.config.is_sample_acceleration:
+ raise ValueError(
+ "Beam search does not support sample acceleration yet! Please set is_sample_acceleration to False."
+ )
+
+ total_time = time.time()
+ prepare_time = time.time()
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+
+ if generation_config.pad_token_id is None:
+ generation_config.pad_token_id = 0
+
+ batch_size = len(beam_scorer._beam_hyps) # pylint: disable=W0212
+ num_beams = beam_scorer.num_beams
+ batch_beam_size = origin_inputs.shape[0]
+ _logger.debug("The input shape is: %s", origin_inputs.shape)
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
+
+ is_encoder_decoder = self.config.is_encoder_decoder
+
+ # get the valid length of each example
+ valid_length_each_example = []
+ for i in range(batch_beam_size):
+ # As the nonzero returns the index and we need length
+ valid_length_each_example.append(
+ np.max(np.argwhere(origin_inputs[i] != generation_config.pad_token_id)) + 1
+ )
+ valid_length_each_example = np.array(valid_length_each_example)
+ _logger.debug("Get the valid for each example is: %s", valid_length_each_example)
+ if not is_encoder_decoder and np.max(valid_length_each_example) > generation_config.max_length:
+ raise ValueError(
+ "The max_length set is smaller than the length in the input_ids."
+ f"You shout set max_length to {np.max(valid_length_each_example)}"
+ )
+
+ target_length = (
+ self.config.seq_length
+ if generation_config.max_length > self.config.seq_length
+ else generation_config.max_length
+ )
+ _logger.debug("max target_length is: %s", target_length)
+ input_ids = self._pad_inputs_using_max_length(
+ origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id
+ )
+
+ _logger.debug(
+ "pad the origin inputs from %s into shape: %s",
+ origin_inputs.shape,
+ input_ids.shape,
+ )
+
+ beam_scores = np.zeros((batch_size, num_beams), dtype=np.float64)
+ beam_scores[:, 1:] = -1e9
+
+ input_mask = np.zeros_like(input_ids)
+ for i in range(valid_length_each_example.shape[0]):
+ input_mask[i, : valid_length_each_example[i]] = 1
+ encoder_output = None
+ encoder_mask = None
+ if is_encoder_decoder:
+ if target_length > self.config.max_decode_length:
+ target_length = self.config.max_decode_length
+ _logger.debug("target_length is: %s", target_length)
+
+ # When do encoder and decoder prediction, the encoder can be cached
+ # to speed up the inference
+ (
+ encoder_output,
+ encoder_mask,
+ input_ids,
+ target_mask,
+ ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask)
+ valid_length_each_example = np.ones((batch_beam_size, 1)).astype(np.int32)
+
+ # update model kwargs once, before go into generate loop.
+ self.update_model_kwargs_before_generate(input_ids, model_kwargs)
+
+ # setup is_first_iteration flag for incremental infer
+ if generation_config.use_past:
+ self.is_first_iteration = True
+ need_gather_logits = True
+
+ is_first_token = True
+
+ origin_len = np.sum(valid_length_each_example) / num_beams
+ prepare_time = time.time() - prepare_time
+ _logger.debug("forward prepare time: %s s", prepare_time)
+
+ while True:
+ forward_time = time.time()
+ seq_length = input_ids.shape[1]
+ current_index = [valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_beam_size)]
+ _logger.debug("validate length: %s", valid_length_each_example)
+ if is_encoder_decoder:
+ inputs = Tensor(input_ids, mstype.int32)
+ # pylint: disable=E1102
+ res = self(
+ input_ids=None,
+ attention_mask=encoder_mask,
+ encoder_outputs=encoder_output,
+ decoder_input_ids=inputs,
+ decoder_attention_mask=Tensor(target_mask, mstype.float32),
+ )
+ else:
+ model_kwargs["current_index"] = current_index
+ # model prepare input dict
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # incremental generate
+ if generation_config.use_past:
+ _logger.warning(
+ "Beam search currently not support incremental, auto-aggressive generate will be performed."
+ )
+ # auto-aggressive generate
+ res = self(**model_inputs)
+ forward_time = time.time() - forward_time
+
+ search_time = time.time()
+ # post process logits
+ # convert to numpy for post process
+ logits = res[0] if isinstance(res, tuple) else res
+ if isinstance(logits, Tensor):
+ logits = logits.asnumpy().astype(np.float32)
+ logits = np.reshape(logits, (-1, logits.shape[-1])) # (batch_size * num_beams * seq_length, vocab_size)
+ # need gather last seq logits using current_index
+ # compare length to determine if need gather; if not, gather should be done in model construct
+ if need_gather_logits and logits.shape[0] > len(current_index):
+ logits = logits[current_index] # (total_batch_size, vocab_size)
+ logits_processor.append(LogitNormalization())
+
+ # post process logits, without changing logits shape and order
+ next_token_scores = logits_processor(input_ids, logits) # (batch_size * num_beams, vocab_size)
+
+ # reshape for beam search
+ vocab_size = next_token_scores.shape[-1]
+ next_token_scores = np.reshape(next_token_scores, (batch_size, -1)) # (batch_size, num_beams * vocab_size)
+
+ if is_first_token:
+ next_token_scores = next_token_scores[:, :vocab_size]
+ is_first_token = False
+
+ # sample 2 next tokens for each beam, so we have at least 1 non eos token per beam
+ next_token_scores, next_tokens = topk(next_token_scores, 2 * num_beams, axis=1, largest=True, sort=True)
+
+ next_indices = np.floor_divide(next_tokens, vocab_size)
+ next_tokens = next_tokens % vocab_size
+
+ beam_outputs = beam_scorer.process(
+ input_ids, # (batch_size * num_beams, seq_length)
+ next_token_scores,
+ next_tokens,
+ next_indices,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ )
+ beam_scores = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
+ search_time = time.time() - search_time
+
+ update_time = time.time()
+ # reorder model inputs
+ old_input_ids = input_ids.copy()
+ for i in range(batch_beam_size):
+ input_ids[i] = old_input_ids[beam_idx[i], :]
+
+ # add new tokens to input_ids
+ for i in range(batch_beam_size):
+ input_ids[i, valid_length_each_example[i]] = beam_next_tokens[i]
+ if is_encoder_decoder:
+ target_mask[i][valid_length_each_example[i]] = int(1)
+
+ input_mask[i][valid_length_each_example[i]] = 1
+ valid_length_each_example[i] += int(1)
+
+ update_time = time.time() - update_time
+ _logger.debug(
+ "forward time: %s s; beam search time: %s s; update time: %s s; total count: %s s",
+ forward_time,
+ search_time,
+ update_time,
+ forward_time + search_time + update_time,
+ )
+
+ if beam_scorer.is_done or np.min(valid_length_each_example) >= generation_config.max_length:
+ break
+
+ sequence_outputs = beam_scorer.finalize(
+ input_ids,
+ beam_scores,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ max_length=generation_config.max_length,
+ )
+
+ generate_len = np.sum(valid_length_each_example) / num_beams - origin_len
+ total_time = time.time() - total_time
+ _logger.info(
+ "total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s",
+ total_time,
+ generate_len,
+ generate_len / total_time,
+ )
+
+ return sequence_outputs["sequences"]
+
+ def generate(
+ self,
+ input_ids: Optional[Union[List[int], List[List[int]]]],
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ streamer=None,
+ seed: Optional[int] = None,
+ **kwargs,
+ ):
+ origin_phase = self.phase
+ self.set_train(False)
+ try:
+ input_ids = np.array(input_ids)
+ except ValueError as e:
+ raise ValueError(
+ str(e) + " Please check your inputs of model.generate(),"
+ " and make sure the inputs are padded to same length."
+ )
+ input_ids = np.reshape(input_ids, (-1, np.shape(input_ids)[-1]))
+ batch_size = input_ids.shape[0]
+ seed = 0 if seed is None else seed
+ np.random.seed(seed)
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+
+ # Handle `generation_config` and kwargs that might update it
+ # priority: `generation_config` argument > `model.generation_config` (default config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation
+ # model attribute accordingly, if it was created from the model config
+ generation_config = GenerationConfig.from_model_config(self.config)
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+
+ if generation_config.num_beams > 1:
+ _logger.warning(
+ "When num_beams is set to a value greater than 1, do_sample will be set to False, "
+ "due to the current beam search does not support sampling."
+ )
+ generation_config.do_sample = False
+ if not generation_config.do_sample:
+ _logger.warning(
+ "When do_sample is set to False, top_k will be set to 1 and top_p will be set to 0, "
+ "making them inactive."
+ )
+ generation_config.top_p = 1.0
+ generation_config.top_k = 0
+ _logger.info("Generation Config is: %s", generation_config)
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ )
+
+ # determine generation mode
+ generation_mode = self._get_generation_mode(generation_config)
+
+ if streamer is not None and (generation_config.num_beams > 1):
+ raise ValueError("`streamer` cannot be used with beam search yet. Make sure that `num_beams` is set to 1.")
+
+ if generation_mode == GenerationMode.GREEDY_SEARCH:
+ # run greedy search
+ output_ids = self._greedy_search(
+ origin_inputs=input_ids,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ elif generation_mode == GenerationMode.SAMPLE:
+ # prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # run sample
+ output_ids = self._sample(
+ origin_inputs=input_ids,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ elif generation_mode == GenerationMode.BEAM_SEARCH:
+ # prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size, num_beams=generation_config.num_beams, max_length=generation_config.max_length
+ )
+ # interleave input_ids with `num_beams` additional sequences per batch
+ input_ids = np.repeat(input_ids, generation_config.num_beams, 0)
+
+ # run beam search
+ output_ids = self._beam_search(
+ origin_inputs=input_ids,
+ beam_scorer=beam_scorer,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ # set to original phase
+ self.set_train(origin_phase == "train")
+ return output_ids
diff --git a/mindocr/nlp/generation/utils.py b/mindocr/nlp/generation/utils.py
new file mode 100644
index 000000000..298891cdd
--- /dev/null
+++ b/mindocr/nlp/generation/utils.py
@@ -0,0 +1,72 @@
+"""utils for text generation."""
+
+from threading import Thread
+
+import numpy as np
+
+
+def log_softmax(x, axis=None):
+ """numpy implemented log softmax function.
+ refers to https://github.com/scipy/scipy/blob/v1.11.1/scipy/special/_logsumexp.py"""
+ x_max = np.amax(x, axis=axis, keepdims=True)
+
+ if x_max.ndim > 0:
+ x_max[~np.isfinite(x_max)] = 0
+ elif not np.isfinite(x_max):
+ x_max = 0
+
+ tmp = x - x_max
+ exp_tmp = np.exp(tmp)
+
+ # suppress warnings about log of zero
+ with np.errstate(divide="ignore"):
+ s = np.sum(exp_tmp, axis=axis, keepdims=True)
+ out = np.log(s)
+
+ out = tmp - out
+ return out
+
+
+def softmax(x, axis=None):
+ """numpy implemented softmax function.
+ refers to https://github.com/scipy/scipy/blob/v1.11.1/scipy/special/_logsumexp.py"""
+ x_max = np.amax(x, axis=axis, keepdims=True)
+ exp_x_shifted = np.exp(x - x_max)
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
+
+
+def softmax_single(i, res, x):
+ res[i] = softmax(x)
+
+
+def softmax_with_threads(x, is_finished=None):
+ """calculate softmax with threads"""
+ res = np.ones_like(x)
+ all_threads = []
+ for i in range(0, res.shape[0]):
+ if is_finished and is_finished[i]:
+ continue
+ thread = Thread(target=softmax_single, args=(i, res, x[i]))
+ all_threads.append(thread)
+ thread.start()
+ for thread in all_threads:
+ thread.join()
+ return res
+
+
+def topk(x, top_k, axis=-1, largest=True, sort=True):
+ """numpy implemented topk sample."""
+ # safety check
+ if x.shape[axis] < top_k:
+ top_k = x.shape[axis] - 1
+ if largest:
+ topk_index = np.argpartition(-x, top_k, axis=axis)
+ else:
+ topk_index = np.argpartition(x, top_k, axis=axis)
+ topk_index = np.take(topk_index, np.arange(top_k), axis=axis)
+ topk_data = np.take_along_axis(x, topk_index, axis=axis)
+ if sort:
+ sort_index = np.argsort(-topk_data, axis=axis) if largest else np.argsort(topk_data, axis=axis)
+ topk_data = np.take_along_axis(topk_data, sort_index, axis=axis)
+ topk_index = np.take_along_axis(topk_index, sort_index, axis=axis)
+ return topk_data, topk_index
diff --git a/mindocr/nlp/llm/__init__.py b/mindocr/nlp/llm/__init__.py
new file mode 100644
index 000000000..ce614c47a
--- /dev/null
+++ b/mindocr/nlp/llm/__init__.py
@@ -0,0 +1,3 @@
+from ._registry import register_llm
+from .builder import build_llm_model
+from .vary_qwen_model import VaryQwenForCausalLM
diff --git a/mindocr/nlp/llm/_registry.py b/mindocr/nlp/llm/_registry.py
new file mode 100644
index 000000000..b9ae12e62
--- /dev/null
+++ b/mindocr/nlp/llm/_registry.py
@@ -0,0 +1,72 @@
+"""llm registry and list"""
+
+__all__ = [
+ "list_llms",
+ "is_llm",
+ "llm_entrypoint",
+ "list_llm_classes",
+ "is_llm_class",
+ "llm_class_entrypoint",
+ "register_llm",
+]
+
+_llm_entrypoints = {}
+_llm_class_entrypoints = {}
+
+
+def register_llm(fn):
+ # add llm to __all__ in module
+ llm_name = fn.__name__
+ # add entries to registry dict/sets
+ _llm_entrypoints[llm_name] = fn
+
+ return fn
+
+
+def list_llms():
+ all_llms = _llm_entrypoints.keys()
+
+ return sorted(list(all_llms))
+
+
+def is_llm(llm_name):
+ """
+ Check if a llm name exists
+ """
+ return llm_name in _llm_entrypoints
+
+
+def llm_entrypoint(llm_name):
+ """
+ Fetch a llm entrypoint for specified llm name
+ """
+ return _llm_entrypoints[llm_name]
+
+
+def register_llm_class(cls):
+ # add llm to __all__ in module
+ llm_class_name = cls.__name__
+ # add entries to registry dict/sets
+ _llm_class_entrypoints[llm_class_name] = cls
+
+ return cls
+
+
+def list_llm_classes():
+ all_llm_classes = _llm_class_entrypoints.keys()
+
+ return sorted(list(all_llm_classes))
+
+
+def is_llm_class(llm_class_name):
+ """
+ Check if a llm name exists
+ """
+ return llm_class_name in _llm_class_entrypoints
+
+
+def llm_class_entrypoint(llm_class_name):
+ """
+ Fetch a llm entrypoint for specified llm name
+ """
+ return _llm_class_entrypoints[llm_class_name]
diff --git a/mindocr/nlp/llm/base_llm_model.py b/mindocr/nlp/llm/base_llm_model.py
new file mode 100644
index 000000000..d18279cdc
--- /dev/null
+++ b/mindocr/nlp/llm/base_llm_model.py
@@ -0,0 +1,83 @@
+"""BaseModel"""
+import os
+
+from mindspore import nn
+from mindspore.train.serialization import load_checkpoint, load_param_into_net
+
+from mindocr.nlp.generation import GeneratorMixin
+from mindocr.nlp.llm.builder import build_llm_model
+from mindocr.nlp.llm.configs import BaseConfig, LLMConfig
+
+
+class BaseLLMModel(nn.Cell, GeneratorMixin):
+ """
+ The base model that contains the class method `from_pretrained` and `save_pretrained`, any new model that should
+ inherit the class.
+
+ Note:
+ GeneratorMixin provides the method `generate` that enable the generation for nlp models.
+
+ Args:
+ config(BaseConfig): The model configuration that inherits the `BaseConfig`.
+ """
+
+ def __init__(self, config: BaseConfig, **kwargs):
+ super(BaseLLMModel, self).__init__(**kwargs)
+ self.config = config
+
+ def load_checkpoint(self, config):
+ """
+ load checkpoint for models.
+
+ Args:
+ config (ModelConfig): a model config instance, which could have attribute
+ "checkpoint_name_or_path (str)". set checkpoint_name_or_path to a supported
+ model name or a path to checkpoint, to load model weights.
+ """
+ checkpoint_name_or_path = config.checkpoint_name_or_path
+ if checkpoint_name_or_path:
+ if not isinstance(checkpoint_name_or_path, str):
+ raise TypeError(f"checkpoint_name_or_path should be a str, but got {type(checkpoint_name_or_path)}")
+
+ if os.path.exists(checkpoint_name_or_path):
+ param = load_checkpoint(checkpoint_name_or_path)
+ else:
+ raise ValueError(
+ f"{checkpoint_name_or_path} is not a supported default model"
+ f" or a valid path to checkpoint,"
+ f" please select from {self._support_list}."
+ )
+
+ load_param_into_net(self, param)
+
+ @classmethod
+ def _get_config_args(cls, pretrained_model_name_or_dir, **kwargs):
+ """build config args."""
+ is_dir = os.path.isdir(pretrained_model_name_or_dir)
+
+ if is_dir:
+ yaml_list = [file for file in os.listdir(pretrained_model_name_or_dir) if file.endswith(".yaml")]
+ yaml_list.sort()
+ config_args = None
+ for yaml_file in yaml_list:
+ if config_args is None:
+ config_args = LLMConfig(yaml_file)
+ else:
+ sub_config_args = LLMConfig(yaml_file)
+ config_args.model.update(**sub_config_args)
+ config_args.model.update(**kwargs)
+ else:
+ yaml_file = pretrained_model_name_or_dir
+ config_args = LLMConfig(yaml_file)
+ config_args.model.update(**kwargs)
+ return config_args
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_dir: str, **kwargs):
+ if not isinstance(pretrained_model_name_or_dir, str):
+ raise TypeError(
+ f"pretrained_model_name_or_dir should be a str, but got {type(pretrained_model_name_or_dir)}"
+ )
+ config_args = cls._get_config_args(pretrained_model_name_or_dir, **kwargs)
+ model = build_llm_model(config_args.model)
+ return model
diff --git a/mindocr/nlp/llm/builder.py b/mindocr/nlp/llm/builder.py
new file mode 100644
index 000000000..490311b09
--- /dev/null
+++ b/mindocr/nlp/llm/builder.py
@@ -0,0 +1,29 @@
+from ._registry import is_llm, is_llm_class, list_llms, llm_class_entrypoint, llm_entrypoint
+
+__all__ = ["build_llm_model"]
+
+
+def build_llm_model(config):
+ """
+
+ Example:
+ >>> from mindocr.nlp.llm import build_llm_model
+ >>> llm_model = build_llm_model(dict(name='VaryQwenForCausalLM'))
+ >>> print(llm_model)
+ """
+ if "name" not in config:
+ raise ValueError("name must in `config`.")
+ name = config["name"]
+ if is_llm(name):
+ create_fn = llm_entrypoint(name)
+ llm = create_fn(config)
+ elif is_llm_class(name):
+ llm_class = llm_class_entrypoint(name)
+ llm = llm_class(config)
+ else:
+ raise ValueError(f"Invalid llm name: {name}, supported llms are: {list_llms()}")
+
+ if "checkpoint_name_or_path" in config:
+ llm.load_checkpoint(config)
+
+ return llm
diff --git a/mindocr/nlp/llm/configs.py b/mindocr/nlp/llm/configs.py
new file mode 100644
index 000000000..76dba33cd
--- /dev/null
+++ b/mindocr/nlp/llm/configs.py
@@ -0,0 +1,318 @@
+import copy
+import os
+
+import yaml
+
+import mindspore.common.dtype as mstype
+
+
+def convert_mstype(ms_type: str = "float16"):
+ """Convert the string type to MindSpore type."""
+ if isinstance(ms_type, mstype.Float):
+ return ms_type
+ if ms_type == "float16":
+ return mstype.float16
+ if ms_type == "bfloat16":
+ return mstype.bfloat16
+ if ms_type == "float32":
+ return mstype.float32
+ raise KeyError(f"Supported data type keywords include: [float16, float32, bfloat16], but get {ms_type}")
+
+
+class LLMConfig(dict):
+ def __init__(self, *args, **kwargs):
+ super(LLMConfig, self).__init__()
+ cfg_dict = {}
+
+ # load from file
+ for arg in args:
+ if isinstance(arg, str):
+ if arg.endswith("yaml") or arg.endswith("yml"):
+ raw_dict = LLMConfig._file2dict(arg)
+ cfg_dict.update(raw_dict)
+
+ # load dictionary configs
+ if kwargs is not None:
+ cfg_dict.update(kwargs)
+
+ LLMConfig._dict2config(self, cfg_dict)
+
+ def __getattr__(self, key):
+ """Get a object attr by its `key`
+
+ Args:
+ key (str) : the name of object attr.
+
+ Returns:
+ attr of object that name is `key`
+ """
+ if key not in self:
+ return None
+
+ return self[key]
+
+ def __setattr__(self, key, value):
+ """Set a object value `key` with `value`
+
+ Args:
+ key (str) : The name of object attr.
+ value : the `value` need to set to the target object attr.
+ """
+ self[key] = value
+
+ def __delattr__(self, key):
+ """Delete a object attr by its `key`.
+
+ Args:
+ key (str) : The name of object attr.
+ """
+ del self[key]
+
+ def __deepcopy__(self):
+ """Deep copy operation on arbitrary LLMConfig objects.
+
+ Returns:
+ LLMConfig : The deep copy of the given LLMConfig object.
+ """
+ config = LLMConfig()
+ for key in self.keys():
+ config.__setattr__(copy.deepcopy(key), copy.deepcopy(self.__getattr__(key)))
+ return config
+
+ @staticmethod
+ def _file2dict(filename=None):
+ """Convert config file to dictionary.
+
+ Args:
+ filename (str) : config file.
+ """
+ if filename is None:
+ raise NameError("This {} cannot be empty.".format(filename))
+
+ filepath = os.path.realpath(filename)
+ with open(filepath, encoding="utf-8") as fp:
+ cfg_dict = yaml.load(fp, yaml.Loader)
+
+ return cfg_dict
+
+ @staticmethod
+ def _dict2config(config, dic):
+ """Convert dictionary to config.
+
+ Args:
+ config : Config object
+ dic (dict) : dictionary
+ Returns:
+
+ Exceptions:
+
+ """
+ if isinstance(dic, dict):
+ for key, value in dic.items():
+ if isinstance(value, dict):
+ sub_config = LLMConfig()
+ dict.__setitem__(config, key, sub_config)
+ LLMConfig._dict2config(sub_config, value)
+ else:
+ config[key] = dic[key]
+
+
+class BaseConfig(dict):
+ def __init__(self, **kwargs):
+ super(BaseConfig, self).__init__()
+ self.update(kwargs)
+
+ def __getattr__(self, key):
+ if key not in self:
+ return None
+ return self[key]
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ del self[key]
+
+ @classmethod
+ def from_pretrained(cls, yaml_name_or_path, **kwargs):
+ """
+ From pretrain method, which instantiates a config by yaml name or path.
+
+ Args:
+ yaml_name_or_path (str): A supported model path to model config (.yaml).
+
+ Returns:
+ A model config, which inherited from BaseConfig.
+ """
+ pretrained_model_name_or_path = kwargs.pop("pretrained_model_name_or_path", None)
+ if pretrained_model_name_or_path is not None:
+ yaml_name_or_path = pretrained_model_name_or_path
+
+ if not isinstance(yaml_name_or_path, str):
+ raise TypeError(f"yaml_name_or_path should be a str, but got {type(yaml_name_or_path)}.")
+
+ if os.path.exists(yaml_name_or_path):
+ if not yaml_name_or_path.endswith(".yaml"):
+ raise ValueError(f"{yaml_name_or_path} should be a .yaml file for model config.")
+
+ config_args = LLMConfig(yaml_name_or_path)
+ else:
+ raise ValueError(f"{yaml_name_or_path} is not a supported model type or a valid path to model config.")
+ config_args.model.update(**kwargs)
+ config = config_args.model
+ return config
+
+
+class QwenConfig(BaseConfig):
+ def __init__(
+ self,
+ batch_size: int = 1,
+ seq_length: int = 2048,
+ hidden_size: int = 4096,
+ num_layers: int = 32,
+ num_heads: int = 32,
+ n_kv_heads: int = None,
+ max_position_embedding: int = None,
+ intermediate_size: int = None,
+ vocab_size: int = 32000, # defined later by tokenizer
+ multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2
+ ffn_dim_multiplier: int = None,
+ rms_norm_eps: float = 1e-5,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ pad_token_id: int = 0,
+ ignore_token_id: int = -100,
+ theta: float = 10000.0,
+ compute_dtype: str = "float16",
+ layernorm_compute_type: str = "float32",
+ softmax_compute_type: str = "float32",
+ rotary_dtype: str = "float32",
+ param_init_type: str = "float16",
+ ln_param_init_type: str = "float32",
+ qkv_has_bias: bool = False,
+ qkv_concat: bool = False,
+ use_past: bool = False,
+ pretrain_seqlen=None,
+ extend_method: str = "None",
+ scaling_factor: float = 1.0,
+ is_dynamic: bool = False,
+ use_kvcache_op: bool = False,
+ is_flexible_shape: bool = False,
+ use_rope_slice: bool = False,
+ use_flash_attention: bool = False,
+ use_paged_attention: bool = False,
+ fine_grain_interleave: int = 1,
+ offset: int = 0,
+ checkpoint_name_or_path: str = "",
+ repetition_penalty: float = 1.0,
+ max_decode_length: int = 1024,
+ block_size: int = 16,
+ num_blocks: int = 512,
+ top_k: int = 5,
+ top_p: float = 1.0,
+ do_sample: bool = True,
+ **kwargs,
+ ):
+ super(QwenConfig, self).__init__(**kwargs)
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length
+ self.intermediate_size = intermediate_size
+ self.multiple_of = multiple_of
+ self.n_kv_heads = n_kv_heads
+ self.ffn_dim_multiplier = ffn_dim_multiplier
+ self.rms_norm_eps = rms_norm_eps
+ self.qkv_concat = qkv_concat
+ self.param_init_type = convert_mstype(param_init_type)
+ self.qkv_has_bias = qkv_has_bias
+ self.layernorm_compute_type = convert_mstype(layernorm_compute_type)
+ self.softmax_compute_type = convert_mstype(softmax_compute_type)
+ self.rotary_dtype = convert_mstype(rotary_dtype)
+ self.compute_dtype = convert_mstype(compute_dtype)
+ self.ln_param_init_type = convert_mstype(ln_param_init_type)
+ self.checkpoint_name_or_path = checkpoint_name_or_path
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.pad_token_id = pad_token_id
+ self.ignore_token_id = ignore_token_id
+ self.use_past = use_past
+ self.pretrain_seqlen = pretrain_seqlen
+ self.extend_method = extend_method
+ self.scaling_factor = scaling_factor
+ self.is_dynamic = is_dynamic
+ self.use_kvcache_op = use_kvcache_op
+ self.is_flexible_shape = is_flexible_shape
+ self.use_rope_slice = use_rope_slice
+ self.use_flash_attention = use_flash_attention
+ self.fine_grain_interleave = fine_grain_interleave
+ self.offset = offset
+ self.repetition_penalty = repetition_penalty
+ self.max_decode_length = max_decode_length
+ self.top_k = top_k
+ self.top_p = top_p
+ self.do_sample = do_sample
+ self.theta = theta
+ self.use_paged_attention = use_paged_attention
+ self.block_size = block_size
+ self.num_blocks = num_blocks
+
+
+class VaryConfig(QwenConfig):
+ def __init__(self, **kwargs):
+ super(VaryConfig, self).__init__(**kwargs)
+
+
+class SAMConfig(BaseConfig):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: int = 4,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ layer_norm_eps: float = 1.0e-6,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = True,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 14,
+ global_attn_indexes: tuple = (2, 5, 8, 11),
+ checkpoint_name_or_path: str = "",
+ compute_dtype: str = "float16",
+ layernorm_compute_type: str = "float32",
+ softmax_compute_type: str = "float16",
+ param_init_type: str = "float16",
+ ln_param_init_type: str = "float32",
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.depth = depth
+ self.num_heads = num_heads
+ self.mlp_ratio = mlp_ratio
+ self.out_chans = out_chans
+ self.qkv_bias = qkv_bias
+ self.layer_norm_eps = layer_norm_eps
+ self.use_abs_pos = use_abs_pos
+ self.use_rel_pos = use_rel_pos
+ self.rel_pos_zero_init = rel_pos_zero_init
+ self.window_size = window_size
+ self.global_attn_indexes = global_attn_indexes
+
+ self.param_init_type = convert_mstype(param_init_type)
+ self.layernorm_compute_type = convert_mstype(layernorm_compute_type)
+ self.softmax_compute_type = convert_mstype(softmax_compute_type)
+ self.compute_dtype = convert_mstype(compute_dtype)
+ self.ln_param_init_type = convert_mstype(ln_param_init_type)
+
+ self.checkpoint_name_or_path = checkpoint_name_or_path
diff --git a/mindocr/nlp/llm/convert_weight.py b/mindocr/nlp/llm/convert_weight.py
new file mode 100644
index 000000000..eabe1d223
--- /dev/null
+++ b/mindocr/nlp/llm/convert_weight.py
@@ -0,0 +1,114 @@
+"""Convert Vary Toy weight."""
+
+import argparse
+
+import torch
+
+import mindspore as ms
+
+ATTENTION_WEIGHT_NAME = "attn.c_attn.weight"
+ATTENTION_BIAS_NAME = "attn.c_attn.bias"
+
+
+def pt2ms(value: torch.Tensor, dtype) -> ms.Tensor:
+ """
+ convert torch.Tensor to ms.Tensor with specified dtype
+ """
+ if value.dtype == torch.bfloat16:
+ np_value = value.to(torch.float32).numpy()
+ else:
+ np_value = value.detach().numpy()
+
+ if dtype:
+ return ms.Tensor(np_value, dtype=dtype)
+ return ms.Tensor(np_value, dtype=ms.bfloat16) if value.dtype == torch.bfloat16 else ms.Tensor(np_value)
+
+
+def _name_replace(name: str):
+ # qwen
+ name = name.replace(".h.", ".layers.")
+ name = name.replace(".wte.weight", ".wte.embedding_weight")
+ name = name.replace("attn.c_proj.", "attention.wo.")
+ name = name.replace("ln_1.", "attention_norm.")
+ name = name.replace("ln_2.", "ffn_norm.")
+ name = name.replace("mlp.w1.", "feed_forward.w1.")
+ name = name.replace("mlp.w2.", "feed_forward.w3.")
+ name = name.replace("mlp.c_proj.", "feed_forward.w2.")
+
+ # clip
+ name = name.replace("vision_model.", "")
+ name = name.replace("embeddings.", "")
+ name = name.replace("patch_embedding.", "conv1.")
+ name = name.replace("position_embedding.weight", "positional_embedding")
+ name = name.replace("pre_layrnorm.weight", "ln_pre.gamma")
+ name = name.replace("pre_layrnorm.bias", "ln_pre.beta")
+ name = name.replace("encoder.layers", "transformer.resblocks")
+ name = name.replace("layer_norm1.weight", "ln_1.gamma")
+ name = name.replace("layer_norm1.bias", "ln_1.beta")
+ name = name.replace("fc1", "c_fc")
+ name = name.replace("fc2", "c_proj")
+ name = name.replace("layer_norm2.weight", "ln_2.gamma")
+ name = name.replace("layer_norm2.bias", "ln_2.beta")
+ name = name.replace("self_attn", "attn")
+ name = name.replace("post_layernorm", "vision_model.post_layernorm")
+
+ # sam
+ name = name.replace("norm1.weight", "norm1.gamma")
+ name = name.replace("norm1.bias", "norm1.beta")
+ name = name.replace("norm2.weight", "norm2.gamma")
+ name = name.replace("norm2.bias", "norm2.beta")
+ return name
+
+
+def convert_attention_weight(name, value, ckpt_weights):
+ split_value = ms.numpy.array_split(value, 3)
+ attention_weight_names = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight"]
+
+ for index in range(len(split_value)):
+ cur_name = name.replace(ATTENTION_WEIGHT_NAME, attention_weight_names[index])
+ ckpt_weights.append({"name": cur_name, "data": ms.Tensor(split_value[index])})
+
+
+def convert_attention_bias(name, value, ckpt_weights):
+ split_value = ms.numpy.array_split(value, 3)
+ attention_bias_names = ["attention.wq.bias", "attention.wk.bias", "attention.wv.bias"]
+
+ for index in range(len(split_value)):
+ cur_name = name.replace(ATTENTION_BIAS_NAME, attention_bias_names[index])
+ ckpt_weights.append({"name": cur_name, "data": ms.Tensor(split_value[index])})
+
+
+def convert_pt_to_ms(torch_ckpt_path, output_path, dtype=ms.float16):
+ state_dict = torch.load(torch_ckpt_path, map_location="cpu")
+ ckpt_weights = []
+ for k, v in state_dict.items():
+ value = pt2ms(v, dtype)
+
+ msname = _name_replace(k)
+
+ if msname != k:
+ print("name: %s->%s" % (k, msname))
+
+ if ATTENTION_WEIGHT_NAME in msname:
+ convert_attention_weight(msname, value, ckpt_weights)
+ continue
+
+ if ATTENTION_BIAS_NAME in msname:
+ convert_attention_bias(msname, value, ckpt_weights)
+ continue
+
+ ckpt_weights.append({"name": msname, "data": value})
+
+ print("Saving converted weights to %s..." % output_path)
+ ms.save_checkpoint(ckpt_weights, output_path)
+ print("Done")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Vary convert script")
+ parser.add_argument("--torch_ckpt_path", required=True, help="The torch checkpoint path.")
+ parser.add_argument("--mindspore_ckpt_path", default="./vary_toy.ckpt", help="The output checkpoint path.")
+
+ args = parser.parse_args()
+
+ convert_pt_to_ms(args.torch_ckpt_path, args.mindspore_ckpt_path, ms.float16)
diff --git a/mindocr/nlp/llm/qwen_model.py b/mindocr/nlp/llm/qwen_model.py
new file mode 100644
index 000000000..1e1bee7b8
--- /dev/null
+++ b/mindocr/nlp/llm/qwen_model.py
@@ -0,0 +1,1139 @@
+from enum import Enum
+from typing import Optional, Tuple
+
+import numpy as np
+
+import mindspore as ms
+import mindspore.common.dtype as mstype
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore._c_expression import MSContext
+from mindspore.common.initializer import initializer
+from mindspore.nn.layer.flash_attention import FlashAttention
+
+from mindocr.nlp.llm.base_llm_model import BaseLLMModel
+from mindocr.nlp.llm.configs import QwenConfig
+from mindocr.nlp.utils.kvcache_mgr import KVCacheMgr, KVCachePreprocess
+from mindocr.nlp.utils.layers import Linear
+from mindocr.nlp.utils.loss import CrossEntropyLoss
+
+
+def is_910a():
+ device = MSContext.get_instance().get_ascend_soc_version()
+ return device in ["910a", "ascend910"]
+
+
+class SeqExtendMethod(Enum):
+ """Stores the acceptable string identifiers for seq length extend method"""
+
+ PI = "PI"
+ NTK = "NTK"
+ NONE = "None"
+
+
+class LlamaEmbedding(nn.Cell):
+ """
+ Embedding Layer.
+
+ Args:
+ - **vocab_size** (int): Size of the dictionary of embeddings.
+ - **embedding_size** (int): The size of each embedding vector.
+ - **param_init_type** (mstype): The param init type, default mstype.float32.
+ - **param_init** (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
+ Refer to class `initializer` for the values of string when a string
+ is specified. Default: "normal".
+ Inputs:
+ - **input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length)
+
+ Outputs:
+ - **output** (Tensor) - The embedding vector for the input with shape (batch_size,
+ seq_length, embedding_size).
+ """
+
+ def __init__(
+ self,
+ vocab_table_size,
+ embedding_size,
+ param_init_type=mstype.float32,
+ param_init="normal",
+ parallel_optimizer=False,
+ ):
+ super().__init__()
+ self.vocab_table_size = vocab_table_size
+ self.embedding_size = embedding_size
+ self.embedding_weight = Parameter(
+ initializer(param_init, [self.vocab_table_size, self.embedding_size], dtype=param_init_type),
+ name="embedding_weight",
+ parallel_optimizer=parallel_optimizer,
+ )
+ self.gather = ops.Gather()
+
+ def construct(self, input_ids):
+ """Forward of vocab embedding."""
+ output = self.gather(self.embedding_weight, input_ids, 0)
+ return output
+
+
+class FreqsMgr(nn.Cell):
+ r"""freqs_cis manager."""
+
+ def __init__(
+ self,
+ head_dim,
+ seq_length=None,
+ max_position_embedding=4096,
+ rotary_dtype=mstype.float16,
+ theta=10000.0,
+ scaling_factor=1.0,
+ extend_method=SeqExtendMethod.NONE.value,
+ is_dynamic=False,
+ ):
+ super().__init__()
+ if seq_length is not None and seq_length > max_position_embedding:
+ max_position_embedding = seq_length
+ if extend_method == SeqExtendMethod.NTK.value:
+ theta *= scaling_factor
+ freqs_base = np.arange(0, head_dim, 2)[: (head_dim // 2)].astype(np.float32) # (head_dim // 2, )
+ freqs = 1.0 / (theta ** (freqs_base / head_dim)) # (head_dim // 2, )
+ if extend_method == SeqExtendMethod.PI.value:
+ t = np.arange(0, max_position_embedding / scaling_factor, 1 / scaling_factor).astype(np.float32)
+ else:
+ t = np.arange(0, max_position_embedding, 1).astype(np.float32)
+ freqs = np.outer(t, freqs) # (max_position_embedding, head_dim // 2)
+ emb = np.concatenate((freqs, freqs), axis=-1)
+ freqs_cos = np.cos(emb) # (seq_len, head_dim)
+ freqs_sin = np.sin(emb) # (seq_len, head_dim)
+ swap_mask = FreqsMgr.get_swap_mask(head_dim)
+
+ self.head_dim = head_dim
+ self.seq_length = max_position_embedding if seq_length is None else seq_length
+ self.is_dynamic = is_dynamic
+ self.freqs_cos = Tensor(freqs_cos, dtype=rotary_dtype)
+ self.freqs_sin = Tensor(freqs_sin, dtype=rotary_dtype)
+ self.swap_mask = Tensor(swap_mask, dtype=rotary_dtype)
+
+ self.reshape = ops.Reshape()
+ if is_dynamic:
+ self.reshape.add_prim_attr("skip_redistribution", True)
+ self.slice = ops.StridedSlice()
+ self.sub = ops.Sub()
+ self.gather = ops.Gather()
+
+ def construct(self, seq_length=None):
+ freqs_cos, freqs_sin = self.freqs_cos, self.freqs_sin
+ seqlen = seq_length if self.is_dynamic else self.seq_length
+ freqs_cos = self.slice(freqs_cos, (0, 0), (seqlen, self.head_dim), (1, 1))
+ freqs_sin = self.slice(freqs_sin, (0, 0), (seqlen, self.head_dim), (1, 1))
+ return freqs_cos, freqs_sin, self.swap_mask
+
+ def increment(self, batch_valid_length, batch_size):
+ freqs_cos = self.reshape(self.gather(self.freqs_cos, batch_valid_length, 0), (batch_size, 1, 1, self.head_dim))
+ freqs_sin = self.reshape(self.gather(self.freqs_sin, batch_valid_length, 0), (batch_size, 1, 1, self.head_dim))
+ return freqs_cos, freqs_sin, self.swap_mask
+
+ @staticmethod
+ def get_swap_mask(head_dim):
+ """Swap matrix"""
+ zero_block = np.zeros((head_dim // 2, head_dim // 2), dtype=np.float32)
+ id_block = np.identity(head_dim // 2, dtype=np.float32)
+ return np.block([[zero_block, id_block], [-id_block, zero_block]])
+
+
+class LlamaSiLU(nn.Cell):
+ def construct(self, x):
+ return ops.silu(x)
+
+
+class LlamaFeedForward(nn.Cell):
+ r"""
+ LLaMA FeedForward.
+
+ .. math::
+ (xW_1 * xW_3)W_2
+
+ Inputs:
+ - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
+ Float tensor.
+
+ Outputs:
+ Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
+ [batch * seq_length, hidden_size]`.
+
+ Raises:
+ ValueError: `hidden_dim` is not a multiple of the model parallel way.
+ ValueError: `dim` is not a multiple of the model parallel way.
+ """
+
+ def __init__(
+ self,
+ dim,
+ intermediate_size=None,
+ hidden_dim=None,
+ multiple_of=256,
+ hidden_act=LlamaSiLU,
+ ffn_dim_multiplier=None,
+ compute_dtype=mstype.float16,
+ param_init_type=mstype.float32,
+ is_dynamic=False,
+ ):
+ super().__init__()
+
+ if hidden_act is None or not (isinstance(hidden_act, str) or issubclass(hidden_act, nn.Cell)):
+ raise TypeError(
+ f"For FeedForward cell, the hidden_act should str type or nn.Cell type, but got {hidden_act}."
+ )
+
+ if intermediate_size is not None:
+ hidden_dim = intermediate_size
+ else:
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int((ffn_dim_multiplier + 0.01) * hidden_dim)
+ hidden_dim = int(2 * hidden_dim / 3)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.dtype = compute_dtype
+ self.hidden_act = hidden_act
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ self.mul = ops.Mul()
+ self.cast = ops.Cast()
+ self.w1 = Linear(
+ in_channels=dim,
+ out_channels=hidden_dim,
+ activation=hidden_act,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ self.w2 = Linear(
+ in_channels=hidden_dim,
+ out_channels=dim,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ self.w3 = Linear(
+ in_channels=dim,
+ out_channels=hidden_dim,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ def construct(self, x):
+ """Forward process of the FeedForward"""
+ x = self.cast(x, self.dtype)
+ # [bs, seq, hidden_dim] or [bs * seq, hidden_dim]
+ gate = self.w1(x) # dp,1 -> dp, mp
+ hidden = self.w3(x) # dp,1 -> dp, mp
+ hidden = self.mul(hidden, gate) # dp,mp -> dp, mp
+ output = self.w2(hidden) # dp,mp -> dp, 1
+ return output
+
+
+class LlamaRotaryEmbedding(nn.Cell):
+ r"""
+ Rotary Position Embedding.
+
+ Args:
+ - **head_dim** (int): The dim of multi head attention.
+ - **compute_dtype** (mstype): The compute type, default mstype.float16.
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
+
+ Outputs:
+ Tensor of shape :math:`(batch, seq_length, hidden_size)`.
+ """
+
+ def __init__(self, head_dim=128, compute_dtype=mstype.float32, use_rope_slice=False):
+ super().__init__(auto_prefix=False)
+ self.half_head_dim = head_dim // 2
+ self.head_dim = head_dim
+ self.dtype = compute_dtype
+ self.use_rope_slice = use_rope_slice
+ self.is_first_iteration = True
+
+ self.add = ops.Add()
+ self.bmm_swap = ops.BatchMatMul()
+ self.mul = ops.Mul()
+ self.mul_inc = ops.Mul()
+ self.neg = ops.Neg()
+ self.slice = ops.StridedSlice()
+ self.concat = ops.Concat(axis=-1)
+ self.shape = ops.Shape()
+
+ self.is_ascend = ms.get_context("device_target") == "Ascend"
+
+ def rotate_half(self, x, swap_mask):
+ # [bs, n_head/n_kv_head, seq/1, head_dim], [head_dim, head_dim]
+ if self.is_ascend:
+ x = self.bmm_swap(x, swap_mask)
+ else:
+ x = ops.matmul(x, swap_mask)
+ return x
+
+ def slice_half(self, x):
+ bs, n_head, seq, _ = self.shape(x)
+ x1 = self.slice(x, (0, 0, 0, 0), (bs, n_head, seq, self.half_head_dim), (1, 1, 1, 1))
+ x2 = self.slice(x, (0, 0, 0, self.half_head_dim), (bs, n_head, seq, self.head_dim), (1, 1, 1, 1))
+ x = self.concat((self.neg(x2), x1))
+ return x
+
+ def construct(self, xq: Tensor, xk: Tensor, freqs_cis):
+ """Forward of rotary position embedding."""
+ original_type = xq.dtype
+ xq = self.cast(xq, self.dtype)
+ xk = self.cast(xk, self.dtype)
+ # xq, xk: [bs, n_head/n_kv_head, seq/1, head_dim]
+ freqs_cos, freqs_sin, swap_mask = freqs_cis
+ mul = self.mul if self.is_first_iteration else self.mul_inc
+ if self.use_rope_slice:
+ xq_out = self.add(mul(xq, freqs_cos), mul(self.slice_half(xq), freqs_sin))
+ xk_out = self.add(mul(xk, freqs_cos), mul(self.slice_half(xk), freqs_sin))
+ else:
+ xq_out = self.add(mul(xq, freqs_cos), mul(self.rotate_half(xq, swap_mask), freqs_sin))
+ xk_out = self.add(mul(xk, freqs_cos), mul(self.rotate_half(xk, swap_mask), freqs_sin))
+
+ xq_out = self.cast(xq_out, original_type)
+ xk_out = self.cast(xk_out, original_type)
+ return xq_out, xk_out
+
+
+class LLamaAttention(nn.Cell):
+ r"""
+ This is an implementation of multi head attention in LLaMA.
+
+ Args:
+ - **batch_size** (int): The batch size of the input tensor when do incremental prediction. Should be a
+ positive value.
+ When do training or prediction, the argument will not work and the user can just pass None to the
+ argument.
+ - **src_seq_length** (int): The sequence length of the query vector.
+ - **tgt_seq_length** (int): The sequence length of the key and value vector.
+ - **dim** (int): The hidden size of the input.
+ - **head_dim** (int): The dim of head.
+ - **n_heads** (int): The number of the heads.
+ - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16.
+ Should be mstype.float32 or mstype.float16.
+ - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32.
+ Should be mstype.float32 or mstype.float16.
+ - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype.
+ float32. Should be mstype.float32 or mstype.float16.
+ - **qkv_has_bias** (bool): Whether Q/K/V in attention has bias or not.
+ - **use_past** (bool): Use the past state to compute, used for incremental prediction.
+ For example, if we have two words and want to generate the ten more words.
+ We just need to compute the two words" state only once, and generate the next word one by one.
+ When use_past is True, there are two steps to run the prediction.
+ In the first step, set the is_first_iteration to be True by
+ `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the
+ is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment,
+ pass the single step"s input tensor, and loop it. Default False.
+
+ Inputs:
+ - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or
+ (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
+ Otherwise, must be (batch_size, 1, hidden_size)
+ - **freqs_cis** (Tuple) - The precompute freqs and mask for rotary position embedding used in attention.
+ - **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
+ matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
+ in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
+ - **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, head_dim, tgt_seq_length).
+ The past calculated key vector. Used for incremental prediction when the use_past is True.
+ Default None.
+ - **value_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, tgt_seq_length,
+ head_dim).
+ The past calculated value vector. Used for incremental prediction when the use_past is True.
+ Default None.
+ - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
+ Used for incremental prediction when the use_past is True. Default None.
+
+ Outputs:
+ Tuple, a tuple contains(`output`, `layer_present`)
+
+ - **output** (Tensor) - Tensor, the float tensor of the output of the layer with
+ shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size),
+ if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size).
+
+ - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
+ ((batch_size, num_heads, head_dim, tgt_seq_length),
+ (batch_size, num_heads, tgt_seq_length, head_dim)).
+ """
+
+ def __init__(
+ self,
+ batch_size,
+ seq_length,
+ dim: int = 512,
+ n_heads: int = 8,
+ n_kv_heads: Optional[int] = None,
+ qkv_concat=False,
+ compute_dtype=mstype.float16,
+ softmax_compute_dtype=mstype.float32,
+ rotary_dtype=mstype.float32,
+ param_init_type=mstype.float32,
+ qkv_has_bias=False,
+ use_past=False,
+ is_dynamic=False,
+ use_kvcache_op=False,
+ is_flexible_shape=False,
+ use_rope_slice=False,
+ use_flash_attention=False,
+ ):
+ super().__init__()
+ self.seq_length = seq_length
+ self.hidden_size = dim
+ self.n_head = n_heads
+ self.head_dim = dim // n_heads
+ self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
+ self.n_rep = self.n_head // self.n_kv_head
+ self.kv_dim = self.n_kv_head * self.head_dim
+
+ self.dtype = compute_dtype
+ self.softmax_dtype = softmax_compute_dtype
+ self.is_first_iteration = True
+ self.use_past = use_past
+ self.use_flash_attention = use_flash_attention
+ self.qkv_concat = qkv_concat
+
+ if self.hidden_size % self.n_head != 0:
+ raise ValueError(
+ "For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple "
+ "of 'n_head', but got the hidden_size is {} and the n_head is {}.".format(self.hidden_size, self.n_head)
+ )
+
+ self.inv_norm_factor = Tensor(1.0 / self.head_dim**0.5, dtype=compute_dtype)
+
+ self.shape = ops.Shape()
+ self.reshape = ops.Reshape().add_prim_attr("skip_redistribution", True)
+ self.transpose = ops.Transpose()
+ self.merger_head_transpose = ops.Transpose()
+ self.batch_matmul = ops.BatchMatMul()
+ self.batch_matmul_q_k = ops.BatchMatMul(transpose_b=True)
+ self.mul = ops.Mul()
+ self.add = ops.Add()
+ self.softmax = ops.Softmax()
+ self.cast = ops.Cast()
+ self.cast_attn = ops.Cast()
+ self.tile_kv = ops.Tile()
+ self.slice_qkv = ops.StridedSlice()
+
+ self.apply_rotary_emb = LlamaRotaryEmbedding(self.head_dim, rotary_dtype, use_rope_slice=use_rope_slice)
+ if self.qkv_concat:
+ self.w = Linear(
+ in_channels=self.hidden_size,
+ out_channels=self.hidden_size + self.kv_dim * 2,
+ has_bias=qkv_has_bias,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+ else:
+ self.wq = Linear(
+ self.hidden_size,
+ self.hidden_size,
+ has_bias=qkv_has_bias,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+ self.wk = Linear(
+ self.hidden_size,
+ self.kv_dim,
+ has_bias=qkv_has_bias,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+ self.wv = Linear(
+ self.hidden_size,
+ self.kv_dim,
+ has_bias=qkv_has_bias,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+ self.wo = Linear(
+ in_channels=self.hidden_size,
+ out_channels=self.hidden_size,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ if self.use_flash_attention:
+ self.flash_attention = FlashAttention(self.head_dim, n_heads, next_block_num=0, high_precision=True)
+
+ if self.use_past:
+ self.kvcache_mgr = KVCacheMgr(
+ self.n_kv_head,
+ self.head_dim,
+ max_batch_size=batch_size,
+ max_seq_length=seq_length,
+ compute_dtype=compute_dtype,
+ is_dynamic=is_dynamic,
+ use_kvcache_op=use_kvcache_op,
+ is_flexible_shape=is_flexible_shape,
+ )
+
+ def construct(self, x: Tensor, freqs_cis: Tuple[Tensor, Tensor], mask=None, kvcache_inputs=None):
+ """Forward process of the MultiHeadAttention"""
+ ori_dtype = x.dtype
+ # [bs, seq/1, hidden_dim]
+ bs, seq_len, _ = self.shape(x)
+ # [bs * seq/1, hidden_dim]
+ if self.qkv_concat:
+ x = self.reshape(x, (-1, x.shape[-1]))
+ bs_seq = x.shape[0]
+ qkv = self.cast(self.w(x), self.dtype)
+ query = self.slice_qkv(qkv, (0, 0), (bs_seq, self.hidden_size), (1, 1))
+ key = self.slice_qkv(qkv, (0, self.hidden_size), (bs_seq, self.hidden_size + self.kv_dim), (1, 1))
+ value = self.slice_qkv(
+ qkv, (0, self.hidden_size + self.kv_dim), (bs_seq, self.hidden_size + self.kv_dim * 2), (1, 1)
+ )
+ else:
+ query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp
+ key = self.cast(self.wk(x), self.dtype) # dp, 1 -> dp, mp
+ value = self.cast(self.wv(x), self.dtype) # dp, 1 -> dp, mp
+
+ if self.use_past and not self.is_first_iteration:
+ query = self.reshape(query, (bs, self.n_head, 1, self.head_dim))
+ key = self.reshape(key, (bs, self.n_kv_head, 1, self.head_dim))
+ value = self.reshape(value, (bs, self.n_kv_head, 1, self.head_dim))
+ else:
+ query = self.reshape(query, (bs, seq_len, self.n_head, self.head_dim))
+ key = self.reshape(key, (bs, seq_len, self.n_kv_head, self.head_dim))
+ value = self.reshape(value, (bs, seq_len, self.n_kv_head, self.head_dim))
+ # [bs, seq/1, n_head/n_kv_head, head_dim]
+ query = self.transpose(query, (0, 2, 1, 3))
+ key = self.transpose(key, (0, 2, 1, 3))
+ value = self.transpose(value, (0, 2, 1, 3))
+ # [bs, n_head/n_kv_head, seq/1, head_dim]
+ query, key = self.apply_rotary_emb(query, key, freqs_cis) # dp, mp, 1, 1
+ # kv cache: [bs, n_kv_head, 1, head_dim] -> [bs, n_kv_head, seq, head_dim]
+ if self.use_past:
+ key, value = self.kvcache_mgr(key, value, kvcache_inputs)
+ # kv share: [bs, n_kv_head, seq, head_dim] -> [bs, n_head, seq, head_dim]
+ key = self._repeat_kv(key, self.n_rep)
+ value = self._repeat_kv(value, self.n_rep)
+ # q, k, v: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim], [bs, n_head, seq, head_dim]
+ if self.use_flash_attention:
+ attention = self.flash_attention(query, key, value, mask)
+ attention = self._merge_heads(attention)
+ else:
+ attention = self._attn(query, key, value, mask)
+ # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim]
+ output = self.wo(attention) # dp, mp -> dp, 1 / dp * mp, 1
+ output = self.cast(output, ori_dtype)
+
+ return output
+
+ def _repeat_kv(self, x, rep):
+ if rep == 1:
+ return x
+ bs, n_kv_head, seqlen, head_dim = self.shape(x)
+ x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim))
+ x = self.tile_kv(x, (1, 1, rep, 1))
+ x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim))
+ return x
+
+ def _merge_heads(self, x):
+ """
+ convert a 4d input to a 2d or 3d output
+
+ Inputs:
+ x: input tensor
+
+ Output:
+ x_merge: the 2d output
+ """
+ # [bs, n_head, seq/1, head_dim]
+ x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1
+ # [bs, seq/1, n_head, head_dim]
+ bs, seq_len, n_head, head_dim = self.shape(x)
+ # [bs, seq/1, hidden_dim]
+ new_shape = (bs, seq_len, n_head * head_dim)
+ x_merge = self.reshape(x, new_shape)
+ return x_merge
+
+ def _attn(self, query, key, value, mask):
+ """
+ Get the weighted score along the seq_length
+
+ Inputs:
+ query: the query matrix
+ key: the key matrix
+ value: the value matrix
+ mask: the attention mask adder matrix with shape (batch_size,
+ 1, seq_length, seq_length)
+ Outputs:
+ weighted_values: Tensor, the weighted sum scores
+ """
+ # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim]
+ score = self.batch_matmul_q_k(query, key)
+ # score: [bs, n_head, seq/1, seq]
+ score = self.mul(score, self.inv_norm_factor)
+ score = self.add(mask, score)
+
+ attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype))
+ # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim]
+ weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value)
+ # [bs, n_head, seq/1, head_dim]
+ attention_merge = self._merge_heads(weighted_values)
+ # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim]
+ return attention_merge
+
+
+class LlamaRMSNorm(nn.Cell):
+ r"""
+ A self-defined RMSNorm operation using reduce mean.
+
+ Args:
+ dim (int): The shape of the input tensor
+ eps (float): The epsilon value of the denominator. Default 1e-5.
+ compute_type: The compute type.
+ param_init_type: The layer norm param init type.
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
+
+ Outputs:
+ Tensor of shape :math:`(batch, seq_length, hidden_size)`.
+ """
+
+ def __init__(self, dim, eps=1e-6, compute_type=mstype.float32, is_dynamic=False, param_init_type=mstype.float32):
+ super(LlamaRMSNorm, self).__init__()
+ self.eps = eps
+ self.compute_type = compute_type
+ self.weight = Parameter(initializer("ones", (dim,), dtype=param_init_type), parallel_optimizer=False)
+
+ if ms.get_context("device_target") == "Ascend" and not is_910a() and not is_dynamic:
+ self.norm = ops.RmsNorm(eps)
+ self.rms_norm = self._rms_norm
+ else:
+ self.cast = ops.Cast()
+ self.mul = ops.Mul()
+ self.mul2 = ops.Mul()
+ self.square = ops.Square()
+ self.mean = ops.ReduceMean(keep_dims=True)
+ self.add = ops.Add()
+ self.rsqrt = ops.Rsqrt()
+ self.rms_norm = self._self_norm
+
+ def _self_norm(self, x):
+ original_type = x.dtype
+ norm_factor = self.square(self.cast(x, self.compute_type))
+ norm_factor = self.mean(norm_factor, -1)
+ norm_factor = self.add(norm_factor, self.eps)
+ norm_factor = self.rsqrt(norm_factor)
+ output = self.mul(x, self.cast(norm_factor, original_type))
+ output = self.mul2(output, self.cast(self.weight, original_type))
+ return output
+
+ def _rms_norm(self, x):
+ original_type = x.dtype
+ return self.norm(x, self.cast(self.weight, original_type))[0]
+
+ def construct(self, x):
+ """Forward of RMSNorm."""
+ return self.rms_norm(x)
+
+
+class QwenForCausalLM(BaseLLMModel):
+ r"""
+ Provide qwen training loss or logits through network.
+ Args:
+ config (QwenConfig): The config of Qwen model.
+
+ Returns:
+ Tensor, the loss or logits of the network.
+ """
+
+ def __init__(self, config=None):
+ super().__init__(config)
+
+ self.transformer = QwenModel(config=config)
+ self.lm_head = Linear(
+ in_channels=config.hidden_size,
+ out_channels=config.vocab_size,
+ has_bias=False,
+ compute_dtype=config.compute_dtype,
+ param_init_type=mstype.float16,
+ weight_init="normal",
+ )
+ self.loss = CrossEntropyLoss()
+
+ self.pad_token_id = config.pad_token_id
+ self.use_past = config.use_past
+ self.ignore_token_id = config.ignore_token_id
+ self.seq_length = config.seq_length
+ self.vocab_size = config.vocab_size
+ self.is_first_iteration = True
+ self.not_equal = ops.NotEqual()
+ self.cast = ops.Cast()
+ self.add = ops.Add()
+ self.reshape = ops.Reshape()
+ self.ones = ops.Ones()
+ self.slice = ops.StridedSlice()
+ self.mul = ops.Mul()
+ self.sub_batch_valid_len = ops.Sub()
+ self.gather = ops.Gather(1)
+
+ def construct(
+ self,
+ input_ids,
+ labels=None,
+ input_position=None,
+ position_ids=None,
+ attention_mask=None,
+ input_embeds=None,
+ init_reset=True,
+ batch_valid_length=None,
+ batch_index=None,
+ zactivate_len=None,
+ ):
+ bsz, seqlen = input_ids.shape
+ if self.use_past:
+ if not isinstance(batch_valid_length, Tensor):
+ batch_valid_length = self.ones((bsz,), mstype.int32)
+ if self.training:
+ tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1))
+ else:
+ tokens = input_ids
+
+ if batch_valid_length is not None:
+ batch_valid_length = self.reshape(batch_valid_length, (-1,))
+ if not self.is_first_iteration:
+ batch_valid_length = self.sub_batch_valid_len(batch_valid_length, 1)
+
+ output = self.transformer(
+ tokens,
+ init_reset=init_reset,
+ batch_valid_length=batch_valid_length,
+ batch_index=batch_index,
+ zactivate_len=zactivate_len,
+ )
+ pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
+ if pre_gather:
+ output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
+ logits = self.lm_head(output)
+
+ input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32)
+ if labels is None:
+ labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1))
+ else:
+ if labels.ndim > 1:
+ if self.training:
+ labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1))
+ label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32)
+ input_mask = self.mul(input_mask, label_mask)
+
+ if not self.training:
+ if not pre_gather:
+ logits = self.reshape(logits, (bsz, seqlen, -1))
+ logits = self.cast(logits, mstype.float32)
+ # makes cast effective to avoid allgather issue in Mindspore1.10
+ input_mask = self.add(input_mask, 1)
+ return logits, tokens, input_mask
+
+ if logits.ndim > 2:
+ logits = self.reshape(logits, (-1, logits.shape[-1]))
+ logits = self.cast(logits, mstype.float32)
+ labels = self.reshape(labels, (-1,))
+ input_mask = self.reshape(input_mask, (-1,))
+ loss = self.loss(logits, labels, input_mask)
+ return loss
+
+
+class QwenModel(BaseLLMModel):
+ """transformer"""
+
+ def __init__(self, config):
+ config = QwenConfig(**config)
+ super().__init__(config)
+ self.dtype = config.compute_dtype
+ self.vocab_size = config.vocab_size
+ self.num_hidden_layers = config.num_layers
+ self.embed_dim = config.hidden_size
+ self.head_dim = config.hidden_size // config.num_heads
+ self.seq_length = config.seq_length
+ self.pad_token_id = config.pad_token_id
+ self.num_attention_heads = config.num_heads
+ self.use_past = config.use_past
+ self.is_dynamic = config.is_dynamic
+ self.use_kvcache_op = config.use_kvcache_op
+ self.is_flexible_shape = config.is_flexible_shape
+
+ self.is_first_iteration = True
+ self.use_flash_attention = config.use_flash_attention
+
+ # 1. wte
+ self.wte = LlamaEmbedding(
+ self.vocab_size, self.embed_dim, param_init_type=config.param_init_type, parallel_optimizer=True
+ )
+
+ # 2. drop
+ self.drop = nn.Dropout(p=config.emb_dropout_prob)
+
+ # 4. h hidden layers for transformer
+ self.layers = nn.CellList()
+ for layer_id in range(config.num_layers):
+ layer = QwenDecodeLayer(
+ config.batch_size,
+ config.seq_length,
+ layer_id,
+ dim=config.hidden_size,
+ n_heads=config.num_heads,
+ intermediate_size=config.intermediate_size,
+ norm_eps=config.rms_norm_eps,
+ compute_dtype=config.compute_dtype,
+ layernorm_compute_dtype=config.layernorm_compute_type,
+ softmax_compute_dtype=config.softmax_compute_type,
+ rotary_dtype=config.rotary_dtype,
+ param_init_type=config.param_init_type,
+ ln_param_init_type=config.ln_param_init_type,
+ qkv_has_bias=True,
+ use_past=config.use_past,
+ use_flash_attention=config.use_flash_attention,
+ )
+
+ self.layers.append(layer)
+
+ self.freqs_mgr = FreqsMgr(
+ head_dim=self.head_dim,
+ seq_length=self.seq_length,
+ max_position_embedding=config.max_position_embedding,
+ rotary_dtype=config.rotary_dtype,
+ theta=config.theta,
+ scaling_factor=config.scaling_factor,
+ extend_method=config.extend_method,
+ is_dynamic=config.is_dynamic,
+ )
+ self.casual_mask = CausalMaskForQwen(
+ seq_length=config.seq_length,
+ compute_type=config.compute_dtype,
+ is_dynamic=config.is_dynamic,
+ pad_token_id=config.pad_token_id,
+ use_flash_attention=config.use_flash_attention,
+ )
+ self.kvcache_preprocess = KVCachePreprocess(
+ max_batch_size=config.batch_size,
+ max_seq_length=config.seq_length,
+ is_dynamic=config.is_dynamic,
+ use_kvcache_op=config.use_kvcache_op,
+ is_flexible_shape=config.is_flexible_shape,
+ )
+ # 5. ln_f
+ self.ln_f = LlamaRMSNorm(
+ self.embed_dim,
+ eps=config.rms_norm_eps,
+ compute_type=config.layernorm_compute_type,
+ param_init_type=config.ln_param_init_type,
+ )
+
+ self.shape = ops.Shape()
+
+ def construct(
+ self, input_ids: Tensor, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None
+ ):
+ """construct"""
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ # 1. wte
+ hidden_states = self.wte(input_ids)
+
+ # 2. drop
+ hidden_states = self.drop(hidden_states)
+
+ # 2. rotary_emb
+ bs, seq_len = self.shape(input_ids)
+ if not self.use_past:
+ freqs_cis = self.freqs_mgr()
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ mask = self.casual_mask.post_process(mask)
+ kvcache_inputs = None
+ else:
+ if self.is_first_iteration:
+ freqs_cis = self.freqs_mgr(seq_len)
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ else:
+ freqs_cis = self.freqs_mgr.increment(batch_valid_length, bs)
+ if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
+ mask = self.casual_mask.increment_slice(
+ self.kvcache_preprocess.range,
+ self.kvcache_preprocess.max_cache_length // bs,
+ batch_valid_length,
+ zactivate_len,
+ )
+ else:
+ mask = self.casual_mask.increment(self.kvcache_preprocess.range, batch_valid_length, zactivate_len)
+ mask = self.casual_mask.post_process(mask)
+
+ kvcache_inputs = self.kvcache_preprocess(bs, batch_valid_length, batch_index, zactivate_len)
+
+ # 4. hidden_states
+ for i in range(self.num_hidden_layers):
+ hidden_states = self.layers[i](hidden_states, freqs_cis, mask, kvcache_inputs=kvcache_inputs)
+
+ # 5. ln_f
+ hidden_states = self.ln_f(hidden_states)
+
+ return hidden_states
+
+
+class QwenDecodeLayer(nn.Cell):
+ def __init__(
+ self,
+ batch_size,
+ seq_length,
+ layer_id,
+ dim: int = 512,
+ n_heads: int = 8,
+ n_kv_heads: Optional[int] = None,
+ intermediate_size: Optional[int] = None,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[int] = None,
+ norm_eps: float = 1e-5,
+ qkv_concat=False,
+ compute_dtype=mstype.float16,
+ layernorm_compute_dtype=mstype.float32,
+ softmax_compute_dtype=mstype.float32,
+ rotary_dtype=mstype.float32,
+ param_init_type=mstype.float32,
+ ln_param_init_type=mstype.float32,
+ use_past=False,
+ is_dynamic=False,
+ use_kvcache_op=False,
+ is_flexible_shape=False,
+ use_rope_slice=False,
+ use_flash_attention=False,
+ qkv_has_bias=True,
+ ):
+ super().__init__()
+ self.batch_size = batch_size
+
+ self.seq_length = seq_length
+ self.layer_id = layer_id
+ self.hidden_size = dim
+ self.n_head = n_heads
+ self.head_dim = self.hidden_size // self.n_head
+ self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads
+ self.dtype = compute_dtype
+ self.is_first_iteration = True
+ self.use_past = use_past
+
+ self.shape = ops.Shape()
+ self.reshape = ops.Reshape().add_prim_attr("skip_redistribution", True)
+ self.add = ops.Add()
+ self.attention_norm = LlamaRMSNorm(
+ self.hidden_size,
+ norm_eps,
+ compute_type=layernorm_compute_dtype,
+ is_dynamic=is_dynamic,
+ param_init_type=ln_param_init_type,
+ )
+ self.ffn_norm = LlamaRMSNorm(
+ self.hidden_size,
+ norm_eps,
+ compute_type=layernorm_compute_dtype,
+ is_dynamic=is_dynamic,
+ param_init_type=ln_param_init_type,
+ )
+ self.attention = LLamaAttention(
+ batch_size=batch_size,
+ seq_length=seq_length,
+ dim=dim,
+ n_heads=n_heads,
+ n_kv_heads=n_kv_heads,
+ qkv_concat=qkv_concat,
+ compute_dtype=compute_dtype,
+ softmax_compute_dtype=softmax_compute_dtype,
+ rotary_dtype=rotary_dtype,
+ param_init_type=param_init_type,
+ qkv_has_bias=qkv_has_bias,
+ use_past=use_past,
+ is_dynamic=is_dynamic,
+ use_kvcache_op=use_kvcache_op,
+ is_flexible_shape=is_flexible_shape,
+ use_rope_slice=use_rope_slice,
+ use_flash_attention=use_flash_attention,
+ )
+ self.feed_forward = LlamaFeedForward(
+ dim=self.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_dim=4 * self.hidden_size,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ is_dynamic=is_dynamic,
+ )
+ self.feed_forward = QwenFeedForward(
+ dim=self.hidden_size,
+ intermediate_size=intermediate_size,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+
+ def construct(self, x, freqs_cis, mask=None, kvcache_inputs=None):
+ """Forward of transformer block."""
+ # [bs, seq/1, hidden_dim]
+ input_x = self.attention_norm(x)
+ # [bs, seq/1, hidden_dim]
+ h = self.attention(input_x, freqs_cis, mask, kvcache_inputs)
+ h = self.add(x, h)
+ ffn_norm = self.ffn_norm(h)
+ # [bs, seq/1, hidden_dim]
+ ffn_out = self.feed_forward(ffn_norm)
+ # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim]
+ out = self.add(h, ffn_out)
+ return out
+
+
+class QwenFeedForward(nn.Cell):
+ r"""
+ Qwen FeedForward.
+
+ .. math::
+ (xW_1 * xW_3)W_2
+
+ Inputs:
+ - **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
+ Float tensor.
+
+ Outputs:
+ Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size] or
+ [batch * seq_length, hidden_size]`.
+
+ Raises:
+ ValueError: `hidden_dim` is not a multiple of the model parallel way.
+ ValueError: `dim` is not a multiple of the model parallel way.
+ """
+
+ def __init__(
+ self, dim, intermediate_size=0, compute_dtype=mstype.float16, param_init_type=mstype.float32, is_dynamic=False
+ ):
+ super().__init__()
+
+ hidden_dim = intermediate_size
+ self.dtype = compute_dtype
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ self.mul = ops.Mul()
+ self.cast = ops.Cast()
+ self.silu = LlamaSiLU()
+
+ self.w1 = Linear(
+ in_channels=dim,
+ out_channels=hidden_dim,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ self.w2 = Linear(
+ in_channels=hidden_dim,
+ out_channels=dim,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ self.w3 = Linear(
+ in_channels=dim,
+ out_channels=hidden_dim,
+ has_bias=False,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ skip_redistribution=is_dynamic,
+ )
+
+ def construct(self, x):
+ """Forward process of the FeedForward"""
+ x = self.cast(x, self.dtype)
+ # [bs, seq, hidden_dim] or [bs * seq, hidden_dim]
+ gate = self.w1(x) # dp,1 -> dp, mp
+ hidden = self.w3(x) # dp,1 -> dp, mp
+ hidden = self.mul(gate, self.silu(hidden).astype(self.dtype)) # dp,mp -> dp, mp
+ output = self.w2(hidden) # dp,mp -> dp, 1
+ return output
+
+
+class CausalMaskForQwen(nn.Cell):
+ r"""Get the Lower triangular matrix from the input_ids.
+ [[[1. 0. 0. 0. 0]
+ [1. 1. 0. 0. 0]
+ [1. 1. 1. 0. 0]
+ [1. 1. 1. 1. 0]
+ [1. 1. 1. 1. 0]]]"""
+
+ def __init__(
+ self, seq_length, compute_type=mstype.float16, is_dynamic=False, pad_token_id=0, use_flash_attention=False
+ ):
+ super().__init__()
+ self.dtype = compute_type
+ self.is_dynamic = is_dynamic
+ self.pad_token_id = pad_token_id
+ self.use_flash_attention = use_flash_attention
+ self.multiply_data = Tensor([-10000.0], dtype=compute_type)
+ self.one = Tensor([1.0], dtype=compute_type)
+ self.lower_triangle_mask = Tensor(np.tril(np.ones(shape=(seq_length, seq_length))), mstype.float32)
+
+ self.shape = ops.Shape()
+ self.cast = ops.Cast()
+ self.reshape = ops.Reshape()
+ self.not_equal = ops.NotEqual()
+ self.less_equal = ops.LessEqual()
+ self.expand_dim = ops.ExpandDims()
+ self.slice = ops.StridedSlice()
+ self.mul = ops.Mul()
+ self.sub = ops.Sub()
+ self.mul_post = ops.Mul()
+ self.expand_dim_post = ops.ExpandDims()
+
+ def construct(self, tokens):
+ """Forward process of the CausalMask"""
+ bs = self.shape(tokens)[0]
+ seq_len = self.shape(tokens)[1]
+ input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), self.dtype)
+ shape_right = (bs, 1, seq_len)
+ # Mask the padded inputs
+ mask_right = self.reshape(input_mask, shape_right)
+ if not self.is_dynamic:
+ lower_triangle = self.expand_dim(self.lower_triangle_mask, 0)
+ else:
+ lower_triangle_mask = self.slice(self.lower_triangle_mask, (0, 0), (seq_len, seq_len), (1, 1))
+ lower_triangle = self.expand_dim(lower_triangle_mask, 0)
+ # the returned shape is [bs, seq_length, seq_length]
+ attention_mask = self.mul(mask_right, lower_triangle)
+ return attention_mask
+
+ def increment(self, seq_range, batch_valid_length, zactivate_len=None):
+ if zactivate_len is not None:
+ seq_range = self.slice(seq_range, (0, 0, 0), (1, 1, self.shape(zactivate_len)[0]), (1, 1, 1))
+ mask = self.less_equal(self.reshape(seq_range, (1, 1, -1)), self.reshape(batch_valid_length, (-1, 1, 1)))
+ return mask
+
+ def increment_slice(self, seq_range, seq_length, batch_valid_length, zactivate_len=None):
+ if zactivate_len is not None:
+ seq_range_mask = self.slice(seq_range, (0, 0, 0), (1, 1, self.shape(zactivate_len)[0]), (1, 1, 1))
+ else:
+ seq_range_mask = self.slice(seq_range, (0, 0, 0), (1, 1, seq_length), (1, 1, 1))
+ mask = self.less_equal(self.reshape(seq_range_mask, (1, 1, -1)), self.reshape(batch_valid_length, (-1, 1, 1)))
+ return mask
+
+ def post_process(self, mask):
+ mask = self.sub(self.one, self.cast(mask, self.dtype))
+ if not self.use_flash_attention:
+ mask = self.expand_dim_post(mask, 1)
+ mask = self.mul_post(mask, self.multiply_data)
+ return mask
diff --git a/mindocr/nlp/llm/qwen_tokenizer.py b/mindocr/nlp/llm/qwen_tokenizer.py
new file mode 100644
index 000000000..8ad4224d2
--- /dev/null
+++ b/mindocr/nlp/llm/qwen_tokenizer.py
@@ -0,0 +1,321 @@
+import base64
+import unicodedata
+from typing import Collection, Dict, List, Set, Union
+
+import numpy as np
+import tiktoken
+
+import mindspore as ms
+
+PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+
+(?!\S)|\s+"""
+ENDOFTEXT = "<|endoftext|>"
+IMSTART = "<|im_start|>"
+IMEND = "<|im_end|>"
+EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
+SPECIAL_TOKENS = (
+ (
+ ENDOFTEXT,
+ IMSTART,
+ IMEND,
+ )
+ + EXTRAS
+ + ("[", "]", "", "", "", "", "", "", "")
+)
+
+
+def to_py_obj(obj):
+ """
+ Convert a Mindspore tensor, Numpy array or python list to a python list.
+ """
+ if isinstance(obj, dict):
+ return {k: to_py_obj(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [to_py_obj(o) for o in obj]
+ if isinstance(obj, ms.Tensor):
+ return obj.asnumpy().tolist()
+ if isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
+ return obj.tolist()
+ return obj
+
+
+def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
+ with open(tiktoken_bpe_file, "rb") as f:
+ contents = f.read()
+ return {
+ base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
+ }
+
+
+class QwenTokenizer:
+ """Qwen Tokenizer"""
+
+ def __init__(self, vocab_file="qwen.tiktoken", pad_token=ENDOFTEXT):
+ self.vocab_file = vocab_file
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
+
+ self.special_tokens = {
+ token: index for index, token in enumerate(SPECIAL_TOKENS, start=len(self.mergeable_ranks))
+ }
+
+ enc = tiktoken.Encoding(
+ "Qwen",
+ pat_str=PAT_STR,
+ mergeable_ranks=self.mergeable_ranks,
+ special_tokens=self.special_tokens,
+ )
+ assert (
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
+
+ self.decoder = {v: k for k, v in self.mergeable_ranks.items()}
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
+
+ self.tokenizer = enc
+
+ self.eod_id = self.tokenizer.eot_token
+ self.im_start_id = self.special_tokens[IMSTART]
+ self.im_end_id = self.special_tokens[IMEND]
+
+ self.errors = "replace"
+ self._in_target_context_manager = False
+ self.model_input_names = ["input_ids", "token_type_ids", "attention_mask"]
+ self.pad_token_type_id = 0
+ self.pad_token_id = self.convert_tokens_to_ids(pad_token)
+
+ @property
+ def vocab_size(self):
+ return self.tokenizer.n_vocab
+
+ # override Tokenizer.convert_tokens_to_string()
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
+ """
+ Converts a sequence of tokens in a single string.
+ """
+ text = ""
+ temp = b""
+ for t in tokens:
+ if isinstance(t, str):
+ if temp:
+ text += temp.decode("utf-8", errors=self.errors)
+ temp = b""
+ text += t
+ elif isinstance(t, bytes):
+ temp += t
+ else:
+ raise TypeError("token should only be of type types or str")
+ if temp:
+ text += temp.decode("utf-8", errors=self.errors)
+ return text
+
+ # called by Tokenizer.convert_tokens_to_ids() & SpecialTokensMixin
+ def _convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> Union[int, List[int]]:
+ """Convert the tokens to ids using vocab mapping"""
+ if isinstance(tokens, (str, bytes)):
+ return self._convert_token_to_id(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id(token))
+ return ids
+
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
+ """Converts a token to an id using the vocab, special tokens included"""
+ if token in self.special_tokens:
+ return self.special_tokens[token]
+ if token in self.mergeable_ranks:
+ return self.mergeable_ranks[token]
+ raise ValueError("unknown token")
+
+ # required by Tokenizer.convert_ids_to_tokens() of mindformers<=0.6
+ def _convert_ids_to_tokens(self, input_id: int):
+ return self._convert_id_to_token(input_id)
+
+ # called by Tokenizer.convert_ids_to_tokens()
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
+ """Converts an id to a token, special tokens included"""
+ if index in self.decoder:
+ return self.decoder[index]
+ raise ValueError("unknown ids")
+
+ def tokenize(
+ self,
+ text: str,
+ allowed_special: Union[Set, str] = "all",
+ disallowed_special: Union[Collection, str] = (),
+ ) -> List[Union[bytes, str]]:
+ """
+ Converts a string in a sequence of tokens.
+
+ Args:
+ text (`str`):
+ The sequence to be encoded.
+ allowed_special (`Literal["all"]` or `set`):
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
+ Default to "all".
+ disallowed_special (`Literal["all"]` or `Collection`):
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
+ Default to an empty tuple.
+
+ Returns:
+ `List[bytes|str]`: The list of tokens.
+ """
+ tokens = []
+ text = unicodedata.normalize("NFC", text)
+
+ # this implementation takes a detour: text -> token id -> token surface forms
+ for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
+ tokens.append(self.decoder[t])
+ return tokens
+
+ def _decode(self, token_ids, skip_special_tokens: bool = False) -> str:
+ """override Tokenizer._decode(), called by BaseTokenizer.decode()"""
+ if isinstance(token_ids, int):
+ token_ids = [token_ids]
+ if skip_special_tokens:
+ token_ids = [i for i in token_ids if i < self.eod_id]
+ return self.tokenizer.decode(token_ids, errors=self.errors)
+
+ def _call_one(self, text, max_length=None):
+ is_batched = isinstance(text, (list, tuple))
+
+ if is_batched:
+ return self.batch_encode_plus(batch_text_or_text_pairs=text, max_length=max_length)
+ outputs = self.batch_encode_plus(batch_text_or_text_pairs=[text], max_length=max_length)
+ outputs = {_: outputs[_][0] for _ in outputs}
+ return outputs
+
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int], None]:
+ """
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
+ vocabulary.
+
+ Args:
+ tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
+
+ Returns:
+ `int` or `List[int]` or `None`: The token id or list of token ids.
+ """
+ if tokens is None:
+ return None
+
+ if isinstance(tokens, str):
+ return self._convert_token_to_id_with_added_voc(tokens)
+
+ ids = []
+ for token in tokens:
+ ids.append(self._convert_token_to_id_with_added_voc(token))
+ return ids
+
+ def _convert_token_to_id_with_added_voc(self, token):
+ if token is None:
+ return None
+ return self._convert_token_to_id(token)
+
+ def batch_encode_plus(self, batch_text_or_text_pairs, max_length):
+ input_ids = []
+ for ids in batch_text_or_text_pairs:
+ tokens = self.tokenize(ids)
+ input_id = self.convert_tokens_to_ids(tokens)
+ input_ids.append(input_id)
+
+ batch_outputs = {}
+ for input_id in input_ids:
+ outputs = self.prepare_for_model(input_id)
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(batch_outputs, max_length=max_length)
+
+ return batch_outputs
+
+ def pad(self, encoded_inputs, max_length=None):
+ # The model's main input name, usually `input_ids`, has be passed for padding
+ if self.model_input_names[0] not in encoded_inputs:
+ raise ValueError(
+ "You should supply an encoding or a list of encodings to this method "
+ f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
+ )
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+
+ batch_size = len(required_input)
+
+ batch_outputs = {}
+ for i in range(batch_size):
+ inputs = {k: v[i] for k, v in encoded_inputs.items()}
+
+ outputs = self._pad(inputs, max_length=max_length)
+ outputs = {_: outputs[_][:max_length] for _ in outputs}
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ return batch_outputs
+
+ def _pad(self, encoded_inputs, max_length) -> dict:
+ return_attention_mask = "attention_mask" in self.model_input_names
+
+ required_input = encoded_inputs[self.model_input_names[0]]
+ required_input_len = len(required_input)
+
+ needs_to_be_padded = required_input_len < max_length
+
+ # Initialize attention mask if not present.
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = [1] * required_input_len
+
+ if needs_to_be_padded:
+ difference = max_length - required_input_len
+
+ if return_attention_mask:
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+ if "token_type_ids" in encoded_inputs:
+ encoded_inputs["token_type_ids"] = (
+ encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+ )
+ if "special_tokens_mask" in encoded_inputs:
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+ encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+
+ return encoded_inputs
+
+ def prepare_for_model(self, ids):
+ return_token_type_ids = "token_type_ids" in self.model_input_names
+ encoded_inputs = {"input_ids": ids}
+ if return_token_type_ids:
+ encoded_inputs["token_type_ids"] = [0] * len(ids)
+ return encoded_inputs
+
+ def __call__(self, text, max_length):
+ """
+ Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+ sequences.
+
+ Args:
+ text (`str`, `List[str]`, `List[List[str]]`, *optional*):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ max_length (int): Max sequence length.
+ """
+ encodings = self._call_one(text=text, max_length=max_length)
+ return encodings
+
+ def decode(self, token_ids, skip_special_tokens=False) -> str:
+ # Convert inputs to python lists
+ token_ids = to_py_obj(token_ids)
+
+ if isinstance(token_ids[0], list):
+ output = []
+ for item in token_ids:
+ new_strs = self._decode(token_ids=item, skip_special_tokens=skip_special_tokens)
+ output.append(new_strs)
+ else:
+ output = self._decode(token_ids=token_ids, skip_special_tokens=skip_special_tokens)
+ return output
diff --git a/mindocr/nlp/llm/vary_clip_model.py b/mindocr/nlp/llm/vary_clip_model.py
new file mode 100644
index 000000000..5c3d22f88
--- /dev/null
+++ b/mindocr/nlp/llm/vary_clip_model.py
@@ -0,0 +1,226 @@
+from collections import OrderedDict
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, nn, ops
+
+from mindocr.nlp.utils.layers import LayerNorm, Linear
+
+
+class QuickGELU(nn.Cell):
+ def construct(self, x: Tensor):
+ return x * ops.sigmoid(1.702 * x)
+
+
+class CLIPAttention(nn.Cell):
+ """Multi-head attention module for CLIP"""
+
+ def __init__(self, embed_dim, num_heads, param_init_type=ms.float32):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.k_proj = Linear(self.embed_dim, self.embed_dim, param_init_type=param_init_type)
+ self.v_proj = Linear(self.embed_dim, self.embed_dim, param_init_type=param_init_type)
+ self.q_proj = Linear(self.embed_dim, self.embed_dim, param_init_type=param_init_type)
+ self.out_proj = Linear(self.embed_dim, self.embed_dim, param_init_type=param_init_type)
+ self.reshape = ops.Reshape()
+ self.transpose = ops.Transpose()
+ self.batch_matmul = ops.BatchMatMul()
+ self.batch_matmul_q_k = ops.BatchMatMul(transpose_b=True)
+ self.softmax = nn.Softmax()
+
+ def construct(self, x):
+ bsz, tgt_len, embed_dim = x.shape
+ query_states = self.transpose(
+ self.reshape(self.q_proj(x) * self.scale, (bsz, tgt_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)
+ )
+ key_states = self.transpose(
+ self.reshape(self.k_proj(x), (bsz, tgt_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)
+ )
+ value_states = self.transpose(
+ self.reshape(self.v_proj(x), (bsz, tgt_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)
+ )
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self.reshape(query_states, proj_shape)
+ key_states = self.reshape(key_states, proj_shape)
+ value_states = self.reshape(value_states, proj_shape)
+
+ src_len = tgt_len
+ attn_weights = self.batch_matmul_q_k(query_states, key_states)
+ if attn_weights.shape != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+ attn_weights = self.softmax(attn_weights)
+ attn_output = self.batch_matmul(attn_weights, value_states)
+ if attn_output.shape != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+ attn_output = self.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim))
+ attn_output = self.transpose(attn_output, (0, 2, 1, 3))
+ attn_output = self.reshape(attn_output, (bsz, tgt_len, embed_dim))
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class ResidualAttentionBlock(nn.Cell):
+ """ResidualAttention module for CLIP"""
+
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ attn_mask: Tensor = None,
+ param_init_type=ms.float32,
+ ln_param_init_type=ms.float32,
+ ):
+ super().__init__()
+
+ self.attn = CLIPAttention(d_model, n_head, param_init_type=param_init_type)
+ self.ln_1 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type)
+ self.mlp = nn.SequentialCell(
+ OrderedDict(
+ [
+ ("c_fc", Linear(d_model, d_model * 4, param_init_type=param_init_type)),
+ ("gelu", QuickGELU()),
+ ("c_proj", Linear(d_model * 4, d_model, param_init_type=param_init_type)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm((d_model,), eps=1e-5, param_init_type=ln_param_init_type)
+ self.attn_mask = Parameter(attn_mask) if attn_mask is not None else None
+
+ def construct(self, x: Tensor):
+ residual0 = x
+ x_type = x.dtype
+ x = self.ln_1(x.to(ms.float32)).to(x_type)
+ x = residual0 + self.attn(x)
+ residual1 = x
+ x = self.ln_2(x.to(ms.float32)).to(x_type)
+ x = residual1 + self.mlp(x)
+ return x
+
+
+class Transformer(nn.Cell):
+ """Vision Transformer module for CLIP"""
+
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ attn_mask: Tensor = None,
+ param_init_type=ms.float32,
+ ln_param_init_type=ms.float32,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.CellList(
+ [
+ ResidualAttentionBlock(
+ width, heads, attn_mask, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def construct(self, x: Tensor):
+ encoder_states = ()
+ hidden_state = x
+ for block in self.resblocks:
+ encoder_states += (hidden_state,)
+ hidden_state = block(hidden_state)
+ encoder_states += (hidden_state,)
+ return encoder_states
+
+
+class VisionTransformer(nn.Cell):
+ """CLIP module for Vary system"""
+
+ def __init__(
+ self,
+ input_resolution: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ output_dim: int,
+ vision_select_layer: int,
+ param_init_type=ms.float32,
+ ln_param_init_type=ms.float32,
+ ):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.vision_select_layer = vision_select_layer
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ has_bias=False,
+ pad_mode="pad",
+ weight_init="uniform",
+ bias_init="uniform",
+ dtype=param_init_type,
+ )
+
+ scale = width**-0.5
+ self.class_embedding = Parameter((scale * ops.randn(width)).astype(param_init_type))
+ self.positional_embedding = Parameter(
+ (scale * ops.randn(((input_resolution // patch_size) ** 2 + 1, width))).astype(param_init_type)
+ )
+ self.ln_pre = LayerNorm((width,), eps=1e-5, param_init_type=ln_param_init_type)
+ self.transformer = Transformer(
+ width, layers, heads, param_init_type=param_init_type, ln_param_init_type=ln_param_init_type
+ )
+
+ def construct(self, x: Tensor):
+ x_type = x.dtype
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape((x.shape[0], x.shape[1], -1)) # shape = [*, width, grid**2]
+ x = x.permute((0, 2, 1)) # shape = [*, grid**2, width]
+ x = ops.cat(
+ [self.class_embedding.to(x_type) + ops.zeros((x.shape[0], 1, x.shape[-1]), dtype=x_type), x], axis=1
+ ) # shape = [*, grid**2 + 1, width]
+ x = x + self.positional_embedding.to(x_type) # torch version: CLIPVisionEmbeddings
+ x = self.ln_pre(x) # modeling_clip.py L842
+ x = self.transformer(x) # modeling_clip.py L844, error 1e-3
+ x = x[self.vision_select_layer][:, 1:]
+ return x
+
+
+def build_model(param_init_type=ms.float32, ln_param_init_type=ms.float32):
+ """construct the CLIP module and load ckpt"""
+ vision_width = 1024
+ vision_layers = 24
+ vision_patch_size = 14
+ grid_size = round(256**0.5)
+ image_resolution = vision_patch_size * grid_size
+ out_width = 1024
+ model = VisionTransformer(
+ input_resolution=image_resolution, # image_size in transformers
+ patch_size=vision_patch_size, # patch_size in transformers
+ width=vision_width, # hidden_size
+ layers=vision_layers, # num_hidden_layers
+ heads=grid_size, # num_attention_heads
+ output_dim=out_width, # projection_dim in transformers, default: 1024
+ vision_select_layer=-2,
+ param_init_type=param_init_type,
+ ln_param_init_type=ln_param_init_type,
+ )
+
+ return model
diff --git a/mindocr/nlp/llm/vary_qwen_model.py b/mindocr/nlp/llm/vary_qwen_model.py
new file mode 100644
index 000000000..8001bd036
--- /dev/null
+++ b/mindocr/nlp/llm/vary_qwen_model.py
@@ -0,0 +1,253 @@
+import mindspore as ms
+from mindspore import ops
+
+from mindocr.nlp.llm import register_llm
+from mindocr.nlp.llm.configs import SAMConfig, VaryConfig
+from mindocr.nlp.llm.qwen_model import QwenForCausalLM, QwenModel
+from mindocr.nlp.llm.vary_clip_model import build_model
+from mindocr.nlp.llm.vary_sam_model import SAMEncoder
+from mindocr.nlp.utils.layers import Linear
+from mindocr.utils.conversation import Conversation
+
+
+class VaryQwenModel(QwenModel):
+ def __init__(self, config):
+ super(VaryQwenModel, self).__init__(config)
+ config = SAMConfig(ln_param_init_type=config.ln_param_init_type)
+ self.vision_tower_high = SAMEncoder(config)
+ self.vision_tower_high.to_float(ms.float16)
+
+ self.vision_tower = build_model(
+ param_init_type=config.param_init_type, ln_param_init_type=config.ln_param_init_type
+ )
+ self.vision_tower.to_float(ms.float16)
+
+ self.mm_projector = Linear(1024, 1024, param_init_type=config.param_init_type)
+ self.mm_projector_vary = Linear(1024, 1024, param_init_type=config.param_init_type)
+
+ self.image_start_token_pos = 22
+ self.num_patches = 256
+
+ def construct(
+ self,
+ input_ids,
+ init_reset=True,
+ batch_valid_length=None,
+ batch_index=None,
+ zactivate_len=None,
+ image=None,
+ image_high=None,
+ ):
+ # 1. wte
+ bs, seq_len = self.shape(input_ids)
+ inputs_embeds = self.wte(input_ids)
+
+ if seq_len > 1 and image is not None and image_high is not None:
+ sam_out = self.vision_tower_high(image_high)
+ sam_out = self.mm_projector_vary(sam_out)
+
+ clip_out = self.vision_tower(image)
+ clip_out = self.mm_projector(clip_out)
+
+ image_features = ops.concat((clip_out, sam_out), -1)
+
+ new_input_embeds = []
+ num_patches = self.num_patches
+ image_start_token_pos = self.image_start_token_pos
+ for i in range(bs):
+ cur_input_embeds = inputs_embeds[i]
+ per_cur_image_features = image_features[i]
+ cur_input_embeds = ops.cat(
+ (
+ cur_input_embeds[: image_start_token_pos + 1],
+ per_cur_image_features,
+ cur_input_embeds[image_start_token_pos + num_patches + 1 :],
+ ),
+ axis=0,
+ )
+
+ new_input_embeds.append(cur_input_embeds)
+
+ hidden_states = ops.stack(new_input_embeds, axis=0)
+ else:
+ hidden_states = inputs_embeds
+
+ # 2. drop
+ hidden_states = self.drop(hidden_states)
+
+ # 2. rotary_emb
+ if not self.use_past:
+ freqs_cis = self.freqs_mgr()
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ mask = self.casual_mask.post_process(mask)
+ kvcache_inputs = None
+ else:
+ if self.is_first_iteration:
+ freqs_cis = self.freqs_mgr(seq_len)
+ mask = self.casual_mask(input_ids) # mask: [bs, seq, seq]
+ else:
+ freqs_cis = self.freqs_mgr.increment(batch_valid_length, bs)
+ if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
+ mask = self.casual_mask.increment_slice(
+ self.kvcache_preprocess.range,
+ self.kvcache_preprocess.max_cache_length // bs,
+ batch_valid_length,
+ zactivate_len,
+ )
+ else:
+ mask = self.casual_mask.increment(self.kvcache_preprocess.range, batch_valid_length, zactivate_len)
+ mask = self.casual_mask.post_process(mask)
+
+ kvcache_inputs = self.kvcache_preprocess(bs, batch_valid_length, batch_index, zactivate_len)
+
+ # 4. hidden_states
+ for i in range(self.num_hidden_layers):
+ hidden_states = self.layers[i](hidden_states, freqs_cis, mask, kvcache_inputs=kvcache_inputs)
+
+ # 5. ln_f
+ hidden_states = self.ln_f(hidden_states)
+
+ return hidden_states
+
+
+@register_llm
+class VaryQwenForCausalLM(QwenForCausalLM):
+ def __init__(self, config):
+ config = VaryConfig(**config)
+ super(VaryQwenForCausalLM, self).__init__(config)
+ self.transformer = VaryQwenModel(config=config)
+ self.conversation = None
+
+ self.image_past = None
+ self.image_high_past = None
+
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
+ image = kwargs.get("image")
+ image_high = kwargs.get("image_high")
+ return {
+ "input_ids": ms.Tensor(input_ids, ms.int32),
+ "image": ms.Tensor(image, ms.float16) if image is not None else None,
+ "image_high": ms.Tensor(image_high, ms.float16) if image_high is not None else None,
+ }
+
+ def construct(
+ self,
+ input_ids,
+ labels=None,
+ input_position=None,
+ position_ids=None,
+ attention_mask=None,
+ input_embeds=None,
+ init_reset=True,
+ batch_valid_length=None,
+ batch_index=None,
+ zactivate_len=None,
+ image=None,
+ image_high=None,
+ ):
+ """construct"""
+ bsz, seqlen = input_ids.shape
+ if self.use_past:
+ if not isinstance(batch_valid_length, ms.Tensor):
+ batch_valid_length = self.ones((bsz,), ms.int32)
+ if self.training:
+ tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1))
+ else:
+ tokens = input_ids
+
+ if batch_valid_length is not None:
+ batch_valid_length = self.reshape(batch_valid_length, (-1,))
+ if not self.is_first_iteration:
+ batch_valid_length = self.sub_batch_valid_len(batch_valid_length, 1)
+
+ output = self.transformer(
+ tokens,
+ init_reset=init_reset,
+ batch_valid_length=batch_valid_length,
+ batch_index=batch_index,
+ zactivate_len=zactivate_len,
+ image=image,
+ image_high=image_high,
+ )
+ pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
+ if pre_gather:
+ output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
+ logits = self.lm_head(output)
+
+ input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), ms.float32)
+ if labels is None:
+ labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1))
+ else:
+ if labels.ndim > 1:
+ if self.training:
+ labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1))
+ label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), ms.float32)
+ input_mask = self.mul(input_mask, label_mask)
+
+ if not self.training:
+ if not pre_gather:
+ logits = self.reshape(logits, (bsz, seqlen, -1))
+ logits = self.cast(logits, ms.float32)
+ # makes cast effective to avoid allgather issue in Mindspore1.10
+ input_mask = self.add(input_mask, 1)
+ return logits, tokens, input_mask
+
+ if logits.ndim > 2:
+ logits = self.reshape(logits, (-1, logits.shape[-1]))
+ logits = self.cast(logits, ms.float32)
+ labels = self.reshape(labels, (-1,))
+ input_mask = self.reshape(input_mask, (-1,))
+ loss = self.loss(logits, labels, input_mask)
+ return loss
+
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ image=None,
+ image_high=None,
+ ) -> str:
+ """
+ example:
+ inputs:
+ query: Provide the ocr results of this image.
+ image: np.array.
+ image_high: np.array.
+ outputs:
+ response: the modalities of irradiation could be modified...
+ history: [
+ ("user", "Provide the ocr results of this image."),
+ ("assistant", "the modalities of irradiation could be modified..."),
+ ]
+
+ """
+ if self.conversation is None:
+ self.conversation = Conversation()
+
+ if image is not None and image_high is not None:
+ num_patch = 256
+ im_start_token = ""
+ im_end_token = ""
+ im_patch_token = ""
+ query = im_start_token + im_patch_token * num_patch + im_end_token + query
+ self.image_past = image
+ self.image_high_past = image_high
+
+ self.conversation.add_message(role="user", message=query)
+ prompt = self.conversation.get_prompt()
+
+ inputs = tokenizer([prompt], max_length=self.seq_length)
+ input_ids = inputs["input_ids"]
+ outputs = self.generate(input_ids=input_ids, image=self.image_past, image_high=self.image_high_past)
+ outputs = tokenizer.decode(outputs, skip_special_tokens=False)
+ response = outputs[0][len(prompt) :]
+
+ for special_token in tokenizer.special_tokens:
+ response = response.replace(special_token, "")
+ self.conversation.add_message(role="assistant", message=response)
+
+ return response
+
+ def reset(self):
+ if self.conversation is not None:
+ self.conversation.messages = list()
diff --git a/mindocr/nlp/llm/vary_sam_model.py b/mindocr/nlp/llm/vary_sam_model.py
new file mode 100644
index 000000000..80f341552
--- /dev/null
+++ b/mindocr/nlp/llm/vary_sam_model.py
@@ -0,0 +1,562 @@
+from typing import Optional, Tuple
+
+import mindspore as ms
+import mindspore.common.dtype as mstype
+import mindspore.numpy as np
+import mindspore.ops as ops
+from mindspore import Parameter, nn
+
+from mindocr.nlp.utils.layers import LayerNorm, Linear
+
+
+class MLPBlock(nn.Cell):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: nn.Cell = nn.GELU,
+ compute_dtype=mstype.float16,
+ param_init_type=mstype.float32,
+ ) -> None:
+ super().__init__()
+ self.lin1 = Linear(
+ in_channels=embedding_dim,
+ out_channels=mlp_dim,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+ self.lin2 = Linear(
+ in_channels=mlp_dim,
+ out_channels=embedding_dim,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+ self.act = act()
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+class LayerNorm2d(nn.Cell):
+ """
+ Layer Normalization for 2D data.
+
+ Inputs:
+ x (ms.Tensor): Input tensor.
+
+ Returns:
+ ms.Tensor: Normalized tensor.
+ """
+
+ def __init__(self, num_channels: int, eps: float = 1e-6, param_init_type=ms.float32) -> None:
+ super().__init__()
+ self.weight = Parameter(ops.Ones()(num_channels, param_init_type))
+ self.bias = Parameter(ops.Zeros()(num_channels, param_init_type))
+ self.eps = eps
+ self.pow = ops.Pow()
+ self.sqrt = ops.Sqrt()
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ u = x.mean(1, keep_dims=True)
+ s = self.pow(x - u, 2).mean(1, keep_dims=True)
+ x = (x - u) / self.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class SAMImageEncoder(nn.Cell):
+ """
+ Image encoder
+ """
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.img_size = config.img_size
+ self.patch_size = config.patch_size
+ self.in_chans = config.in_chans
+ self.embed_dim = config.embed_dim
+ self.depth = config.depth
+ self.num_heads = config.num_heads
+ self.mlp_ratio = config.mlp_ratio
+ self.out_chans = config.out_chans
+ self.qkv_bias = config.qkv_bias
+ self.layer_norm_eps = config.layer_norm_eps
+ self.use_abs_pos = config.use_abs_pos
+ self.use_rel_pos = config.use_rel_pos
+ self.window_size = config.window_size
+ self.global_attn_indexes = config.global_attn_indexes
+
+ self.compute_dtype = config.compute_dtype
+ self.layernorm_compute_type = config.layernorm_compute_type
+ self.softmax_compute_type = config.softmax_compute_type
+ self.param_init_type = config.param_init_type
+ self.ln_param_init_type = config.ln_param_init_type
+
+ if isinstance(self.img_size, int):
+ img_h = self.img_size
+ img_w = self.img_size
+ else:
+ img_h, img_w = self.img_size
+ feat_h = img_h // self.patch_size
+ feat_w = img_w // self.patch_size
+ self.feat_size = (feat_h, feat_w)
+ if self.window_size > 0:
+ pad_h = (self.window_size - feat_h % self.window_size) % self.window_size
+ pad_w = (self.window_size - feat_w % self.window_size) % self.window_size
+ self.pad_size = (pad_h, pad_w)
+ else:
+ self.pad_size = (0, 0)
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(self.patch_size, self.patch_size),
+ stride=(self.patch_size, self.patch_size),
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ param_init_type=self.param_init_type,
+ )
+
+ self.pos_embed: Optional[Parameter] = None
+ if self.use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = Parameter(
+ ops.Zeros()(
+ (1, self.img_size // self.patch_size, self.img_size // self.patch_size, self.embed_dim),
+ self.param_init_type,
+ )
+ )
+
+ self.blocks = nn.CellList()
+ for i in range(self.depth):
+ block = Block(
+ dim=self.embed_dim,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ use_rel_pos=self.use_rel_pos,
+ window_size=self.window_size if i not in self.global_attn_indexes else 0,
+ pad_size=self.pad_size,
+ feat_size=self.feat_size,
+ input_size=(self.img_size // self.patch_size, self.img_size // self.patch_size),
+ layer_norm_eps=self.layer_norm_eps,
+ compute_dtype=self.compute_dtype,
+ layernorm_compute_type=self.layernorm_compute_type,
+ softmax_compute_type=self.softmax_compute_type,
+ param_init_type=self.param_init_type,
+ ln_param_init_type=self.ln_param_init_type,
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.SequentialCell(
+ nn.Conv2d(
+ self.embed_dim,
+ self.out_chans,
+ kernel_size=1,
+ has_bias=False,
+ dtype=self.param_init_type,
+ ),
+ LayerNorm2d(self.out_chans, param_init_type=self.ln_param_init_type),
+ nn.Conv2d(
+ self.out_chans,
+ self.out_chans,
+ kernel_size=3,
+ pad_mode="pad",
+ padding=1,
+ has_bias=False,
+ dtype=self.param_init_type,
+ ),
+ LayerNorm2d(self.out_chans, param_init_type=self.ln_param_init_type),
+ )
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ """
+ Args:
+ x (ms.Tensor): Input image tensor.
+
+ Returns:
+ ms.Tensor: Encoded image tensor.
+ """
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.transpose(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Cell):
+ """
+ Transformer blocks with support of window attention and residual propagation blocks
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ window_size: int = 0,
+ pad_size: Tuple[int, int] = (0, 0),
+ feat_size: Tuple[int, int] = (64, 64),
+ input_size: Optional[Tuple[int, int]] = None,
+ layer_norm_eps: float = 1.0e-12,
+ compute_dtype=mstype.float16,
+ layernorm_compute_type=mstype.float32,
+ softmax_compute_type=mstype.float32,
+ param_init_type=mstype.float32,
+ ln_param_init_type=mstype.float32,
+ ) -> None:
+ super().__init__()
+ self.norm1 = LayerNorm((dim,), eps=layer_norm_eps, param_init_type=ln_param_init_type)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ compute_dtype=compute_dtype,
+ layernorm_compute_type=layernorm_compute_type,
+ softmax_compute_type=softmax_compute_type,
+ param_init_type=param_init_type,
+ )
+
+ self.norm2 = LayerNorm((dim,), eps=layer_norm_eps, param_init_type=ln_param_init_type)
+ self.mlp = MLPBlock(
+ embedding_dim=dim,
+ mlp_dim=int(dim * mlp_ratio),
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+
+ self.window_size = window_size
+ self.pad_size = pad_size
+ self.feat_size = feat_size
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ """
+ Args:
+ x (ms.Tensor): Input tensor.
+
+ Returns:
+ ms.Tensor: Output tensor.
+ """
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ pad_size = self.pad_size
+ window_size = self.window_size
+ b, h, w, c = x.shape
+ pad_h, pad_w = pad_size
+ if pad_h > 0 or pad_w > 0:
+ pad = ops.Pad(paddings=((0, 0), (0, pad_h), (0, pad_w), (0, 0)))
+ x = pad(x)
+ hp, wp = h + pad_h, w + pad_w
+
+ x = x.view(b, hp // window_size, window_size, wp // window_size, window_size, c)
+ x = x.transpose(0, 1, 3, 2, 4, 5).view(-1, window_size, window_size, c)
+ # x = window_partition(x, self.window_size, self.pad_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, self.pad_size, self.feat_size)
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Cell):
+ """
+ Multi-head Attention block with relative position embeddings.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ input_size: Optional[Tuple[int, int]] = None,
+ compute_dtype=mstype.float16,
+ layernorm_compute_type=mstype.float32,
+ softmax_compute_type=mstype.float32,
+ param_init_type=mstype.float32,
+ ) -> None:
+ super().__init__()
+ self.compute_dtype = compute_dtype
+ self.layernorm_compute_type = layernorm_compute_type
+ self.softmax_compute_type = softmax_compute_type
+
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = Linear(
+ in_channels=dim,
+ out_channels=dim * 3,
+ has_bias=qkv_bias,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+ self.proj = Linear(
+ in_channels=dim,
+ out_channels=dim,
+ compute_dtype=compute_dtype,
+ param_init_type=param_init_type,
+ )
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = Parameter(ops.Zeros()((2 * input_size[0] - 1, head_dim), self.compute_dtype))
+ self.rel_pos_w = Parameter(ops.Zeros()((2 * input_size[1] - 1, head_dim), self.compute_dtype))
+
+ self.softmax = ops.Softmax(axis=-1)
+ self.batchmatmul = ops.BatchMatMul()
+ self.batchmatmul_trans_b = ops.BatchMatMul(transpose_b=True)
+ self.cast = ops.Cast()
+ self.unstack = ops.Unstack(axis=0)
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ """
+ Args:
+ x (ms.Tensor): Input tensor.
+
+ Returns:
+ ms.Tensor: Output tensor.
+ """
+ b, h, w, _ = x.shape
+ ori_type = x.dtype
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(b, h * w, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ qkv = self.cast(qkv, self.compute_dtype)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = self.unstack(qkv.reshape(3, b * self.num_heads, h * w, -1))
+
+ attn = self.batchmatmul_trans_b((q * self.scale), k)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (h, w), (h, w))
+
+ attn = self.cast(attn, self.softmax_compute_type)
+ attn = self.softmax(attn)
+ attn = self.cast(attn, self.compute_dtype)
+ x = self.batchmatmul(attn, v)
+ x = x.view(b, self.num_heads, h, w, -1)
+ x = x.permute(0, 2, 3, 1, 4)
+ x = x.reshape(b, h, w, -1)
+ x = self.proj(x)
+ x = self.cast(x, ori_type)
+
+ return x
+
+
+def window_partition(x: ms.Tensor, window_size: int, pad_size: Tuple[int, int] = 0) -> ms.Tensor:
+ """
+ Partition the input tensor into non-overlapping windows with optional padding.
+
+ Args:
+ x (ms.Tensor): Input tensor with shape [B, H, W, C].
+ window_size (int): Window size.
+ pad_size (tuple[int, int]): Padding size as (pad_h, pad_w).
+
+ Returns:
+ windows (ms.Tensor): Windows after partition with shape [B * num_windows, window_size, window_size, C].
+ """
+ b, h, w, c = x.shape
+ pad_h, pad_w = pad_size
+ if pad_h > 0 or pad_w > 0:
+ pad = ops.Pad(paddings=((0, 0), (0, pad_h), (0, pad_w), (0, 0)))
+ x = pad(x)
+ hp, wp = h + pad_h, w + pad_w
+
+ x = x.view(b, hp // window_size, window_size, wp // window_size, window_size, c)
+ windows = x.transpose(0, 1, 3, 2, 4, 5).view(-1, window_size, window_size, c)
+ return windows
+
+
+def window_unpartition(windows: ms.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) -> ms.Tensor:
+ """
+ Unpartition windows back into original sequences and remove padding if needed.
+
+ Args:
+ windows (ms.Tensor): Input windows with shape [B * num_windows, window_size, window_size, C].
+ window_size (int): Window size.
+ pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp).
+ hw (Tuple[int, int]): Original height and width (H, W) before padding.
+
+ Returns:
+ x (ms.Tensor): Unpartitioned sequences with shape [B, H, W, C].
+ """
+ pad_h, pad_w = pad_hw
+ h, w = hw
+ hp, wp = h + pad_h, w + pad_w
+ b = windows.shape[0] // (hp * wp // window_size // window_size)
+ x = windows.view(b, hp // window_size, wp // window_size, window_size, window_size, -1)
+ x = x.transpose(0, 1, 3, 2, 4, 5).view(b, hp, wp, -1)
+
+ if hp > h or wp > w:
+ x = x[:, :h, :w, :]
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: ms.Tensor) -> ms.Tensor:
+ """
+ Get relative positional embeddings based on the relative positions of query and key sizes.
+
+ Args:
+ q_size (int): Size of query q.
+ k_size (int): Size of key k.
+ rel_pos (ms.Tensor): Relative position embeddings (L, C).
+
+ Returns:
+ ms.Tensor: Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = ops.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = np.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = np.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.astype(mstype.int32)]
+
+
+def add_decomposed_rel_pos(
+ attn: ms.Tensor,
+ q: ms.Tensor,
+ rel_pos_h: ms.Tensor,
+ rel_pos_w: ms.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> ms.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from mvitv2 paper.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+
+ Args:
+ attn (ms.Tensor): Attention map.
+ q (ms.Tensor): Query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (ms.Tensor): Relative position embeddings (Lh, C) for the height axis.
+ rel_pos_w (ms.Tensor): Relative position embeddings (Lw, C) for the width axis.
+ q_size (Tuple[int, int]): Spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple[int, int]): Spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ ms.Tensor: Attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ b, _, dim = q.shape
+ r_q = q.reshape(b, q_h, q_w, dim)
+ rel_h = ops.matmul(r_q, rh.transpose(0, 2, 1)).reshape(b, q_h, q_w, rh.shape[1])
+ rel_w = ops.matmul(r_q.transpose(0, 2, 1, 3), rw.transpose(0, 2, 1)).reshape(b, q_h, q_w, rw.shape[1])
+ rel_w = rel_w.transpose(0, 2, 1, 3)
+
+ attn = (attn.view(b, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
+ b, q_h * q_w, k_h * k_w
+ )
+
+ return attn
+
+
+class PatchEmbed(nn.Cell):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ param_init_type=ms.float32,
+ ) -> None:
+ """
+ Initialize the PatchEmbed layer.
+
+ Args:
+ kernel_size (Tuple[int, int]): Kernel size of the projection layer.
+ stride (Tuple[int, int]): Stride of the projection layer.
+ padding (Tuple[int, int, int, int]): Padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ has_bias=True,
+ dtype=param_init_type,
+ )
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ """
+ Forward pass of the PatchEmbed layer.
+
+ Args:
+ x (ms.Tensor): Input image tensor.
+
+ Returns:
+ ms.Tensor: Patch embeddings tensor.
+ """
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.transpose(0, 2, 3, 1)
+ return x
+
+
+class SAMEncoder(SAMImageEncoder):
+ """SAM encoder for Vary system"""
+
+ def __init__(self, config) -> None:
+ super().__init__(config)
+ self.net_2 = nn.Conv2d(
+ 256, 512, kernel_size=3, stride=2, pad_mode="pad", padding=1, has_bias=False, dtype=config.param_init_type
+ )
+ self.net_3 = nn.Conv2d(
+ 512, 1024, kernel_size=3, stride=2, pad_mode="pad", padding=1, has_bias=False, dtype=config.param_init_type
+ )
+
+ def construct(self, x):
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.transpose(0, 3, 1, 2))
+
+ x = self.net_2(x)
+ x = self.net_3(x)
+ x = x.flatten(start_dim=2).permute(0, 2, 1)
+ return x
diff --git a/mindocr/nlp/utils/__init__.py b/mindocr/nlp/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/mindocr/nlp/utils/kvcache_mgr.py b/mindocr/nlp/utils/kvcache_mgr.py
new file mode 100644
index 000000000..a5bd26292
--- /dev/null
+++ b/mindocr/nlp/utils/kvcache_mgr.py
@@ -0,0 +1,259 @@
+import numpy as np
+
+import mindspore.common.dtype as mstype
+from mindspore import Parameter, Tensor, nn, ops
+
+
+class KVCacheMgr(nn.Cell):
+ """KVCache Manager."""
+
+ def __init__(
+ self,
+ n_head,
+ head_dim,
+ max_batch_size=8,
+ max_seq_length=4096,
+ compute_dtype=mstype.float16,
+ is_dynamic=False,
+ use_kvcache_op=True,
+ is_flexible_shape=False,
+ ):
+ super().__init__()
+ self.n_head = n_head
+ self.head_dim = head_dim
+ self.max_batch_size = max_batch_size
+ self.max_seq_length = max_seq_length
+ self.dtype = compute_dtype
+ self.use_kvcache_op = use_kvcache_op
+ self.is_dynamic = is_dynamic
+ self.is_flexible_shape = is_flexible_shape
+ self.is_first_iteration = True
+
+ self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
+ self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
+ self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
+ self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
+ self.seqlen_axis_tensor_pad = Tensor([2, 3], dtype=mstype.int64)
+ self.pad_before = Tensor([0, 0, 0, 0, 0], mstype.int32)
+ self.pad_after = Tensor([0, 0], mstype.int32)
+ self.pad_zero = Tensor(0, compute_dtype)
+
+ if self.use_kvcache_op:
+ # pylint: disable=W0212
+ self.prompt_kvcache = ops.operations._inner_ops.PromptKVCache()
+ # pylint: disable=W0212
+ self.decoder_kvcache = ops.operations._inner_ops.DecoderKVCache()
+ else:
+ self.add = ops.Add()
+ self.mul = ops.Mul()
+ self.assign = ops.Assign()
+ self.concat = ops.Concat(axis=0)
+ self.sub = ops.Sub()
+ self.div = ops.Div()
+ self.pad = ops.PadV3()
+ self.slice = ops.StridedSlice()
+ self.cast = ops.Cast()
+ self.shape = ops.Shape()
+ self.reshape = ops.Reshape().add_prim_attr("skip_redistribution", True)
+
+ kv_shape = (max_batch_size, n_head, max_seq_length, head_dim)
+ self.key_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="key_past", requires_grad=False)
+ self.value_past = Parameter(Tensor(np.zeros(kv_shape), compute_dtype), name="value_past", requires_grad=False)
+
+ def padding(self, key, value, seq_length):
+ """padding key, value"""
+ pad_length = self.sub(self.seq_length_tensor, seq_length)
+ # calculate padding parameter: (0, 0),(0,0),(0,pad_length),(0,0), append values of 'pad_length' in axis
+ pad_config = self.concat((self.pad_before, pad_length, self.pad_after))
+ key_padding = self.pad(key, pad_config, self.pad_zero)
+ value_padding = self.pad(value, pad_config, self.pad_zero)
+ return key_padding, value_padding
+
+ def trimming(self, key, value, zactivate_len, batch_size):
+ """tramming key, value"""
+ if self.is_flexible_shape:
+ key = self.reshape(key, (batch_size, self.n_head, -1, self.head_dim))
+ value = self.reshape(value, (batch_size, self.n_head, -1, self.head_dim))
+ if zactivate_len is not None:
+ act_len = self.shape(zactivate_len)[0]
+ key = self.slice(key, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
+ value = self.slice(value, (0, 0, 0, 0), (batch_size, self.n_head, act_len, self.head_dim), (1, 1, 1, 1))
+ elif not self.is_flexible_shape:
+ key = self.slice(
+ key, (0, 0, 0, 0), (batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1)
+ )
+ value = self.slice(
+ value, (0, 0, 0, 0), (batch_size, self.n_head, self.max_seq_length, self.head_dim), (1, 1, 1, 1)
+ )
+ return key, value
+
+ def auto_caching(self, key_update, value_update, batch_valid_length, seq_length_tensor_pad, batch_index_pad=None):
+ """use kvcache op to cache key, value"""
+ # key_update shape: [real_bs, n_head, max_seqlen, head_dim]
+ if self.is_first_iteration:
+ batch_valid_length = batch_valid_length * 0
+ self.prompt_kvcache(
+ self.key_past,
+ key_update,
+ batch_valid_length,
+ batch_index_pad,
+ self.seqlen_axis_tensor_pad,
+ seq_length_tensor_pad,
+ seq_length_tensor_pad,
+ )
+ self.prompt_kvcache(
+ self.value_past,
+ value_update,
+ batch_valid_length,
+ batch_index_pad,
+ self.seqlen_axis_tensor_pad,
+ seq_length_tensor_pad,
+ seq_length_tensor_pad,
+ )
+ return None
+
+ key_cache = self.key_past
+ value_cache = self.value_past
+ self.decoder_kvcache(
+ self.key_past,
+ key_update,
+ batch_valid_length,
+ batch_index_pad,
+ self.seqlen_axis_tensor_pad,
+ seq_length_tensor_pad,
+ seq_length_tensor_pad,
+ )
+ self.decoder_kvcache(
+ self.value_past,
+ value_update,
+ batch_valid_length,
+ batch_index_pad,
+ self.seqlen_axis_tensor_pad,
+ seq_length_tensor_pad,
+ seq_length_tensor_pad,
+ )
+ key_cache = ops.depend(key_cache, key_update)
+ value_cache = ops.depend(value_cache, value_update)
+ return key_cache, value_cache
+
+ def manual_caching(self, key_update, value_update, valid_length_vector, batch_size):
+ """use assign to cache key, value"""
+ # key_update shape: [real_bs, n_head, 1, head_dim]
+ if self.is_first_iteration:
+ if self.is_dynamic:
+ self.assign(
+ self.key_past, self.reshape(key_update, (self.max_batch_size, self.n_head, -1, self.head_dim))
+ )
+ self.assign(
+ self.value_past, self.reshape(value_update, (self.max_batch_size, self.n_head, -1, self.head_dim))
+ )
+ else:
+ self.assign(self.key_past, self.mul(key_update, valid_length_vector))
+ self.assign(self.value_past, self.mul(value_update, valid_length_vector))
+ return None
+
+ if self.is_dynamic:
+ key = self.add(
+ self.reshape(self.key_past, (batch_size, self.n_head, -1, self.head_dim)),
+ self.mul(key_update, valid_length_vector),
+ )
+ value = self.add(
+ self.reshape(self.value_past, (batch_size, self.n_head, -1, self.head_dim)),
+ self.mul(value_update, valid_length_vector),
+ )
+ self.assign(self.key_past, self.reshape(key, (self.max_batch_size, self.n_head, -1, self.head_dim)))
+ self.assign(self.value_past, self.reshape(value, (self.max_batch_size, self.n_head, -1, self.head_dim)))
+ else:
+ key = self.add(self.key_past, self.mul(key_update, valid_length_vector))
+ value = self.add(self.value_past, self.mul(value_update, valid_length_vector))
+ self.assign(self.key_past, key)
+ self.assign(self.value_past, value)
+ # key shape: [real_bs, n_head, max_cache_len // real_bs, head_dim]
+ return key, value
+
+ def construct(self, key, value, kvcache_inputs=None):
+ """The forward compute of KVCacheMgr."""
+ # TODO: add inputs check
+ batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad = kvcache_inputs
+ batch_size, _, seq_length, _ = self.shape(key)
+ if self.is_first_iteration:
+ if self.is_dynamic:
+ key_padding, value_padding = self.padding(key, value, seq_length=seq_length)
+ else:
+ key_padding, value_padding = key, value
+ if self.use_kvcache_op:
+ self.auto_caching(
+ key_padding, value_padding, batch_valid_length, seq_length_tensor_pad, batch_index_pad
+ )
+ else:
+ self.manual_caching(key_padding, value_padding, batch_valid_length, batch_size=batch_size)
+ else:
+ if self.use_kvcache_op:
+ key, value = self.auto_caching(key, value, batch_valid_length, seq_length_tensor_pad, batch_index_pad)
+ else:
+ key, value = self.manual_caching(key, value, batch_valid_length, batch_size=batch_size)
+ key, value = self.trimming(key, value, zactivate_len, batch_size=batch_size)
+
+ return key, value
+
+
+class KVCachePreprocess(nn.Cell):
+ """KVCache Manager."""
+
+ def __init__(
+ self,
+ max_batch_size=8,
+ max_seq_length=4096,
+ is_dynamic=False,
+ use_kvcache_op=False,
+ is_flexible_shape=False,
+ ):
+ super().__init__()
+ self.is_dynamic = is_dynamic
+ self.use_kvcache_op = use_kvcache_op
+ self.is_flexible_shape = is_flexible_shape
+ self.max_cache_length = max_batch_size * max_seq_length
+ range_len = self.max_cache_length if self.is_flexible_shape else max_seq_length
+ self.range = Tensor(np.arange(range_len).reshape((1, 1, -1)), mstype.int32)
+ self.cache_length_tensor = Tensor([max_batch_size * max_seq_length], dtype=mstype.int32)
+ self.cache_pad_tensor = Tensor([3], dtype=mstype.int64)
+ self.seq_length_tensor = Tensor([max_seq_length], dtype=mstype.int32)
+ self.seq_length_tensor_pad = Tensor([max_seq_length, 3], dtype=mstype.int64)
+ self.is_first_iteration = True
+
+ self.slice = ops.StridedSlice()
+ self.reshape = ops.Reshape().add_prim_attr("skip_redistribution", True)
+ self.equal = ops.Equal()
+ self.less = ops.Less()
+ self.expand_dims = ops.ExpandDims()
+ self.div = ops.Div()
+ self.concat = ops.Concat(axis=0)
+
+ def construct(self, batch_size, batch_valid_length=None, batch_index=None, zactivate_len=None):
+ """precompute kvcache inputs"""
+ seq_range = self.range
+ if self.is_dynamic and self.is_flexible_shape and not self.use_kvcache_op:
+ seq_range = self.slice(seq_range, (0, 0, 0), (1, 1, self.max_cache_length // batch_size), (1, 1, 1))
+
+ if self.use_kvcache_op:
+ if batch_index is None:
+ batch_index = ops.arange(0, batch_size, 1)
+ batch_index_pad = self.concat((batch_index, self.cache_pad_tensor))
+ seq_length_tensor_pad = self.get_seq_length_tensor_pad(batch_size=batch_size)
+ batch_valid_length = self.cast(self.reshape(batch_valid_length, (-1,)), mstype.int64)
+ kvcache_inputs = (batch_valid_length, zactivate_len, batch_index_pad, seq_length_tensor_pad)
+ else:
+ if self.is_first_iteration:
+ valid_length_vector = self.less(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
+ else:
+ valid_length_vector = self.equal(seq_range, self.reshape(batch_valid_length, (-1, 1, 1)))
+ valid_length_vector = self.expand_dims(valid_length_vector, 3)
+ kvcache_inputs = (valid_length_vector, zactivate_len, None, None)
+ return kvcache_inputs
+
+ def get_seq_length_tensor_pad(self, batch_size):
+ """get seq_length_tensor_pad"""
+ if self.is_flexible_shape:
+ max_seq_length = self.div(self.cache_length_tensor, batch_size).astype(mstype.int64)
+ return self.concat((max_seq_length, self.cache_pad_tensor))
+ return self.seq_length_tensor_pad
diff --git a/mindocr/nlp/utils/layers.py b/mindocr/nlp/utils/layers.py
new file mode 100644
index 000000000..f3fcedaa5
--- /dev/null
+++ b/mindocr/nlp/utils/layers.py
@@ -0,0 +1,180 @@
+import mindspore.common.dtype as mstype
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore._extends import cell_attr_register
+from mindspore.common.initializer import initializer
+
+__all__ = ["LayerNorm", "Linear"]
+
+
+class LayerNorm(nn.Cell):
+ r"""
+ A self-defined layer norm operation using reduce sum and reduce mean
+
+ Args:
+ normalized_shape (tuple): The shape of the input tensor
+ eps (float): The epsilon value of the denominator. Default 1e-5.
+ param_init_type: The param init type.
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
+
+ Outputs:
+ Tensor of shape :math:`(batch, seq_length, hidden_size)`.
+ """
+
+ def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32, is_self_defined=False):
+ super(LayerNorm, self).__init__()
+ if param_init_type not in [mstype.float32, mstype.float16, mstype.bfloat16]:
+ raise TypeError(
+ "The type of parameter 'param_init_type' should in [float32, float16], "
+ "but got the type : {}.".format(type(param_init_type))
+ )
+ self.is_self_defined = is_self_defined
+ if not self.is_self_defined:
+ self.layer_norm = ops.LayerNorm(begin_norm_axis=-1, begin_params_axis=-1, epsilon=eps)
+ self.gamma = Parameter(
+ initializer("ones", normalized_shape, param_init_type), name="gamma", parallel_optimizer=False
+ )
+ self.beta = Parameter(
+ initializer("zeros", normalized_shape, param_init_type), name="beta", parallel_optimizer=False
+ )
+ self.mean = ops.ReduceMean(keep_dims=True)
+ self.square = ops.Square()
+ self.sqrt = ops.Sqrt()
+ self.sub1 = ops.Sub()
+ self.sub2 = ops.Sub()
+ self.add = ops.Add()
+ self.eps = eps
+ self.mul = ops.Mul()
+ self.add2 = ops.Add()
+ self.real_div = ops.RealDiv()
+
+ def construct(self, x):
+ # x : batch x seq_length x hidden_size
+ if self.is_self_defined:
+ mean = self.mean(x, -1)
+ diff = self.sub1(x, mean)
+ variance = self.mean(self.square(diff), -1)
+ variance_eps = self.sqrt(self.add(variance, self.eps))
+ output = self.real_div(diff, variance_eps)
+ output = self.add2(self.mul(output, self.gamma), self.beta)
+ else:
+ output, _, _ = self.layer_norm(x, self.gamma, self.beta)
+ return output
+
+
+class Linear(nn.Cell):
+ r"""
+ The dense connected layer. Once the parallel mode is enabled, the input shape should be
+ 3-D tensor.
+
+ Applies dense connected layer for the input. This layer implements the operation as:
+
+ .. math::
+ \text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
+
+ where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
+ argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
+ data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
+ with the same data type as the :math:`X` created by the layer (only if has_bias is True).
+
+ Args:
+ in_channels (int): The number of channels in the input space.
+ out_channels (int): The number of channels in the output space.
+ weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
+ is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
+ bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
+ same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
+ has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
+ activation (Union[nn.Cell, str]): activate function applied to the output of the fully connected layer,
+ eg. 'ReLU'. Default: None.
+ outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
+ Default: 1.
+ expert_group_size (int): The number of tokens in each data parallel group. Default: None.
+ compute_dtype (dtype.Number): The computation type. Default: mstype.float16
+ Inputs:
+ - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
+ to :math:`in\_channels` in `Inputs`.
+
+ Outputs:
+ Tensor of shape :math:`(*, out\_channels)`.
+
+ Raises:
+ TypeError: If `in_channels` or `out_channels` is not an int.
+ TypeError: If `has_bias` is not a bool.
+ TypeError: If `activation` is not one of str, nn.Cell, Primitive, None.
+ ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
+ is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
+ ValueError: If length of shape of `bias_init` is not equal to 1
+ or shape[0] of `bias_init` is not equal to `out_channels`.
+
+ Supported Platforms:
+ ``Ascend`` ``GPU``
+ """
+
+ @cell_attr_register
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ weight_init="normal",
+ bias_init="zeros",
+ has_bias=True,
+ activation=None,
+ transpose_b=True,
+ outer_batch=1,
+ expert_group_size=None,
+ param_init_type=mstype.float32,
+ compute_dtype=mstype.float16,
+ skip_redistribution=False,
+ ):
+ super(Linear, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
+ raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
+ if isinstance(weight_init, Tensor) and (
+ weight_init.ndim != 2 or weight_init.shape[0] != out_channels or weight_init.shape[1] != in_channels
+ ):
+ raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
+ weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
+ self.outer_batch = outer_batch
+ self.expert_group_size = expert_group_size
+ self.transpose_b = transpose_b
+ self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
+ self.matmul = ops.MatMul(transpose_b=transpose_b)
+ self.bias = None
+ self.has_bias = has_bias
+ if self.has_bias:
+ if isinstance(bias_init, Tensor) and (bias_init.ndim != 1 or bias_init.shape[0] != out_channels):
+ raise ValueError("The shape of parameter 'bias_init' is error, please check shape of 'bias_init'.")
+ self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
+ self.bias.parallel_optimizer = False
+ self.bias_add = ops.Add()
+ self.act_name = activation
+ if callable(activation):
+ self.activation = activation()
+ else:
+ self.activation = activation
+ self.activation_flag = self.activation is not None
+ self.dtype = compute_dtype
+ self.cast = ops.Cast()
+ self.reshape = ops.Reshape()
+ if skip_redistribution:
+ self.reshape.add_prim_attr("skip_redistribution", True)
+ self.shape = ops.Shape()
+
+ def construct(self, x):
+ """Forward process, x should be a tensor"""
+ out_shape = self.shape(x)[:-1] + (self.out_channels,)
+ x = self.reshape(x, (-1, self.in_channels))
+ ori_dtype = ops.dtype(x)
+ weight = self.cast(self.weight, self.dtype)
+ x = self.cast(x, self.dtype)
+ x = self.matmul(x, weight)
+ if self.has_bias:
+ x = self.bias_add(x, self.cast(self.bias, self.dtype))
+ if self.activation_flag:
+ x = self.activation(x)
+ x = ops.cast(x, ori_dtype)
+ output = self.reshape(x, out_shape)
+ return output
diff --git a/mindocr/nlp/utils/loss.py b/mindocr/nlp/utils/loss.py
new file mode 100644
index 000000000..7e59c6f9f
--- /dev/null
+++ b/mindocr/nlp/utils/loss.py
@@ -0,0 +1,139 @@
+from mindspore import Tensor, nn, ops
+from mindspore.common import dtype as mstype
+
+__all__ = ["CrossEntropyLoss"]
+
+
+class _Softmax(nn.Cell):
+ """
+ Calculate the softmax results with given logits. The bprop of the cell is rewritten,
+ just returns the accepted dout as returns. This cell should be used together with _NLLoss,
+ to optimize the bprop of the cross entroy loss.
+
+ Inputs:
+ - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
+ the backbone.
+
+ - **label** (Tensor) - Tensor of shape (N, 1). The ground truth label of the sample.
+
+ Returns:
+ The corresponding softmax results.
+ """
+
+ def __init__(self):
+ super(_Softmax, self).__init__()
+ # on/off value for onehot, for smooth labeling, modify the off_value
+ self.on_value = Tensor(1.0, mstype.float32)
+ self.off_value = Tensor(0.0, mstype.float32)
+
+ self.sum = ops.ReduceSum()
+ self.max = ops.ArgMaxWithValue(axis=-1, keep_dims=True)
+ self.sub = ops.Sub()
+ self.exp = ops.Exp()
+ self.div = ops.RealDiv()
+ self.onehot = ops.OneHot()
+
+ def construct(self, logits, label):
+ """Forward process"""
+ # LogSoftmax for logits over last dimension
+ logits = ops.cast(logits, mstype.float32)
+ _, logit_max = self.max(logits)
+ logit_sub = self.sub(logits, logit_max)
+ logit_exp = self.exp(logit_sub)
+ exp_sum = self.sum(logit_exp, -1)
+ exp_sum = ops.Reshape()(exp_sum, (ops.shape(exp_sum)[0], 1))
+ softmax_result = self.div(logit_exp, exp_sum)
+
+ one_hot_label = self.onehot(label, ops.shape(logits)[-1], self.on_value, self.off_value)
+ return softmax_result, one_hot_label
+
+ def bprop(self, logits, label, _, dout):
+ """just return the loss of the dout. Note this should be used together with _NLLLoss"""
+ d_logits = ops.cast(dout[0], ops.dtype(logits))
+ return d_logits, ops.zeros_like(label)
+
+
+class _NLLLoss(nn.Cell):
+ """
+ Calculate the NLLLoss results with given softmax results and the label. The bprop of the cell is rewritten.
+ This cell should be used together with _Softmax, to optimize the bprop of the cross entroy loss.
+
+ Inputs:
+ - **softmax_result** (Tensor) - Tensor of shape (N, C). Data type is float32.
+ - **one_hot_label** (Tensor) - Tensor of shape (N, C). The ground truth label in one-hot format of the sample.
+
+ Returns:
+ The corresponding loss results.
+ """
+
+ def __init__(self, eps_const=1e-24):
+ super(_NLLLoss, self).__init__()
+ self.repeat_loss = 1
+ self.eps_const = Tensor(eps_const, mstype.float32)
+ self.sum = ops.ReduceSum()
+ self.mul = ops.Mul()
+ self.neg = ops.Neg()
+ self.log = ops.Log()
+ self.add = ops.Add()
+
+ def construct(self, softmax_result, one_hot_label):
+ log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
+ loss = self.mul(log_softmax_result, one_hot_label)
+ loss_unsum = self.neg(loss)
+ loss_reduce = self.sum(loss_unsum, -1)
+ return loss_reduce
+
+ def bprop(self, softmax_result, one_hot_label, _, dout):
+ """A simplified function. Note this should be used together with _Softmax"""
+ logits = softmax_result - one_hot_label
+ logits = logits * ops.ExpandDims()(dout, -1) * self.repeat_loss
+
+ return logits, ops.zeros_like(one_hot_label)
+
+
+class CrossEntropyLoss(nn.Cell):
+ """
+ Calculate the cross entropy loss.
+
+ Inputs:
+ - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
+ the backbone.
+
+ - **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
+
+ - **input_mask** (Tensor) - Tensor of shape (N, ). input_mask indicates whether there are padded inputs and for
+ padded inputs it will not be counted into loss.
+
+ Returns:
+ The corresponding cross entropy loss.
+ """
+
+ def __init__(self, eps_const=1e-24):
+ super(CrossEntropyLoss, self).__init__()
+ self.enable_force_redistribute = False
+ self.sum2 = ops.ReduceSum()
+ self.mul2 = ops.Mul()
+ self.add2 = ops.Add()
+ self.div2 = ops.RealDiv()
+ self.relu = ops.ReLU()
+
+ self._softmax = _Softmax()
+ self._nllloss = _NLLLoss(eps_const)
+
+ def construct(self, logits, label, input_mask):
+ """Forward process"""
+ # The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
+ if self.enable_force_redistribute:
+ logits = self.add(logits, 0)
+ label = self.add_label(label, 0)
+ softmax, one_hot_label = self._softmax(logits, label)
+ loss_reduce = self._nllloss(softmax, one_hot_label)
+
+ # Using input_mask to mask the loss
+ input_mask = ops.Reshape()(input_mask, (-1,))
+ numerator = self.sum2(self.mul2(loss_reduce, input_mask))
+
+ denominator = self.add2(self.sum2(input_mask), ops.Cast()(ops.tuple_to_array((1e-5,)), mstype.float32))
+ loss = self.div2(numerator, denominator)
+
+ return loss
diff --git a/mindocr/utils/conversation.py b/mindocr/utils/conversation.py
new file mode 100644
index 000000000..b0002db36
--- /dev/null
+++ b/mindocr/utils/conversation.py
@@ -0,0 +1,41 @@
+import dataclasses
+from typing import List, Tuple
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ def __init__(
+ self,
+ system: str = None,
+ roles: Tuple[str, str] = ("user", "assistant"),
+ messages: List[Tuple[str, str]] = None,
+ sep: str = "<|im_end|>",
+ ):
+ self.system = (
+ "<|im_start|>{system}\n{message}{sep}\n".format(
+ system="system",
+ message="You should follow the instructions carefully and explain your answers in detail.",
+ sep=sep,
+ )
+ if system is None
+ else system
+ )
+ self.roles = roles
+ self.messages = list() if messages is None else messages
+ self.sep = sep
+
+ def get_messages(self):
+ return self.messages
+
+ def get_prompt(self):
+ ret = self.system if self.system else ""
+ for role, message in self.messages:
+ ret += "<|im_start|>{role}\n{message}{sep}\n".format(role=role, message=message, sep=self.sep)
+ return ret
+
+ def add_message(self, role, message):
+ if role not in self.roles:
+ raise ValueError("role must be in {}.".format(self.roles))
+ self.messages.append((role, message))
diff --git a/requirements.txt b/requirements.txt
index 20ba0a4bf..25c8aa770 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -20,3 +20,4 @@ seqeval>=1.2.2
requests>=2.31.0
pycocotools>=2.0.2
setuptools-scm
+albumentations
diff --git a/tools/infer/text/predict_llm.py b/tools/infer/text/predict_llm.py
new file mode 100644
index 000000000..9eb02fb20
--- /dev/null
+++ b/tools/infer/text/predict_llm.py
@@ -0,0 +1,112 @@
+import argparse
+import logging
+import os
+
+from PIL import Image
+
+import mindspore as ms
+
+from mindocr.data.transforms.llm_transform import image_processor, image_processor_high
+from mindocr.nlp.llm.configs import LLMConfig
+from mindocr.nlp.llm.qwen_tokenizer import QwenTokenizer
+from mindocr.nlp.llm.vary_qwen_model import VaryQwenForCausalLM
+from mindocr.utils.logger import set_logger
+
+
+def load_image(image_file):
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "1"):
+ return True
+ elif v.lower() in ("no", "false", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Inference Config Args")
+ parser.add_argument("--image_dir", type=str, required=True, help="image path")
+ parser.add_argument("--query", type=str, required=False, default="Provide the ocr results of this image.")
+ parser.add_argument("--config_path", type=str, required=False, default="../../../configs/llm/vary/vary_toy.yaml")
+ parser.add_argument("--chat_mode", type=str2bool, required=False, default=False)
+ args = parser.parse_args()
+ return args
+
+
+class LLMGenerator(object):
+ def __init__(self, args):
+ config_path = args.config_path
+ config = LLMConfig(config_path)
+ ms.set_context(
+ mode=ms.GRAPH_MODE,
+ device_target="Ascend",
+ enable_graph_kernel=False,
+ graph_kernel_flags="--disable_expand_ops=Softmax,Dropout --enable_parallel_fusion=true "
+ "--reduce_fuse_depth=8 --enable_auto_tensor_inplace=true",
+ ascend_config={"precision_mode": "must_keep_origin_dtype"},
+ max_call_depth=10000,
+ max_device_memory="58GB",
+ save_graphs=False,
+ save_graphs_path="./graph",
+ device_id=os.environ.get("DEVICE_ID", 0),
+ )
+ self.tokenizer = QwenTokenizer(**config.processor.tokenizer)
+ self.model = VaryQwenForCausalLM.from_pretrained(config_path)
+
+ self.image_dir = args.image_dir
+ self.query = args.query
+ self.seq_length = self.model.seq_length
+ self.chat_mode = args.chat_mode
+
+ def _call_one(self, query=None, image=None, image_high=None):
+ response = self.model.chat(tokenizer=self.tokenizer, query=query, image=image, image_high=image_high)
+ print(">" * 100)
+ print(response)
+ print("<" * 100)
+ return response
+
+ def __call__(self, query=None, image_dir=None):
+ self.model.reset()
+ is_first_iteration = True
+ if query is None:
+ query = self.query
+ if image_dir is None:
+ image_dir = self.image_dir
+ image = load_image(image_dir)
+ image_high = image_processor_high(image)
+ image = image_processor(image)
+ while True:
+ try:
+ if is_first_iteration:
+ self._call_one(query=query, image=image, image_high=image_high)
+ if not self.chat_mode:
+ break
+ is_first_iteration = False
+ if self.chat_mode:
+ logging.info("You can input 'exit' to quit the conversation, or input your query:")
+ query = input()
+ if query == "exit":
+ break
+ self._call_one(query=query, image=None, image_high=None)
+ except ValueError as e:
+ if "check your inputs and set max_length larger than your inputs length." in e.args[0]:
+ logging.warning("The input is too long. The conversation is closed.")
+ break
+ raise e
+
+
+def main():
+ set_logger()
+ args = parse_args()
+ llm_generator = LLMGenerator(args)
+ llm_generator()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/paddle2mindir.sh b/tools/paddle2mindir.sh
new file mode 100644
index 000000000..88c4c75fc
--- /dev/null
+++ b/tools/paddle2mindir.sh
@@ -0,0 +1,153 @@
+#!/bin/bash
+
+usage() {
+ echo -e "Usage"
+ echo -e " paddle2mindir.sh [-m=PPOCR_MODEL_NAME] \\"
+ echo -e " [-p=SAVE_DIR] \\"
+ echo -e " "
+ echo -e "Description"
+ echo -e " PPOCR_MODEL_NAME: Name of support models. Supported models: 'ch_PP-OCRv4', 'ch_PP-OCRv4_server'"
+ echo -e " SAVE_DIR: folder to save downloaded ppocr models and converted mindir"
+ exit -1
+}
+
+SAVE_DIR_=ppocr_models
+for key in "$@"; do
+ case $key in
+ -m=*|--ppocr_model_name=*) PPOCR_MODEL_NAME_="${key#*=}";;
+ -p=*|--save_dir=*) SAVE_DIR_="${key#*=}";;
+ -h|--help) usage;;
+ esac
+done
+
+# convert ppocr to mindir with dynamic shape
+generate_dynamic_shape_config_file(){
+ if [ $GENERATE_CONFIG_FILE == "True" ]; then
+ echo "[acl_build_options]" > $CONFIG_FILE
+ echo "input_format=NCHW" >> $CONFIG_FILE
+ echo "input_shape_range=x:[-1,3,$DATA_SHAPE_H,$DATA_SHAPE_W]" >> $CONFIG_FILE
+ fi
+}
+
+report_paddle2onnx(){
+ report_paddle2onnx_filename=$SAVE_ONNX_FILE
+ if [ -f "$report_paddle2onnx_filename" ]; then
+ echo -e "\033[32mpaddle2onnx Success\033[0m: $report_paddle2onnx_filename" | tee -a $logFile
+ else
+ echo -e "\033[31mpaddle2onnx Failed\033[0m: $report_paddle2onnx_filename" | tee -a $logFile
+ fi
+}
+report_convert(){
+ report_convert_filename=$SAVE_CONVERT_FILE
+ # 2.1.1: ms, 2.2.0:mindir
+ if [ -f "$report_convert_filename".ms ]; then
+ echo -e "\033[32mConvert Success\033[0m: $report_convert_filename.ms" | tee -a $logFile
+ elif [ -f "$report_convert_filename".mindir ]; then
+ echo -e "\033[32mConvert Success\033[0m: $report_convert_filename.mindir" | tee -a $logFile
+ else
+ echo -e "\033[31mConvert Failed\033[0m: $report_convert_filename" | tee -a $logFile
+ fi
+}
+## pp-ocr configuration
+# Converter_lite
+CONVERTER_PATH="converter_lite"
+# path to save log
+LOG_NAME="paddle2mindir.log"
+# folder to save downloaded ppocr models and converted mindir
+SAVE_DIR=${SAVE_DIR_}
+# If generated config.txt,default True
+GENERATE_CONFIG_FILE=True
+# Models, supported: ["ch_PP-OCRv4", "ch_PP-OCRv4_server"]
+PPOCR_MODEL_NAME=${PPOCR_MODEL_NAME_}
+############
+
+FILE_PATH=$(cd "$(dirname "$0")"; pwd)
+logFile="$FILE_PATH/$LOG_NAME"
+infoCmd=">> $logFile 2>&1"
+
+pip3 install paddle2onnx==1.0.5
+ppocr_path=$FILE_PATH/$SAVE_DIR
+mkdir -p $ppocr_path/models
+cd $ppocr_path/models
+
+if [ "$PPOCR_MODEL_NAME" = "ch_PP-OCRv4" ]; then
+ det_model="ch_PP-OCRv4_det_infer"
+ rec_model="ch_PP-OCRv4_rec_infer"
+ cls_model="ch_ppocr_mobile_v2.0_cls_infer"
+elif [ "$PPOCR_MODEL_NAME" = "ch_PP-OCRv4_server" ]; then
+ det_model="ch_PP-OCRv4_det_server_infer"
+ rec_model="ch_PP-OCRv4_rec_server_infer"
+ cls_model="ch_ppocr_mobile_v2.0_cls_infer"
+else
+ echo "$PPOCR_MODEL_NAME is not supported"
+ exit
+fi
+
+### det model convertion
+cd $ppocr_path/models
+wget https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/${det_model}.tar --no-check-certificate
+tar xvf ${det_model}.tar
+cd ${det_model}
+SAVE_DB_ONNX_FILE=$ppocr_path/models/${det_model}/det_db.onnx
+cmd="paddle2onnx --model_dir $ppocr_path/models/${det_model} --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file $SAVE_DB_ONNX_FILE --opset_version 11 --enable_onnx_checker True $infoCmd"
+echo -e "\033[36mpaddle2onnx command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_ONNX_FILE=$SAVE_DB_ONNX_FILE
+report_paddle2onnx
+CONFIG_FILE=$ppocr_path/models/${det_model}/dynamic_config.txt
+DATA_SHAPE_H=-1
+DATA_SHAPE_W=-1
+generate_dynamic_shape_config_file
+SAVE_DB_MINDIR_FILE=$ppocr_path/models/${det_model}/det_db_dynamic_output
+cmd="$CONVERTER_PATH --saveType=MINDIR --fmk=ONNX --optimize=ascend_oriented --modelFile=$SAVE_DB_ONNX_FILE --outputFile=$SAVE_DB_MINDIR_FILE --configFile=dynamic_config.txt $infoCmd"
+echo -e "\033[36mconvert command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_CONVERT_FILE=$SAVE_DB_MINDIR_FILE
+report_convert
+mv $ppocr_path/models/${det_model}/det_db_dynamic_output.mindir $ppocr_path/${PPOCR_MODEL_NAME}_det_db_dynamic_output.mindir
+
+### rec model convertion
+cd $ppocr_path/models
+wget https://paddleocr.bj.bcebos.com/PP-OCRv4/chinese/${rec_model}.tar --no-check-certificate
+tar xvf ${rec_model}.tar
+cd ${rec_model}
+SAVE_REC_ONNX_FILE=$ppocr_path/models/${rec_model}/rec_crnn.onnx
+cmd="paddle2onnx --model_dir $ppocr_path/models/${rec_model} --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file $SAVE_REC_ONNX_FILE --opset_version 11 --enable_onnx_checker True $infoCmd"
+echo -e "\033[36mpaddle2onnx command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_ONNX_FILE=$SAVE_REC_ONNX_FILE
+report_paddle2onnx
+CONFIG_FILE=$ppocr_path/models/${rec_model}/dynamic_config.txt
+DATA_SHAPE_H=-1
+DATA_SHAPE_W=-1
+generate_dynamic_shape_config_file
+SAVE_REC_MINDIR_FILE=$ppocr_path/models/${rec_model}/rec_crnn_dynamic_output
+cmd="$CONVERTER_PATH --saveType=MINDIR --fmk=ONNX --optimize=ascend_oriented --modelFile=$SAVE_REC_ONNX_FILE --outputFile=$SAVE_REC_MINDIR_FILE --configFile=dynamic_config.txt $infoCmd"
+echo -e "\033[36mconvert command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_CONVERT_FILE=$SAVE_REC_MINDIR_FILE
+report_convert
+mv $ppocr_path/models/${rec_model}/rec_crnn_dynamic_output.mindir $ppocr_path/${PPOCR_MODEL_NAME}_rec_crnn_dynamic_output.mindir
+
+### cls model convertion
+cd $ppocr_path/models
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/${cls_model}.tar --no-check-certificate
+tar xvf ${cls_model}.tar
+cd ${cls_model}
+SAVE_CLS_ONNX_FILE=$ppocr_path/models/${cls_model}/cls_mv4.onnx
+cmd="paddle2onnx --model_dir $ppocr_path/models/${cls_model} --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file $SAVE_CLS_ONNX_FILE --opset_version 11 --enable_onnx_checker True $infoCmd"
+echo -e "\033[36mpaddle2onnx command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_ONNX_FILE=$SAVE_CLS_ONNX_FILE
+report_paddle2onnx
+CONFIG_FILE=$ppocr_path/models/${cls_model}/dynamic_config.txt
+DATA_SHAPE_H=-1
+DATA_SHAPE_W=-1
+generate_dynamic_shape_config_file
+SAVE_CLS_MINDIR_FILE=$ppocr_path/models/${cls_model}/cls_mv4_dynamic_output
+cmd="$CONVERTER_PATH --saveType=MINDIR --fmk=ONNX --optimize=ascend_oriented --modelFile=$SAVE_CLS_ONNX_FILE --outputFile=$SAVE_CLS_MINDIR_FILE --configFile=dynamic_config.txt $infoCmd"
+echo -e "\033[36mconvert command:\033[0m $cmd" | tee -a $logFile
+eval $cmd
+SAVE_CONVERT_FILE=$SAVE_CLS_MINDIR_FILE
+report_convert
+mv $ppocr_path/models/${cls_model}/cls_mv4_dynamic_output.mindir $ppocr_path/${PPOCR_MODEL_NAME}_cls_mv4_dynamic_output.mindir