Skip to content

Pytorch implementation of TPAMI2023: CMW-NetCMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning

Notifications You must be signed in to change notification settings

xjtushujun/CMW-Net

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

43 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CMW-Net

TPAMI2023: CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning (Official Pytorch implementation)

======================================================================================================================================================

This is the code for the paper: CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning. Jun Shu, Xiang Yuan, Deyu Meng, and Zongben Xu. Official site, Arxiv Vervision


CONTENTS

Overview

Modern deep neural networks (DNNs) can easily overfit to biased training data containing corrupted labels or class imbalance. Sample re-weighting methods are popularly used to alleviate this data bias issue. Most current methods, however, require manually pre-specifying the weighting schemes as well as their additional hyper-parameters relying on the characteristics of the investigated problem and training data. This makes them fairly hard to be generally applied in practical scenarios, due to their significant complexities and inter-class variations of data bias situations. To address this issue, we propose a meta-model capable of adaptively learning an explicit weighting scheme directly from data. Specifically, by seeing each training class as a separate learning task, our method aims to extract an explicit weighting function with sample loss and task/class feature as input, and sample weight as output, expecting to impose adaptively varying weighting schemes to different sample classes based on their own intrinsic bias characteristics. The architectures of the CMW-Net meta-model is shown blow:

Prerequisites

  • Python 3.7
  • PyTorch >= 1.5.0
  • Torchvision >= 0.4.0
  • sklearn
  • torchnet

Experiments

Synthetic and real data experiments substantiate the capability of our method on achieving proper weighting schemes in various data bias cases. The task-transferability of the learned weighting scheme is also substantiated. A performance gain can be readily achieved compared with previous state-of-the-art ones without additional hyper-parameter tuning and meta gradient descent step. The general availability of our method for multiple robust deep learning issues has also been validated. We provide the running scripts in corresponding code. The detail description and main results are shown below.

learning with synthetic biased data

Class Imbalance Experiments

You can repeat the results of Class Imbalance Experiments(TABLE 1 in the paper) by

cd section4/Class_Imbalance
bash table1.sh

The main results are shown below:

Dataset Name CIFAR-10-LT CIFAR-100-LT
Imbalance factor 200 100 50 20 10 1 200 100 50 20 10 1
ERM 34.32 29.64 25.19 17.77 13.61 7.53 65.16 61.68 56.15 48.86 44.29 29.50
Focal loss 34.71 29.62 23.29 17.24 13.34 6.97 64.38 61.59 55.68 48.05 44.22 28.85
CB loss 31.11 27.63 21.95 15.64 13.23 7.53 64.44 61.23 55.21 48.06 42.43 29.37
LDAM loss - 26.65 - - 13.04 - 60.40 - - - 43.09 -
L2RW 33.49 25.84 21.07 16.90 14.81 10.75 66.62 59.77 55.56 48.36 46.27 35.89
MW-Net 32.80 26.43 20.90 15.55 12.45 7.19 63.38 58.39 54.34 46.96 41.09 29.90
MCW with CE loss 29.34 23.59 19.49 13.54 11.15 7.21 60.69 56.65 51.47 44.38 40.42 -
CMW-Net with CE loss 27.80 21.15 17.26 12.45 10.97 8.30 60.85 55.25 49.73 43.06 39.41 30.81
MCW with LDAM loss 25.10 20.00 17.77 15.63 12.60 10.29 60.47 55.92 50.84 47.62 42.00 -
CMW-Net with LDAM loss 25.57 19.95 17.66 13.08 11.42 7.04 59.81 55.87 51.14 45.26 40.32 29.19
SADE 19.37 16.78 14.81 11.78 9.88 7.72 54.78 50.20 46.12 40.06 36.40 28.08
CMW-Net with SADE 19.11 16.04 13.54 10.25 9.39 5.39 54.59 49.50 46.01 39.42 34.78 27.50

Details can refer to Section 4.1 of the main paper.

Feature-independent Label Noise Experiment

You can repeat the results of Feature-independent Label Noise Experiment(TABLE 2 and TABLE 3 in the paper) by

cd section4/Feature-independent_Label_Noise
bash table2.sh

The main results are shown below:

