-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #391 from open-mmlab/plyfager/ada-readme
[Doc] Update READMEs for ADA
- Loading branch information
Showing
7 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# ADA | ||
|
||
> [Training Generative Adversarial Networks with Limited Data](https://arxiv.org/pdf/2006.06676.pdf) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes. The approach does not require changes to loss functions or network architectures, and is applicable both when training from scratch and when fine-tuning an existing GAN on another dataset. We demonstrate, on several datasets, that good results are now possible using only a few thousand training images, often matching StyleGAN2 results with an order of magnitude fewer images. We expect this to open up new application domains for GANs. We also find that the widely used CIFAR-10 is, in fact, a limited data benchmark, and improve the record FID from 5.59 to 2.42. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/22982797/165902671-ee835ca5-3957-451e-8e7d-e3741d90e0b1.png"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
<div align="center"> | ||
<b> Results (compressed) from StyleGAN3-ada trained by MMGeneration</b> | ||
<br/> | ||
<img src="https://user-images.githubusercontent.com/22982797/165905181-66d6b4e7-6d40-48db-8281-50ebd2705f64.png" width="800"/> | ||
</div> | ||
|
||
| Model | Dataset | Iter | FID50k | Config | Log | Download | | ||
| :-------------: | :---------------: | :----: | :----: | :-------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| stylegan3-t-ada | metface 1024x1024 | 130000 | 15.09 | [config](https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py) | [log](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_20220328_142211.log.json) | [model](https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth) | | ||
|
||
## Usage | ||
|
||
Currently we only implement ada for StyleGANv2/v3. To use this training trick. You should use `ADAStyleGAN2Discriminator` as your discriminator. | ||
|
||
An example: | ||
|
||
```python | ||
model = dict( | ||
xxx, | ||
discriminator=dict( | ||
type='ADAStyleGAN2Discriminator', | ||
in_size=1024, | ||
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)), | ||
xxx | ||
) | ||
``` | ||
|
||
Here, you can adjust `ada_kimg` to change the magnitude of augmentation(The smaller the value, the greater the magnitude). | ||
|
||
`aug_kwargs` is usually set as follows: | ||
|
||
```python | ||
aug_kwargs = { | ||
'xflip': 1, | ||
'rotate90': 1, | ||
'xint': 1, | ||
'scale': 1, | ||
'rotate': 1, | ||
'aniso': 1, | ||
'xfrac': 1, | ||
'brightness': 1, | ||
'contrast': 1, | ||
'lumaflip': 1, | ||
'hue': 1, | ||
'saturation': 1 | ||
} | ||
``` | ||
|
||
Here, the number is Probability multiplier for each operation. For details, you can refer to [augment](https://github.com/open-mmlab/mmgeneration/tree/master/mmgen/models/architectures/stylegan/ada/augment.py). | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{Karras2020ada, | ||
title = {Training Generative Adversarial Networks with Limited Data}, | ||
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila}, | ||
booktitle = {Proc. NeurIPS}, | ||
year = {2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
Collections: | ||
- Metadata: | ||
Architecture: | ||
- ADA | ||
Name: ADA | ||
Paper: | ||
- https://arxiv.org/pdf/2006.06676.pdf | ||
README: configs/ada/README.md | ||
Models: | ||
- Config: https://github.com/open-mmlab/mmgeneration/tree/master/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py | ||
In Collection: ADA | ||
Metadata: | ||
Training Data: Others | ||
Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8 | ||
Results: | ||
- Dataset: Others | ||
Metrics: | ||
FID50k: 15.09 | ||
Iter: 130000.0 | ||
Log: '[log]' | ||
Task: Tricks for GANs | ||
Weights: https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8_best_fid_iter_130000_20220401_115101-f2ef498e.pth |
99 changes: 99 additions & 0 deletions
99
configs/ada/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
_base_ = [ | ||
'../_base_/models/stylegan/stylegan3_base.py', | ||
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
synthesis_cfg = { | ||
'type': 'SynthesisNetwork', | ||
'channel_base': 65536, | ||
'channel_max': 1024, | ||
'magnitude_ema_beta': 0.999, | ||
'conv_kernel': 1, | ||
'use_radial_filters': True | ||
} | ||
r1_gamma = 3.3 # set by user | ||
d_reg_interval = 16 | ||
|
||
load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa | ||
|
||
# ada settings | ||
aug_kwargs = { | ||
'xflip': 1, | ||
'rotate90': 1, | ||
'xint': 1, | ||
'scale': 1, | ||
'rotate': 1, | ||
'aniso': 1, | ||
'xfrac': 1, | ||
'brightness': 1, | ||
'contrast': 1, | ||
'lumaflip': 1, | ||
'hue': 1, | ||
'saturation': 1 | ||
} | ||
|
||
model = dict( | ||
type='StaticUnconditionalGAN', | ||
generator=dict( | ||
out_size=1024, | ||
img_channels=3, | ||
rgb2bgr=True, | ||
synthesis_cfg=synthesis_cfg), | ||
discriminator=dict( | ||
type='ADAStyleGAN2Discriminator', | ||
in_size=1024, | ||
input_bgr2rgb=True, | ||
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)), | ||
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), | ||
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval)) | ||
|
||
imgs_root = 'data/metfaces/images/' | ||
data = dict( | ||
samples_per_gpu=4, | ||
train=dict(dataset=dict(imgs_root=imgs_root)), | ||
val=dict(imgs_root=imgs_root)) | ||
|
||
ema_half_life = 10. # G_smoothing_kimg | ||
|
||
ema_kimg = 10 | ||
ema_nimg = ema_kimg * 1000 | ||
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8)) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type='VisualizeUnconditionalSamples', | ||
output_dir='training_samples', | ||
interval=5000), | ||
dict( | ||
type='ExponentialMovingAverageHook', | ||
module_keys=('generator_ema', ), | ||
interp_mode='lerp', | ||
interp_cfg=dict(momentum=ema_beta), | ||
interval=1, | ||
start_iter=0, | ||
priority='VERY_HIGH') | ||
] | ||
|
||
inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl' | ||
metrics = dict( | ||
fid50k=dict( | ||
type='FID', | ||
num_images=50000, | ||
inception_pkl=inception_pkl, | ||
inception_args=dict(type='StyleGAN'), | ||
bgr2rgb=True)) | ||
|
||
evaluation = dict( | ||
type='GenerativeEvalHook', | ||
interval=dict(milestones=[100000], interval=[10000, 5000]), | ||
metrics=dict( | ||
type='FID', | ||
num_images=50000, | ||
inception_pkl=inception_pkl, | ||
inception_args=dict(type='StyleGAN'), | ||
bgr2rgb=True), | ||
sample_kwargs=dict(sample_model='ema')) | ||
|
||
lr_config = None | ||
|
||
total_iters = 160000 |
96 changes: 96 additions & 0 deletions
96
configs/ada/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
_base_ = [ | ||
'../_base_/models/stylegan/stylegan3_base.py', | ||
'../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py' | ||
] | ||
|
||
synthesis_cfg = { | ||
'type': 'SynthesisNetwork', | ||
'channel_base': 32768, | ||
'channel_max': 512, | ||
'magnitude_ema_beta': 0.999 | ||
} | ||
r1_gamma = 6.6 # set by user | ||
d_reg_interval = 16 | ||
|
||
load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa | ||
# ada settings | ||
aug_kwargs = { | ||
'xflip': 1, | ||
'rotate90': 1, | ||
'xint': 1, | ||
'scale': 1, | ||
'rotate': 1, | ||
'aniso': 1, | ||
'xfrac': 1, | ||
'brightness': 1, | ||
'contrast': 1, | ||
'lumaflip': 1, | ||
'hue': 1, | ||
'saturation': 1 | ||
} | ||
|
||
model = dict( | ||
type='StaticUnconditionalGAN', | ||
generator=dict( | ||
out_size=1024, | ||
img_channels=3, | ||
rgb2bgr=True, | ||
synthesis_cfg=synthesis_cfg), | ||
discriminator=dict( | ||
type='ADAStyleGAN2Discriminator', | ||
in_size=1024, | ||
input_bgr2rgb=True, | ||
data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)), | ||
gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), | ||
disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval)) | ||
|
||
imgs_root = 'data/metfaces/images/' | ||
data = dict( | ||
samples_per_gpu=4, | ||
train=dict(dataset=dict(imgs_root=imgs_root)), | ||
val=dict(imgs_root=imgs_root)) | ||
|
||
ema_half_life = 10. # G_smoothing_kimg | ||
|
||
ema_kimg = 10 | ||
ema_nimg = ema_kimg * 1000 | ||
ema_beta = 0.5**(32 / max(ema_nimg, 1e-8)) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type='VisualizeUnconditionalSamples', | ||
output_dir='training_samples', | ||
interval=5000), | ||
dict( | ||
type='ExponentialMovingAverageHook', | ||
module_keys=('generator_ema', ), | ||
interp_mode='lerp', | ||
interp_cfg=dict(momentum=ema_beta), | ||
interval=1, | ||
start_iter=0, | ||
priority='VERY_HIGH') | ||
] | ||
|
||
inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl' | ||
metrics = dict( | ||
fid50k=dict( | ||
type='FID', | ||
num_images=50000, | ||
inception_pkl=inception_pkl, | ||
inception_args=dict(type='StyleGAN'), | ||
bgr2rgb=True)) | ||
|
||
evaluation = dict( | ||
type='GenerativeEvalHook', | ||
interval=dict(milestones=[80000], interval=[10000, 5000]), | ||
metrics=dict( | ||
type='FID', | ||
num_images=50000, | ||
inception_pkl=inception_pkl, | ||
inception_args=dict(type='StyleGAN'), | ||
bgr2rgb=True), | ||
sample_kwargs=dict(sample_model='ema')) | ||
|
||
lr_config = None | ||
|
||
total_iters = 160000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters