Skip to content

Commit

Permalink
Merge pull request #391 from open-mmlab/plyfager/ada-readme
Browse files Browse the repository at this point in the history
[Doc] Update READMEs for ADA
  • Loading branch information
plyfager authored Aug 24, 2022
2 parents bb4fac2 + d7457f1 commit 92be69b
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 0 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ These methods have been carefully studied and supported in our frameworks:

</details>

<details open>
<summary>Tricks for GANs (click to collapse)</summary>

-[ADA](configs/ada/README.md) (NeurIPS'2020)

</details>

<details open>
<summary>Image2Image Translation (click to collapse)</summary>

Expand Down
7 changes: 7 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ pip3 install -e .[all]

</details>

<details open>
<summary>Tricks for GANs (点击折叠)</summary>

-[ADA](configs/ada/README.md) (NeurIPS'2020)

</details>

<details open>
<summary>Image2Image Translation (点击折叠)</summary>

Expand Down
78 changes: 78 additions & 0 deletions configs/ada/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# ADA

> [Training Generative Adversarial Networks with Limited Data](https://arxiv.org/pdf/2006.06676.pdf)
<!-- [ALGORITHM] -->

## Abstract

Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes. The approach does not require changes to loss functions or network architectures, and is applicable both when training from scratch and when fine-tuning an existing GAN on another dataset. We demonstrate, on several datasets, that good results are now possible using only a few thousand training images, often matching StyleGAN2 results with an order of magnitude fewer images. We expect this to open up new application domains for GANs. We also find that the widely used CIFAR-10 is, in fact, a limited data benchmark, and improve the record FID from 5.59 to 2.42.

<!-- [IMAGE] -->

<div align=center>
<img src="https://user-images.githubusercontent.com/22982797/165902671-ee835ca5-3957-451e-8e7d-e3741d90e0b1.png"/>
</div>

## Results and Models

<div align="center">
<b> Results (compressed) from StyleGAN3-ada trained by MMGeneration</b>
<br/>
<img src="https://user-images.githubusercontent.com/22982797/165905181-66d6b4e7-6d40-48db-8281-50ebd2705f64.png" width="800"/>
</div>

| Model | Dataset | Iter | FID50k | Config | Log | Download |
| :-------------: | :---------------: | :----: | :----: | :-------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: |
| stylegan3-t-ada | metface 1024x1024 | 130000 | 15.09 | [config](https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py) | [log](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_20220328_142211.log.json) | [model](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth) |

## Usage

Currently we only implement ada for StyleGANv2/v3. To use this training trick. You should use `ADAStyleGAN2Discriminator` as your discriminator.

An example:

```python
model = dict(
xxx,
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
xxx
)
```

Here, you can adjust `ada_kimg` to change the magnitude of augmentation(The smaller the value, the greater the magnitude).

`aug_kwargs` is usually set as follows:

```python
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}
```

Here, the number is Probability multiplier for each operation. For details, you can refer to [augment](https://github.com/open-mmlab/mmgeneration/tree/master/mmgen/models/architectures/stylegan/ada/augment.py).

## Citation

```latex
@inproceedings{Karras2020ada,
title = {Training Generative Adversarial Networks with Limited Data},
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
booktitle = {Proc. NeurIPS},
year = {2020}
}
```
22 changes: 22 additions & 0 deletions configs/ada/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Collections:
- Metadata:
Architecture:
- ADA
Name: ADA
Paper:
- https://arxiv.org/pdf/2006.06676.pdf
README: configs/ada/README.md
Models:
- Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py
In Collection: ADA
Metadata:
Training Data: Others
Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8
Results:
- Dataset: Others
Metrics:
FID50k: 15.09
Iter: 130000.0
Log: '[log]'
Task: Tricks for GANs
Weights: https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth
99 changes: 99 additions & 0 deletions configs/ada/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]

synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 65536,
'channel_max': 1024,
'magnitude_ema_beta': 0.999,
'conv_kernel': 1,
'use_radial_filters': True
}
r1_gamma = 3.3 # set by user
d_reg_interval = 16

load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa

# ada settings
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}

model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
input_bgr2rgb=True,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))

imgs_root = 'data/metfaces/images/'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))

ema_half_life = 10. # G_smoothing_kimg

ema_kimg = 10
ema_nimg = ema_kimg * 1000
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))

custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interp_cfg=dict(momentum=ema_beta),
interval=1,
start_iter=0,
priority='VERY_HIGH')
]

inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))

evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[100000], interval=[10000, 5000]),
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))

lr_config = None

total_iters = 160000
96 changes: 96 additions & 0 deletions configs/ada/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
_base_ = [
'../_base_/models/stylegan/stylegan3_base.py',
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py'
]

synthesis_cfg = {
'type': 'SynthesisNetwork',
'channel_base': 32768,
'channel_max': 512,
'magnitude_ema_beta': 0.999
}
r1_gamma = 6.6 # set by user
d_reg_interval = 16

load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa
# ada settings
aug_kwargs = {
'xflip': 1,
'rotate90': 1,
'xint': 1,
'scale': 1,
'rotate': 1,
'aniso': 1,
'xfrac': 1,
'brightness': 1,
'contrast': 1,
'lumaflip': 1,
'hue': 1,
'saturation': 1
}

model = dict(
type='StaticUnconditionalGAN',
generator=dict(
out_size=1024,
img_channels=3,
rgb2bgr=True,
synthesis_cfg=synthesis_cfg),
discriminator=dict(
type='ADAStyleGAN2Discriminator',
in_size=1024,
input_bgr2rgb=True,
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)),
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'),
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval))

imgs_root = 'data/metfaces/images/'
data = dict(
samples_per_gpu=4,
train=dict(dataset=dict(imgs_root=imgs_root)),
val=dict(imgs_root=imgs_root))

ema_half_life = 10. # G_smoothing_kimg

ema_kimg = 10
ema_nimg = ema_kimg * 1000
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8))

custom_hooks = [
dict(
type='VisualizeUnconditionalSamples',
output_dir='training_samples',
interval=5000),
dict(
type='ExponentialMovingAverageHook',
module_keys=('generator_ema', ),
interp_mode='lerp',
interp_cfg=dict(momentum=ema_beta),
interval=1,
start_iter=0,
priority='VERY_HIGH')
]

inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl'
metrics = dict(
fid50k=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True))

evaluation = dict(
type='GenerativeEvalHook',
interval=dict(milestones=[80000], interval=[10000, 5000]),
metrics=dict(
type='FID',
num_images=50000,
inception_pkl=inception_pkl,
inception_args=dict(type='StyleGAN'),
bgr2rgb=True),
sample_kwargs=dict(sample_model='ema'))

lr_config = None

total_iters = 160000
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Import:
- configs/ada/metafile.yml
- configs/biggan/metafile.yml
- configs/cyclegan/metafile.yml
- configs/dcgan/metafile.yml
Expand Down

0 comments on commit 92be69b

Please sign in to comment.