Skip to content

zheng-yuwei/PyTorch-Image-Classification

Repository files navigation

基于PyTorch的分类网络库

实现的分类网络包括:

  • PyTorch自带的网络:resnet, shufflenet, densenet, mobilenet, mnasnet等;
  • MobileNet v3;
  • EfficientNet系列;
  • ResNeSt系列;

包含特性

  • 支持多种功能(application/):训练、测试、转JIT部署、模型蒸馏、可视化;
  • 数据增强(dataloader/enhancement):AutoAugment,自定义Augment(MyAugment),mixup数据增强,多尺度训练数据增强;
  • 库中包含多种优化器(optim):目前使用的是Adam,同时推荐RAdam;
  • 不同损失指标的实现(criterions):OHM、GHM、weighted loss等;

文件结构说明

  • applications: 包括test.py, train.py, convert.py等应用,提供给main.py调用;
  • checkpoints: 训练好的模型文件保存目录(当前可能不存在);
  • criterions: 自定义损失函数;
  • data: 训练/测试/验证/预测等数据集存放的路径;
  • dataloader: 数据加载、数据增强、数据预处理(默认采用ImageNet方式);
  • demos: 模型使用的demo,目前classifier.py显示如何调用jit格式模型进行预测;
  • logs: 训练过程中TensorBoard日志存放的文件(当前可能不存在);
  • models: 自定义的模型结构;
  • optim: 一些前沿的优化器,PyTorch官方还未实现;
  • pretrained: 预训练模型文件;
  • utils: 工具脚本:混淆矩阵、图片数据校验、模型结构打印、日志等;
  • config.py: 配置文件;
  • main.py: 总入口;
  • requirements.txt: 工程依赖包列表;

使用说明

数据准备

在文件夹data下放数据,分成三个文件夹: train/test/val,对应 训练/测试/验证 数据文件夹; 每个子文件夹下,依据分类类别每个类别建立一个对应的文件夹,放置该类别的图片。

数据准备完毕后,使用utils/check_images.py脚本,检查图像数据的有效性,防止在训练过程中遇到无效图片中止训练。

最终大概结构为:

- data
  - train
    - class_0
      - 0.jpg
      - 1.jpg
      - ...
    - class_1
      - ...
    - ..
  - test
    - ...
  - val
    - ...
- dataloader
- ...

部分重要配置参数说明

针对config.py里的部分重要参数说明如下:

  • --data: 数据集根目录,下面包含train, test, val三个目录的数据集,默认当前文件夹下data/目录;
  • --image_size: 输入应该为两个整数值,预训练模型的输入时正方形的,也就是[224, 224]之类的; 实际可以根据自己需要更改,数据预处理时,会将图像 等比例resize然后再padding(默认用0 padding)到 指定的输入尺寸。
  • --num_classes: 分类模型的预测类别数;
  • -b: 设置batch size大小,默认为256,可根据GPU显存设置;
  • -j: 设置数据加载的进程数,默认为8,可根据CPU使用量设置;
  • --criterion: 损失函数,一种使用PyTorch自带的softmax损失函数,一种使用我自定义的sigmoid损失函数; sigmoid损失函数则是将多分类问题转化为多标签二分类问题,同时我增加了几个如GHM自定义的sigmoid损失函数, 可通过--weighted_loss --ghm_loss --threshold_loss --ohm_loss指定是否启动;
  • --lr: 初始学习率,main.py里我默认使用Adam优化器;目前学习率的scheduler我使用的是LambdaLR接口,自定义函数规则如下, 详细可参考main.pyadjust_learning_rate(epoch, args)函数:
~ warmup: 0.1
~ warmup + int([1.5 * (epochs - warmup)]/4.0): 1, 
~ warmup + int([2.5 * (epochs - warmup)]/4.0): 0.1
~ warmup + int([3.5 * (epochs - warmup)]/4.0) 0.01
~ epochs: 0.001
  • --warmup: warmup的迭代次数,训练前warmup个epoch会将 初始学习率*0.1 作为warmup期间的学习率;
  • --epochs: 训练的总迭代次数;
  • --aug: 是否使用数据增强,目前默认使用的是我自定义的数据增强方式:dataloader/my_augment.py
  • --mixup: 数据增强mixup,默认 False;
  • --multi_scale: 多尺度训练,默认 False;
  • --resume: 权重文件路径,模型文件将被加载以进行模型初始化,--jit--evaluation时需要指定;
  • --jit: 将模型转为JIT格式,利于部署;
  • --evaluation: 在测试集上进行模型评估;
  • --knowledge: 指定数据集,使用教师模型(配合resume选型指定)对该数据集进行预测,获取概率文件(知识), 生成的概率文件路径为data/distill.txt,同时生成原始概率data/label.txt;
  • --distill: 模型蒸馏(需要教师模型输出的概率文件),默认 False, 使用该模式训练前,需要先启用--knowledge train --resume teacher.pth对训练集进行测试,生成概率文件作为教师模型的概率; 概率文件形式为data路径下distill*.txt模式的文件,有多个文件会都使用,取均值作为教师模型的概率输出指导接下来训练的学生模型;
  • --visual_data: 对指定数据集运行测试,并进行可视化;
  • --visual_method: 可视化方法,包含cam, grad-cam, grad-camm++三种;
  • --make_curriculum: 制作课程学习的课程文件;
  • --curriculum_thresholds: 不同课程中样本的阈值;
  • --curriculum_weights: 不同课程中样本的损失函数权重;
  • --curriculum_learning: 进行课程学习,从data/curriculum.txt中读取样本权重数据,训练时给对应样本的损失函数加权;

BTW,在models/efficientnet/model.py中增加了sample-free的思想,目前代码注释掉了,若需要可以借鉴使用。 sample-free主要是我使用bce进行多标签二分类时,我希望任务偏好某些类别,所以在初始某些类别的bias上设置一个较大的数,提高初始概率。 (具体计算公式可参考原论文 Is Sampling Heuristics Necessary in Training Deep Object Detectors)

参数的详细说明可查看config.py文件。


快速使用

可参考对应的z_task_shell/*.sh文件

模型部署demo

训练好模型后,想用该模型对图像数据进行预测,可使用demos目录下的脚本classifier.py

cd demos
python classifier.py

Reference

d-li14/mobilenetv3.pytorch

lukemelas/EfficientNet-PyTorch

zhanghang1989/ResNeSt

yizt/Grad-CAM.pytorch

TODO

  • 预训练模型下载URL整理(参考Reference);
  • 模型的openvino格式的转换和对应的部署demo;

About

基于PyTorch框架实现的图像分类网络

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published