Datasets Noise Symmetric Noise Asymmetric Noise
0.2 0.4 0.6 0.8 0.2 0.4 0.6 0.8
CIFAR-10 ERM 86.98 ± 0.12 77.52 ± 0.41 73.63 ± 0.85 53.82 ± 1.04 83.60 ± 0.24 77.85 ± 0.98 69.69 ± 0.72 55.20 ± 0.28
Forward 87.99 ± 0.36 83.25 ± 0.38 74.96 ± 0.65 54.64 ± 0.44 91.34 ± 0.28 89.87 ± 0.61 87.24 ± 0.96 81.07 ± 1.92
GCE 89.99 ± 0.16 87.31 ± 0.53 82.15 ± 0.47 57.36 ± 2.08 89.75 ± 1.53 87.75 ± 0.36 67.21 ± 3.64 57.46 ± 0.31
M-correction 93.80 ± 0.23 92.53 ± 0.11 90.30 ± 0.34 86.80 ± 0.11 92.15 ± 0.18 91.76 ± 0.57 87.59 ± 0.33 67.78 ± 1.22
DivideMix 95.70 ± 0.31 95.00 ± 0.17 94.23 ± 0.23 92.90 ± 0.31 93.96 ± 0.21 91.80 ± 0.78 80.14 ± 0.45 59.23 ± 0.38
L2RW 89.45 ± 0.62 87.18 ± 0.84 81.57 ± 0.66 58.59 ± 1.84 90.46 ± 0.56 89.76 ± 0.53 88.22 ± 0.71 85.17 ± 0.31
MW-Net 90.46 ± 0.52 86.53 ± 0.57 82.98 ± 0.34 64.41 ± 0.92 92.69 ± 0.24 90.17 ± 0.11 68.55 ± 0.76 58.29 ± 1.33
CMW-Net 91.09 ± 0.54 86.91 ± 0.37 83.33 ± 0.55 64.80 ± 0.72 93.02 ± 0.25 92.70 ± 0.32 91.28 ± 0.40 87.50 ± 0.26
CMW-Net-SL 96.20 ± 0.33 95.29 ± 0.14 94.51 ± 0.32 92.10 ± 0.76 95.48 ± 0.29 94.51 ± 0.52 94.18 ± 0.21 93.07 ± 0.24
CIFAR-100 ERM 60.38 ± 0.75 46.92 ± 0.51 31.82 ± 1.16 8.29 ± 3.24 61.05 ± 0.11 50.30 ± 1.11 37.34 ± 1.80 12.46 ± 0.43
Forward 63.71 ± 0.49 49.34 ± 0.60 37.90 ± 0.76 9.57 ± 1.01 64.97 ± 0.47 52.37 ± 0.71 44.58 ± 0.60 15.84 ± 0.62
GCE 68.02 ± 1.05 64.18 ± 0.30 54.46 ± 0.31 15.61 ± 0.97 66.15 ± 0.44 56.85 ± 0.72 40.58 ± 0.47 15.82 ± 0.63
M-correction 73.90 ± 0.14 70.10 ± 0.14 59.50 ± 0.35 48.20 ± 0.23 71.85 ± 0.19 70.83 ± 0.48 60.51 ± 0.52 16.06 ± 0.33
DivideMix 76.90 ± 0.21 75.20 ± 0.12 72.00 ± 0.33 59.60 ± 0.21 76.12 ± 0.44 73.47 ± 0.63 45.83 ± 0.83 16.98 ± 0.40
L2RW 65.32 ± 0.42 55.75 ± 0.81 41.16 ± 0.85 16.80 ± 0.22 65.93 ± 0.17 62.48 ± 0.56 51.66 ± 0.49 12.40 ± 0.61
MW-Net 69.93 ± 0.40 65.29 ± 0.43 55.59 ± 1.07 27.63 ± 0.56 69.80 ± 0.34 64.88 ± 0.63 56.89 ± 0.95 17.05 ± 0.52
CMW-Net 70.11 ± 0.19 65.84 ± 0.50 56.93 ± 0.38 28.36 ± 0.67 71.07 ± 0.56 66.15 ± 0.51 58.21 ± 0.78 17.41 ± 0.16
CMW-Net-SL 77.84 ± 0.12 76.25 ± 0.67 72.61 ± 0.92 55.21 ± 0.31 77.73 ± 0.37 75.69 ± 0.68 61.54 ± 0.72 18.34 ± 0.21
Datasets Noise Symmetric Asy. Noise
0.2 0.5 0.8 0.9 0.4
CIFAR-10 DivideMix 95.7 94.4 92.9 75.4 92.1
ELR+ 94.6 93.8 93.1 75.2 92.7
REED 95.7 95.4 94.1 93.5 -
AugDesc 96.2 95.1 93.6 91.8 94.3
C2D 96.2 95.1 94.3 93.4 90.8
Two-step 96.2 95.3 93.7 92.7 92.4
CMW-Net-SL 96.2 95.1 92.1 48.0 94.5
CMW-Net-SL+ 96.6 96.2 95.4 93.7 96.0
CIFAR-100 DivideMix 77.3 74.6 60.2 31.5 72.1
ELR+ 77.5 72.4 58.2 30.8 76.5
REED 76.5 72.2 66.5 59.4 -
AugDesc 79.2 77.0 66.1 40.9 76.8
C2D 78.3 76.1 67.4 58.5 75.1
Two-step 79.1 78.2 70.1 53.2 65.5
CMW-Net-SL 77.84 76.2 55.2 21.2 75.7
CMW-Net-SL+ 80.2 78.2 71.1 64.6 77.2

Details can refer to Section 4.2 of the main paper.

Feature-dependent Label Noise Experiment

You can repeat the results of TABLE 4 in the paper by

cd section4/Feature-dependent_Label_Noise
bash table4.sh

The main results are shown below:

Datasets Noise ERM LRT GCE MW-Net PLC CMW-Net CMW-Net-SL
CIFAR-10 Type-I (35%) 78.11 ± 0.74 80.98 ± 0.80 80.65 ± 0.39 82.20 ± 0.40 82.80 ± 0.27 82.27 ± 0.33 84.23 ± 0.17
Type-I (70%) 41.98 ± 1.96 41.52 ± 4.53 36.52 ± 1.62 38.85 ± 0.67 42.74 ± 2.14 42.23 ± 0.69 44.19 ± 0.69
Type-II (35%) 76.65 ± 0.57 80.74 ± 0.25 77.60 ± 0.88 81.28 ± 0.56 81.54 ± 0.47 81.69 ± 0.57 83.12 ± 0.40
Type-II (70%) 45.57 ± 1.12 81.08 ± 0.35 40.30 ± 1.46 42.15 ± 1.07 46.04 ± 2.20 46.30 ± 0.77 48.26 ± 0.88
Type-III (35%) 76.89 ± 0.79 76.89 ± 0.79 79.18 ± 0.61 81.57 ± 0.73 81.50 ± 0.50 81.52 ± 0.38 83.10 ± 0.34
Type-III (70%) 43.32 ± 1.00 44.47 ± 1.23 37.10 ± 0.59 42.43 ± 1.27 45.05 ± 1.13 43.76 ± 0.96 45.15 ± 0.91
CIFAR-100 Type-I (35%) 57.68 ± 0.29 56.74 ± 0.34 58.37 ± 0.18 62.10 ± 0.50 60.01 ± 0.43 62.43 ± 0.38 64.01 ± 0.11
Type-I (70%) 39.32 ± 0.43 45.29 ± 0.43 40.01 ± 0.71 44.71 ± 0.49 45.92 ± 0.61 46.68 ± 0.64 47.62 ± 0.44
Type-II (35%) 57.83 ± 0.25 57.25 ± 0.68 58.11 ± 1.05 63.78 ± 0.24 63.68 ± 0.29 64.08 ± 0.26 64.13 ± 0.19
Type-II (70%) 39.30 ± 0.32 43.71 ± 0.51 37.75 ± 0.46 44.61 ± 0.41 45.03 ± 0.50 50.01 ± 0.51 51.99 ± 0.35
Type-III (35%) 56.07 ± 0.79 56.57 ± 0.30 57.51 ± 1.16 62.53 ± 0.33 63.68 ± 0.29 63.21 ± 0.23 64.47 ± 0.15
Type-III (70%) 40.01 ± 0.18 44.41 ± 0.19 40.53 ± 0.60 45.17 ± 0.77 44.45 ± 0.62 47.38 ± 0.65 48.78 ± 0.62

We can repeat the results of TABLE 5 in the paper by

cd section4/Feature-dependent_Label_Noise
bash table5.sh

The main results are shown below:

Datasets Noise ERM LRT GCE MW-Net PLC CMW-Net CMW-Net-SL
CIFAR-10 Type-I + Symmetric 75.26 ± 0.32 75.97 ± 0.27 78.08 ± 0.66 76.39 ± 0.42 79.04 ± 0.50 78.42 ± 0.47 82.00 ± 0.36
Type-I + Asymmetric 75.21 ± 0.64 76.96 ± 0.45 76.91 ± 0.56 76.54 ± 0.56 78.31 ± 0.41 77.14 ± 0.38 80.69 ± 0.47
Type-II + Symmetric 74.92 ± 0.63 75.94 ± 0.58 75.69 ± 0.21 76.57 ± 0.81 80.08 ± 0.37 76.77 ± 0.63 80.96 ± 0.23
Type-II + Asymmetric 74.28 ± 0.39 77.03 ± 0.62 75.30 ± 0.81 75.35 ± 0.40 77.63 ± 0.30 77.08 ± 0.52 80.94 ± 0.14
Type-III + Symmetric 74.00 ± 0.38 75.66 ± 0.57 77.00 ± 0.12 76.28 ± 0.82 80.06 ± 0.47 77.16 ± 0.30 81.58 ± 0.55
Type-III + Asymmetric 75.31 ± 0.34 77.19 ± 0.74 75.70 ± 0.91 75.82 ± 0.77 77.54 ± 0.70 76.49 ± 0.88 80.48 ± 0.48
CIFAR-100 Type-I + Symmetric 48.86 ± 0.56 45.66 ± 1.60 52.90 ± 0.53 57.70 ± 0.32 60.09 ± 0.15 59.17 ± 0.42 60.87 ± 0.56
Type-I + Asymmetric 45.85 ± 0.93 52.04 ± 0.15 52.69 ± 1.14 56.61 ± 0.71 56.40 ± 0.34 57.42 ± 0.81 61.35 ± 0.52
Type-II + Symmetric 49.32 ± 0.36 43.86 ± 1.31 53.61 ± 0.46 54.08 ± 0.18 60.01 ± 0.63 59.16 ± 0.18 61.00 ± 0.41
Type-II + Asymmetric 46.50 ± 0.95 52.11 ± 0.46 51.98 ± 0.37 58.53 ± 0.45 61.43 ± 0.33 58.99 ± 0.91 61.35 ± 0.57
Type-III + Symmetric 48.94 ± 0.61 42.79 ± 1.78 52.07 ± 0.35 55.29 ± 0.57 60.14 ± 0.97 58.48 ± 0.79 60.21 ± 0.48
Type-III + Asymmetric 45.70 ± 0.12 50.31 ± 0.39 50.87 ± 1.12 58.43 ± 0.60 54.56 ± 1.11 58.83 ± 0.57 60.52 ± 0.53

Details can refer to Section 4.3 of the main paper.

learning with real biased data

Learning with Real-world Noisy Datasets

We test our method in the ANIMAL-10N and mini WebVision. You can repeat the results in the ANIMAL-10N (TABLE 6 in the paper) by

cd section5/ANIMAL-10N
bash table6.sh

The main results are shown below:

Method Test Accuracy Method Test Accuracy
ERM 79.4 $\pm$ 0.14 ActiveBias 80.5 $\pm$ 0.26
Co-teaching 80.2 $\pm$ 0.13 SELFIE 81.8 $\pm$ 0.09
PLC 83.4 $\pm$ 0.43 MW-Net 80.7 $\pm$ 0.52
CMW-Net 80.9 $\pm$ 0.48 CMW-Net-SL 84.7 $\pm$ 0.28

You can repeat the results in the mini WebVision (TABLE 7 in the paper) by

cd section5/mini_WebVision
bash table7.sh

The main results are shown below:

Methods ILSVRC12 top1 ILSVRC12 top5 WebVision top1 WebVision top5
Forward 61.12 82.68 57.36 82.36
MentorNet 63.00 81.40 57.80 79.92
Co-teaching 63.58 85.20 61.48 84.70
Interative-CV 65.24 85.34 61.60 84.98
MW-Net 69.34 87.44 65.80 87.52
CMW-Net 70.56 88.76 66.44 87.68
DivideMix 77.32 91.64 75.20 90.84
ELR 77.78 91.68 70.29 89.76
DivideMix 76.32 90.65 74.42 91.21
CMW-Net-SL 78.08 92.96 75.72 92.52
DivideMix with C2D 79.42 92.32 78.57 93.04
CMW-Net-SL+C2D 80.44 93.36 77.36 93.48

Details can refer to Section 5.1 of the main paper.

Webly Supervised Fine-Grained Recognition

We further run our method on a benchmark WebFG-496 dataset consisting of three sub-datasets: Web-aircraft, Web-bird, Web-car, You can repeat the results in the mini WebVision (TABLE 7 in the paper) by

cd section5/WebFG-496
bash table8.sh

The main results are shown below:

Methods Web-Bird Web-Aircraft Web-Car Average
ERM 66.56 64.33 67.42 66.10
Decoupling 70.56 75.97 75.00 73.84
Co-teaching 73.85 72.76 73.10 73.24
Peer-learning 76.48 74.38 78.52 76.46
MW-Net 75.60 72.93 77.33 75.29
CMW-Net 75.72 73.72 77.42 75.62
CMW-Net-SL 77.41 76.48 79.70 77.86

Details can refer to Section 5.2 of the main paper.

transferability of CMW-Net

A potential usefulness of the metalearned weighing scheme by CMW-Net is that it is modelagnostic and hopefully equipped into other learning algorithms in a plug-and-play manner. To validate such transferable capability of CMW-Net, we attempt to transfer meta-learned CMW-Net on relatively smaller dataset to significantly larger-scale ones. In specific, we use CMWNet trained on CIFAR-10 with feature-dependent label noise (i.e.,35% Type-I + 30% Asymmetric) as introduced in Sec. 4.3 in the paper since it finely simulates the real-world noise configuration. The extracted weighting function is depicted blew.

Webvisoin dataset

We deploy it on full WebVision. Even with a relatively concise form, our method still outperforms the second-best Heteroscedastic method by an evident margin. This further validates the potential usefulness of CMWNet to practical large-scale problems with complicated data bias situations, with an intrinsic reduction of the labor and computation costs by readily specifying proper weighting scheme for a learning algorithm. You can repeat the performance on full WebVision(TABLE 10 in the main paper) by

cd section6/webvision
bash table10.sh

The main results are shown below:

Methods ILSVRC12 top1 ILSVRC12 top5 WebVision top1 WebVision top5
ERM 69.7 87.0 62.9 83.6
MentorNet 70.8 88.0 62.5 83.0
MentorMix {74.3} 90.5 67.5 {87.2}
HAR 75.0 90.6 67.1 86.7
MILE 76.5 90.9 68.7 86.4
Heteroscedastic 76.6 92.1 68.6 87.1
CurriculumNet 79.3 93.6 - -
ERM + CMW-Net-SL 77.9 92.6 69.6 88.5

Details can refer to Section 6 of the main paper.

extensional applications

We evaluate the generality of our proposed adaptive sample weighting strategy in more robust learning tasks.

Partial-Label Learning

It is seen that CMW-Net can significantly enhance the performance of the baseline method in both test cases, showing its potential usability in this Partial-Label Learning task. You can repeat the performance in Partial-Label Learning(Fig 9 in the paper) by

cd section7/Partial-Label_Learning
bash fig9.sh

The main results are shown below.

Accuracy comparisons on PRODEN w/o CMW-Net strategy over CIFAR-10:

Accuracy comparisons on PRODEN w/o CMW-Net strategy over CIFAR-100:

Details can refer to Section 7.1 of the main paper.

Citation

If you find this code useful, please cite our paper.

@@inproceedings{CMW-Net,
  	title={CMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning},
  	author={Jun Shu, Xiang Yuan, Deyu Meng, and Zongben Xu},
  	journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
    pages={1-15} ,
  	year={2023}
    }

Acknowledgments

We appreciate the following github repos for their valuable codebase:

About

Pytorch implementation of TPAMI2023: CMW-NetCMW-Net: Learning a Class-Aware Sample Weighting Mapping for Robust Deep Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published