๐ก ์์ด๋ค์ด ์ค์ผ์น๋ฅผ ๊ทธ๋ฆฌ๋ฉด ์์ฑํ AI๊ฐ ์ด๋ฅผ ๊ทธ๋ฆผ์ผ๋ก ๋ณํํด ์ฃผ๋ ์๋น์ค์ ๋๋ค. ๊ฒ์์ ์ธ ์์๋ฅผ ๋ํด ์์ด๋ค์ด ์ค์ค๋ก ์์ ๋ง์ ์ฝํ ์ธ ๋ฅผ ์์ฑํ๋ฉด์ ์์ ๊ฐ๊ณผ ์ฐฝ์๋ ฅ์ ์ฆ์งํ๊ณ , ์ฌ๋ฏธ์ ์ฑ์ทจ๊ฐ์ ๊ฒฝํํ ์ ์๋๋ก ํฉ๋๋ค.
- ์ฐธ๊ณ ๋ฌธํ : Pix2Pix ํ๋ก์ ํธ ํ์ด์ง | Github | Paper | HED paper
-
๋ฐ์ดํฐ ์ค๋น ๊ณผ์ ์ด ๊ถ๊ธํ์ ๋ถ์ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ Section์ ์ฐธ๊ณ ํด์ฃผ์ธ์.
-
๋ชจ๋ธ์ด ๋์๊ฐ๋ ์๋ฆฌ๊ฐ ๊ถ๊ธํ์ ๋ถ์ ๋ชจ๋ธ ์๊ฐ Section์ ์ฐธ๊ณ ํด์ฃผ์ธ์.
-
๋ชจ๋ธ์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๊ณ ์ถ์ผ์ ๋ถ์ ์ฌ์ฉํ ๋ฐ์ดํฐ Section์ ์ฐธ๊ณ ํด์ฃผ์ธ์.
- Gamification & ์์ด๋ค ์นํ์ ์ธ UI
๊ทธ๋ฆผ ๋ณํ | ์ฃผ์ ๋ณ ๊ทธ๋ฆฌ๊ธฐ |
---|---|
- ์์ด๋ค์ ํฅ๋ฏธ๋ฅผ ์ ๋ฐํ๊ธฐ ์ํ ๊ฒ์์ ์์
๋จ๊ณ๋ณ ๊ทธ๋ฆฌ๊ธฐ | ์ฑ์ |
---|---|
- ๊ทธ๋ฆฐ ๊ทธ๋ฆผ์ ๊ฒ์ํ๊ณ ์ข์์๋ฅผ ๋ฐ์์ ๋ญํน์ ๋ค ์ ์์ต๋๋ค
๋ญํน ๊ทธ๋ฆผ ์กฐํ | ์ฃผ์ ๋ณ ๊ทธ๋ฆผ ์กฐํ | ์๋ฆผ ๊ธฐ๋ฅ |
---|---|---|
- ๋ํ ์ด๋ฅผ ์นด์นด์คํก์ผ๋ก๋ ๊ณต์ ํ ์ ์์ต๋๋ค
์นด์นด์คํก ๊ณต์ ํ๊ธฐ |
---|
๐ฉ ๊ถ๊ทน์ ์ธ ๋ชฉํ : ์๋น์ค์ ํต์ฌ ๊ธฐ๋ฅ ์ค ํ๋์ธ ๊ทธ๋ฆฐ ์ค์ผ์น๋ฅผ ํด๋น ๊ทธ๋ฆผ์ผ๋ก ๋ณํํด์ฃผ๋ generator ํ์ต
- ์ด๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด์ ํด์ผํ ์ผ
- ๋ฐ์ดํฐ ์์ง
- ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
- ๋ชจ๋ธ ํ์ต
- ๋ชจ๋ธ ๋น๊ต & checkpoint ๊ฒฐ์
- FastAPI ์๋ฒ์ ๋ชจ๋ธ ๋์ฐ๊ธฐ
- ๋ง์ ๊ฒฝ์ฐ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ํด๋นํ๋ edge๊ฐ ํจ๊ป ์ ๊ณต๋์ง ์์ต๋๋ค.
- ์ด๋ฏธ์ง ์๊ฐ ๋ง๊ธฐ์ ์ด๋ฅผ ์ง์ ๊ทธ๋ฆฌ๋ ๊ฒ์ ํ์ค์ ์ผ๋ก ์ด๋ ค์ฐ๋, pix2pix ์ ์๋ค์ implementation์ ์ฐธ๊ณ , HED(Holistically-Nested Edge Detection)๋ก edge๋ฅผ ์ถ์ถํ ๋ค, post-processing ์์
์ ๊ฑฐ์ณค์ต๋๋ค.
- pix2pix github์ Extracting Edges Section์ ์ฐธ๊ณ ํด์ฃผ์ธ์
- ๊ฒฐ๋ก ์ ์ผ๋ก ์ ๋ ๊ฒ ์ด์ด๋ถ์ฌ์ ํ์ต ์์ ๋ถ๋ฌ์ค๊ฒ ๋ฉ๋๋ค
- ์์ ์ด๋ฏธ์ง : DVM Car Dataset, bmw series 5
โป Distribution Mismatch
- ํ์ต์ ์ํค๋ ๋ฐ์ดํฐ๋ HED์ ์ํด ์๋์ผ๋ก ์ถ์ถ๋ edge์ธ๋ฐ, ์ค์ ์ฌ์ฉ์๊ฐ ์ด๋ฅผ ๋ฐ๋ผ ๊ทธ๋ฆฌ๊ธฐ๋ ํ์ค์ ์ผ๋ก ์ด๋ ต์ต๋๋ค.
- ๊ทธ๋์ ํ์ต์ํค๋ ๋ฐ์ดํฐ์ ๊ถ๊ทน์ ์ผ๋ก ์ ์ฉํ๊ณ ์ ํ๋ ๋ฐ์ดํฐ์ distribution mismatch๊ฐ ๋ฐ์ํ๋๋ฐ, ์ผ๋ฐ์ ์ผ๋ก ์ด๋ ์ค์ ์ ์ฉ ์์์์ ์ฑ๋ฅ ์ ํ๋ฅผ ์ด๋ํ ์ ์๋ค๋ ์ ์ ์ง๊ณ ๋์ด๊ฐ์ผ ํฉ๋๋ค.
- ๊ทธ๋ ๋ค ํ๋๋ผ๋ ์ ํฌ๊ฐ ๊ถ๊ทน์ ์ผ๋ก ์ํ๋ ๊ฒ์ ์ ์ ๊ฐ ๊ทธ๋ฆฐ edge๋ฅผ ์ ๋ณํํด์ฃผ๋ generator์ด๊ธฐ ๋๋ฌธ์ ์ ํฌ๋ ์ง์ edge๋ฅผ ๊ทธ๋ ค์ ์ ์ ๊ฐ ์ ์ฌ์ ์ผ๋ก ๊ทธ๋ฆด ๋งํ, ๊ทธ๋ฆด ์ ์๋ ์์ค์ edge๋ฅผ ๊ธฐ์ค์ผ๋ก ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํด์ ๋ชจ๋ธ์ ์ ์ ํ์ต๋๋ค.
HED์ ์ํด ์๋์ผ๋ก ์ถ์ถ๋ edge | ์ง์ ๊ทธ๋ ค๋ณธ edge |
---|---|
- ์ ์๋ค๋ ๋ง์ฐฌ๊ฐ์ง๋ก ์ด๋ฅผ ์ ๋ ํ์ฌ, paper ๋ถ๋ก์ ํ์ต ๋ฐ์ดํฐ์ ๋ํ ๊ฒฐ๊ณผ ๋ฟ๋ง ์๋๋ผ, ์ฌ๋์ด ๊ทธ๋ฆฐ ๋ฐ์ดํฐ์ ๋ํ ๊ฒฐ๊ณผ๋ ํจ๊ป ์ฒจ๋ถํ์ต๋๋ค.
- pix2pix๋ conditional GAN(Generative Adversarial Network)๋ฅผ ์ด์ฉํด์ (paired) Image-to-Image Translation ๋ฌธ์ ์ ์ ๊ทผํฉ๋๋ค
๋ฐฐ๊ฒฝ ์ง์
- (paired) Image-to-Image Translation์ด๋ ๋ง ๊ทธ๋๋ก, ์ด๋ค ํ ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ก์ ๋ ์ด๋ฅผ ๋์ํ๋ ํ ์ด๋ฏธ์ง๋ก ๋ฐ๊พธ๋ ๊ฒ์ ๋๋ค
- ์ฌ๊ธฐ์ paired์ ์๋ฏธ๋ ์ด๋ค ๊ตฌ์กฐ๋ฅผ ๊ณต์ ํ๋, (input์ผ๋ก output์ ์ด๋ ์ ๋ ์ค๋ช ๊ฐ๋ฅํ) (input, output) pair๊ฐ ์๋ ํ๊ฒฝ์ ๋งํฉ๋๋ค
- ์์ ์ ์ฉ ์ฌ๋ก
- GAN์ด๋ Generative Adversarial Network์ ์ฝ์๋ก, generator์ discriminator ๋๊ฐ์ neural network๋ก ์ด๋ฃจ์ด์ ธ ์๋๋ฐ, Generative : ๋ญ๊ฐ๋ฅผ ๋ง๋ค์ด๋ด๋, Adversarial : generator์ discriminator๊ฐ ๋ญ๊ฐ ์๋ก ๊ฒฝ์ํ๋ค๋(or ๋์์ ์ฃผ๋) ๋ป์ ๋ด๊ณ ์์ต๋๋ค
- generator์ ๋ชฉ์ ์ ์ฌ์ค์ ์ธ ๋ฐ์ดํฐ(image, audio ๋ฑ)๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ฒ์
๋๋ค
- ์ฌ๊ธฐ์ ์ฌ์ค์ ์ด๋ผ ํจ์, (discriminator๊ฐ) ์ค์ ๋ฐ์ดํฐ์ ๊ตฌ๋ณํ๊ธฐ ์ด๋ ค์ด ๊ฒ์ ๋งํฉ๋๋ค
- discriminator์ ๋ชฉ์ ์ ์ด๋ค ๋ฐ์ดํฐ(image, audio ๋ฑ)๊ฐ ์ฃผ์ด์ก์ ๋, ์ด๊ฒ์ด generator๊ฐ ๋ง๋ค์ด๋ธ fake ๋ฐ์ดํฐ์ธ์ง, ํน์ real ๋ฐ์ดํฐ์ธ์ง ๊ตฌ๋ณํด๋ด๋ ๊ฒ์ ๋๋ค
- generator์ ๋ชฉ์ ์ ์ฌ์ค์ ์ธ ๋ฐ์ดํฐ(image, audio ๋ฑ)๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ฒ์
๋๋ค
์๊ฐ
- pix2pix๋ input image๊ฐ ์ฃผ์ด์ง๋ฉด ํด๋นํ๋ ํ๊ฒ์ output image๋ก ๋ฐ๊ฟ์ฃผ๋ paired Image-to-Image Translation Task๋ฅผ ์ํ ๋ชจ๋ธ์ ๋๋ค
- ๋ง์ ์ด์ ์ GAN์ด noise๋ฅผ input์ผ๋ก ์ฃผ๋ฉด output์ ๋ฐํํ๋ ๊ฒ์ ๋นํด, pix2pix์์๋ input์ผ๋ก condition(๋ณํํ๊ณ ์ ํ๋ ์ด๋ฏธ์ง)์ ์ฃผ๊ณ ๋ณ๋์ noise๋ ์ฃผ์ง ์์ต๋๋ค
- ๊ทธ๋์ condition์ด ๋์ผํ ํ generator๋ deterministicํ๊ฒ ๋ฉ๋๋ค
- ์ฆ, ๊ฐ์ condition์ด ์ฃผ์ด์ง๋ฉด ๊ฐ์ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด๋ด๊ฒ ๋ฉ๋๋ค
- ๊ทธ๋์ condition์ด ๋์ผํ ํ generator๋ deterministicํ๊ฒ ๋ฉ๋๋ค
- ์ ์๋ค๋ ์ฒ์์๋ noise๋ ๊ฐ์ด ์ฃผ๋ ๋ฐฉํฅ์ ๊ณ ๋ คํ์ง๋ง ๊ทธ๋ฆฌ ํจ๊ณผ์ ์ด์ง ์์์ ์ ์ธํ๋ค๊ณ ํฉ๋๋ค
- generator๊ฐ noise๋ฅผ ๋ฌด์ํ๋ ์ชฝ์ผ๋ก ํ์ตํ๋ ๊ฒฝํฅ์ ๋ณด์๋ค๊ณ ํฉ๋๋ค
Past conditional GANs have acknowledged this and provided Gaussian noise z as an input to the generator, in addition to x (e.g., [55]). In initial experiments, we did not find this strategy effective โ the generator simply learned to ignore the noise โ which is consistent with Mathieu et al. (์ถ์ฒ : pix2pix paper)
- input : (H, W, 1) image tensor(ํ๋ฐฑ), ๋ฒ์ : [0, 1]
- output : (H, W, 3) image tensor(์ปฌ๋ฌ), ๋ฒ์ : [-1, 1]
- ์ ์๋ค์ด ์ด ๋ ผ๋ฌธ์ ๋ฐํํ ๋๋ U-net ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์ง๋ง, ์ ํฌ๋ ์ดํ CycleGAN์์ ์ฌ์ฉํ๋ Resnet ๊ธฐ๋ฐ์ generator๋ฅผ ์ฌ์ฉํ์ต๋๋ค
- conv + 2 Contracting Blocks + 9 Residual Blocks + 2 Expanding Blocks + conv
- c7s1-64, d128, d256, R256 * 9, u128, u64, c7s1-3 (CycleGAN ์ ์๋ค์ Notation ์ฐธ๊ณ )
- Contracting Block : conv + instance_norm โ width & height๋ฅผ ์ ๋ฐ์ผ๋ก ์ค์ ๋๋ค
- Residual Block : conv + instance_norm + relu + conv + instance_norm + input๊ณผ์ skip_connection โ width & height๋ฅผ ๊ทธ๋๋ก ์ ์งํฉ๋๋ค
- Expanding Block : transposed_conv + instance_norm โ width & height๋ฅผ 2๋ฐฐ๋ก ๋๋ฆฝ๋๋ค
- padding : reflection_pad
- ๊ฒฐ๋ก ์ ์ผ๋ก input๊ณผ output์ ์ด๋ฏธ์ง ํฌ๊ธฐ๋ ๋์ผํฉ๋๋ค
- ํ๋ผ๋ฏธํฐ ์ : 11,377,155๊ฐ (condition์ด ํ๋ฐฑ edge channel ํ๋์ผ ๊ฒฝ์ฐ)
-
input : (H, W, 4) image tensor(real or fake image + condition(์ฐ๋ฆฌ์ ๊ฒฝ์ฐ edge)), ๋ฒ์ : [-1, 1]
- ํ๊ฒ ์ด๋ฏธ์ง ๋ฟ๋ง ์๋๋ผ, ๊ทธ ํ๊ฒ์ ๋ง๋ค์ด๋ด๊ธฐ ์ํ condition์ ํ๊ฒ ์ด๋ฏธ์ง์ ์ฑ๋ ์ถ์ ๋ถ์ ๋๋ค
- ์ด๋ ๊ธฐ์กด์ conditional GAN์์ ๊ทธ๋ฌ๋ฏ, ๋จ์ํ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง๋ง ๊ฐ์ง๊ณ ๊ทธ ์ด๋ฏธ์ง๊ฐ ์ง์ง์ธ์ง ์๋์ง ๊ตฌ๋ณํ๋ ๊ฒ๋ณด๋ค, ๊ทธ ๊ฒฐ๊ณผ ์ด๋ฏธ์ง๊ฐ ์ ์ปจ๋์ ์ผ๋ก๋ถํฐ ๋์์ ๋ ์ง์ง์ธ์ง ๊ฐ์ง์ธ์ง ๊ตฌ๋ณํ๋ ๊ฒ์ด ์ฑ๋ฅ์ ๋ ์ข์๊ธฐ์ ๊ทธ๋ฌ๋ค๊ณ ํฉ๋๋ค
-
output : patchGAN output tensor, ๋ฒ์ : ์ ํ ์์, but 0์ ๊ฐ๊น์ธ์๋ก discriminator๋ ๊ฐ์ง๋ก ํ๋จํ๋ ๊ฒ์ด๊ณ , 1์ ๊ฐ๊น์ธ์๋ก ์ง์ง๋ก ํ๋จํ๋ ๊ฒ์ ๋๋ค
-
discriminator๋ก๋ PatchGAN discriminator์ ์ฌ์ฉํฉ๋๋ค
-
์ด์ ์ ๋ง์ GAN์ด ์ด๋ฏธ์ง ์ ์ฒด๋ฅผ ํ๋ฒ์ ๋ณด๊ณ , ์ด๊ฒ real์ธ์ง fake์ธ์ง ๊ตฌ๋ถํ๋ค๋ฉด, PatchGAN discriminator๋ ์ด๋ฏธ์ง ์ ์ฒด๋ฅผ ํ๋ฒ์ ๋ณด์ง ์๊ณ , ๊ฐ๊ฐ ํด๋นํ๋ ์ด๋ฏธ์ง Patch ๋ณ๋ก ๊ทธ ๋ถ๋ถ์ด ์ฌ์ค์ ์ธ๊ฐ(real distribution๊ณผ discriminator๊ฐ ๊ตฌ๋ถํ ์ ์๋๊ฐ) ์๋๊ฐ(fake)๋ฅผ ํ๋จํฉ๋๋ค
- ๊ทธ๋ฆฌ๊ณ ์ถํ์ ์ดํด๋ณผ loss์์ ์ด ์ ๋ณด๋ค์ ์ทจํฉํฉ๋๋ค
- ์ฌ๊ธฐ์ ์ฃผ์ํ ์ ์, ์ด๋ฏธ์ง๋ฅผ ์ฌ๋ฌ ๊ฐ์ patch๋ก ์๋ผ์ ํ๋ํ๋ ๋ฃ๋ ๊ฒ์ด ์๋๋ผ, ๋ชจ๋ ํ๋ฒ์ convolution์ ์ฑ์ง์ ์ด์ฉํด์ ์งํํ๊ฒ ๋ฉ๋๋ค
- output์ ๊ฐ cell๋ค์ด ๋ณด๋ patch๋ ์๋ก ๊ฒน์น ์ ์์ต๋๋ค
-
PatchGAN์ ๊ธฐ๋ณธ์ ์ผ๋ก receptive field size ๋ณด๋ค ๋ ๋ฉ๋ฆฌ ์๋ ํฝ์ ๋ค์ ์๋ก ๋ ๋ฆฝ์ ์ด๋ผ๊ณ ๊ฐ์ ์ ํ๊ธฐ ๋๋ฌธ์ ๋ค๋ฅธ ๋ง๋ก Markovian Discriminator, Local-patch Disctiminator๋ผ๊ณ ๋ ๋ถ๋ฆฐ๋ค๊ณ ํฉ๋๋ค
-
์๋๋ ์ ์๋ค์ด ํ ์คํธ ํด๋ณธ ์ฌ๋ฌ๊ฐ์ง receptive field size(discriminator์ output์ค ํ cell์ด ๋ณด๋ ์ ๋ ฅ ์ด๋ฏธ์ง patch ํฌ๊ธฐ, in pixel)์ ๋ํ ๊ฒฐ๊ณผ ์์์ ๋๋ค
-
์ ํฌ๋ ์ ์๋ค์ ์ ํ์ ์ฐธ๊ณ ํด์ 70์ ์ฌ์ฉํ์ต๋๋ค
-
C64 - C128 - C256 - C512 - output layer (CycleGAN ์ ์๋ค์ notation ์ฐธ๊ณ )
-
ํ๋ผ๋ฏธํฐ ์ : 2,765,633๊ฐ (condition์ด ํ๋ฐฑ edge channel ํ๋์ผ ๊ฒฝ์ฐ)
loss - pix2pix๋ ๋ค์ objective๋ฅผ ๊ธฐ์ค์ผ๋ก ํ์ตํฉ๋๋ค
- ์ฌ๊ธฐ์ G๋ generator์ด๊ณ , D๋ discriminator์ ๋๋ค
- ๊ตฌ์ฑํ๋ ๊ฒ์ ๋๊ฐ์ง๋ก ๋๋ ๋ณด๋ฉด ํฌ๊ฒ L_cGAN๊ณผ L_L1์ผ๋ก ๋๋ ์ ์์ต๋๋ค
-
L_cGAN
- ์ด๋ conditional gan์์ ์ฌ์ฉํ๋ loss์ ๋์ผํฉ๋๋ค
- ๋จ, ์ ํฌ๋ cross entropy loss๋์ ์ LSGAN์์ ์ฌ์ฉ๋์๋ least square adversarial loss๋ฅผ ์ฌ์ฉํ์ต๋๋ค
- ์ฆ, generator๋ input condition x(์ฐ๋ฆฌ์ ๊ฒฝ์ฐ sketch)๋ฅผ ๋ฃ์ด์ ์์ฑ๋ ๊ฒฐ๊ณผ G(x)๊ฐ discriminator์๊ฒ ์ฌ์ค์ ์ธ ์ด๋ฏธ์ง์ฒ๋ผ ๋ณด์ด๋๋ก ํ์ตํ๊ณ
- discriminator๋ y(์ค์ target)์ 1๋ก, ๊ฐ์ง(G(x))๋ 0์ด ๋๋๋ก ํ์ตํฉ๋๋ค
-
L_L1 loss
- input์ generator์ ๋ฃ์ด์ ์์ฑ๋ ์ด๋ฏธ์ง์ ์ํ๋ target ์ฌ์ด์ pixel level์์์ L1 distance๋ฅผ ๊ณ์ฐํฉ๋๋ค
- ์ฝ๊ฒ ๋งํ์๋ฉด, model์ output๊ณผ target(์ด์์ ์ธ ๊ฒฐ๊ณผ)์ ํฝ์ ๊ฐ์ ์ ๋๊ฐ ์ฐจ์ด(RGB ๋ชจ๋)๋ฅผ ๋ชจ๋ ๊ตฌํ ํ ์ด๋ฅผ ํ๊ท ๋ด๋ฉด ๋ฉ๋๋ค
- ์ ์๋ค์ ์ ์ฒด์ ์ผ๋ก ๋น๊ต์ blurryํ ๊ฒฐ๊ณผ๋ฅผ ๋ง๋๋ L2(์ ๋๊ฐ ๋์ ์ ๊ณฑํฉ์ ์ฌ์ฉ) ๋์ L1 distance๋ฅผ ์ฌ์ฉํ์ต๋๋ค
- ์ฐ๋ฆฌ๊ฐ ๊ถ๊ทน์ ์ผ๋ก ์ํ๋ ๊ฒ์ input์ target์ผ๋ก ์ ๋ฐ๊ฟ์ฃผ๋ generator์ด๊ธฐ ๋๋ฌธ์ ์์ฐ์ค๋ฌ์ด ์ ํ์ผ๋ก ๋ฐ์๋ค์ผ ์ ์์ต๋๋ค
-
def discriminator_loss_function(real_D_out, fake_D_out):
'''
LSGAN loss
<params>
real_D_out : ์ค์ ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ก์ ๋, discriminator์ ๊ฒฐ๊ณผ๊ฐ
fake_D_out : ๊ฐ์ง ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ก์ ๋, discirminator์ ๊ฒฐ๊ณผ๊ฐ
'''
# ์ ์๋ค์ ๋ฐฉ์์ ๋ฐ๋ผ, 2๋ก ๋๋์ผ๋ก์จ D๊ฐ ๋ฐฐ์ฐ๋ ์๋๋ฅผ ๋ฆ์ถ๋ค (G๊ฐ Generator Adversarial Loss๋ก๋ถํฐ ๋ฐฐ์ฐ๋ ๊ฒ์ ๋นํด์)
return 0.5 * (tf.math.reduce_mean(tf.math.squared_difference(real_D_out, tf.ones_like(real_D_out))) +
tf.math.reduce_mean(tf.math.squared_difference(fake_D_out, tf.zeros_like(fake_D_out))))
def generator_adversarial_loss_function(fake_D_out):
'''
LSGAN loss
<params>
fake_D_out : ๊ฐ์ง ์ด๋ฏธ์ง๊ฐ ์ฃผ์ด์ก์ ๋, discirminator์ ๊ฒฐ๊ณผ๊ฐ
'''
return tf.math.reduce_mean(tf.math.squared_difference(fake_D_out, tf.ones_like(fake_D_out)))
def generator_L1_loss_function(real_images, fake_images):
'''
L1 loss
<params>
real_images : ์ค์ ์ด๋ฏธ์ง
fake_images : generator์ ์ํด์ ์์ฑ๋ ์ด๋ฏธ์ง
'''
return tf.math.reduce_mean(tf.math.abs(real_images - fake_images))
# discriminator loss ๊ณ์ฐ
discriminator_loss = discriminator_loss_function(real_D_out, fake_D_out)
# generator loss ๊ณ์ฐ
generator_adversarial_loss = generator_adversarial_loss_function(fake_D_out)
generator_L1_loss = generator_L1_loss_function(real_image, fake_image)
generator_loss = generator_adversarial_loss + LAMBDA * generator_L1_loss
- L1 loss๋ฅผ ๊ตฌํ๋ ๋ฐ๋ discriminator๋ฅผ ์ด์ฉํ์ง ์์ผ๋ฏ๋ก
- discriminator_loss = LSGAN_discriminator_loss
- generator_loss = LSGAN_generator_loss + lambda * L1_loss
- lambda๋ 10์ ์ฌ์ฉํ์ต๋๋ค
- ์ค์ ํ์ต์ discriminator์ ํ๋ผ๋ฏธํฐ๋ฅผ discriminator_loss๋ฅผ ๋ฎ์ถ๋ ๋ฐฉํฅ์ผ๋ก ํ๋ฒ ์ ๋ฐ์ดํธ ํ๊ณ , ๊ทธ๋ฆฌ๊ณ ๊ทธ ๋ค์์ generator์ ํ๋ผ๋ฏธํฐ๋ฅผ generator_loss๋ฅผ ๋ฎ์ถ๋ ๋ฐฉํฅ์ผ๋ก ํ๋ฒ ์ ๋ฐ์ดํธ ํ๋ ๊ณผ์ ์ ๋ฐ๋ณตํฉ๋๋ค
- ๋ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๋์์ ํ์ตํ์ง ์๋ ๊ฒ์ด ์ค์ํฉ๋๋ค (๋ฐ๋ก๋ฐ๋ก ๋ฒ๊ฐ์๊ฐ๋ฉฐ ํด์ผํฉ๋๋ค)
- ๊ทธ ์ด์ ๋ generator(์ ๋ง๋ค์ด๋ด์)์ ์ญํ ๊ณผ discriminator(์ ๊ตฌ๋ณํ์)์ ์ญํ ์ด ์ด์ฐ ๋ณด์๋ฉด ์๋ก ์๋ฐ๋๋๋ฐ, ๋๊ฐ์ parameter๋ฅผ ๋์์ ํ์ตํ๋ค๋ฉด ์๋ก ์ ์ถฉ, ํํํ๋ ๋ฐฉํฅ์ผ๋ก ํ์ตํ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค
- discriminator_loss๋ฅผ ๋ฎ์ถ๋ค ํจ์ LSGAN_discriminator_loss๋ฅผ ๋ฎ์ถ๋ ๋ฐฉํฅ, ์ฆ ๊ฐ์ง ์ด๋ฏธ์ง๋ 0์ผ๋ก ์์ธกํ๋ ค๊ณ ํ๊ณ , ์ง์ง ์ด๋ฏธ์ง๋ 1๋ก ์์ธกํ๋ ค๊ณ ํ๋, ์ด๋ฏธ์ง๋ฅผ ์ ๊ตฌ๋ณํ๋ ค๊ณ ํ์ตํ๋ ๊ณผ์ ์ ๋๋ค
- ๊ทธ๋ฆฌ๊ณ ๊ทธ ๋ค์ generator_loss๋ฅผ ๋ฎ์ถ๋ค ํจ์ LSGAN_generator_loss + lambda * L1_loss๋ฅผ ๋ฎ์ถ๋ ๋ฐฉํฅ์ธ๋ฐ
-
LSGAN_generator_loss๋ฅผ ๋ค์ ์ดํด๋ณด๋ฉด G(x), ์ฆ ๊ฐ์ง์ด๋ฏธ์ง๊ฐ discriminator์๊ฒ 1(์ง์ง์ฒ๋ผ ๋ณด์ด๋๋ก)์ ๊ฐ๊น๋๋ก ํ์ตํฉ๋๋ค
-
๊ถ๊ทน์ ์ผ๋ก D๋ ์ฌ์ฉํ์ง ์๊ณ ์ข์ G๋ฅผ ์ป์ด๋ด๋ ๊ฒ์ด ๋ชฉ์ ์ธ๋ฐ, D๊ฐ ์ ํ์ํ๋?์ ๋๋ต์ ์ฌ๊ธฐ์ ํ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค
-
D๋ ์ด์ ์ ๊ฐ์ง(G(x))์ ์ง์ง(y)๋ฅผ ๊ตฌ๋ณํ๋ ค๊ณ ํ์ตํ๊ธฐ ๋๋ฌธ์ (๋น๋ก ํ๋ฒ์ step์ด์ง๋ง, ์กฐ๊ธ์ด๋ผ๋, ๊ณ์ ๋์ ๋๋ค๋ฉด) ๊ฐ์ง์ ์ง์ง๊ฐ ์ด๋ค ๋ถ๋ถ์์ ๋ค๋ฅธ์ง ์ด๋ ์ ๋ ์๊ฒ ๋ฉ๋๋ค
-
๊ทธ๋์ ์ด๋ป๊ฒ ๊ตฌ๋ณํ ์ ์์๋์ง ์ด ์ ๋ณด๋ฅผ G์ ํ๋ผ๋ฏธํฐ์ ์ ๋ฌ(D๋ ์ด๋ ํ์ตํ์ง ์์ต๋๋ค)ํด์ G๊ฐ ๋ ์ ๋ง๋ค๊ฒ ๋๋ ๊ฑฐ๋ผ๊ณ ์ดํดํ๋ฉด ์ข์๊ฑฐ ๊ฐ์ต๋๋ค
- ์ฌ๊ธฐ์ ์ ๋ฌ์ ๋ฌผ๋ก back propagation์ ํตํด ์ด๋ฃจ์ด์ง๋๋ค
-
- ์ด๋ฐ ์์ผ๋ก D๋ ์ ๊ตฌ๋ณํ๋ ค๊ณ ํ๋ฒ ๋ฐฐ์ฐ๊ณ , ๊ทธ ๋ค์์ ์ด๋ป๊ฒ ๊ตฌ๋ณํ๋์ง ๊ทธ ์ ๋ณด๋ฅผ G์๋ ์ ๋ฌํด์ G๋ ๋ ์ ๋ง๋ค๊ฒ ๋๊ณ , ๋ค์ ๋ D๋ ์ด๊ฑธ ์ค์ ์ ๊ตฌ๋ณํด๋ณด๋ ค๊ณ ๋ ธ๋ ฅํ๊ณ , ๋ค์ G์๊ฒ ์ด๋ป๊ฒ ๊ตฌ๋ณํ๋์ง ์ ๋ณด๋ฅผ ์ ๋ฌํ๊ณ ์ด๋ฐ ์์ผ๋ก GAN์ ํ์ตํ๋ ๊ฒ์ ๋๋ค
- ์ถ์ฒ : ๊ตฌ๊ธ์ cartoon set
- ๋ฐ์ดํฐ ์ : 9996 (์๋ณธ 10๋ง ๊ฐ ์ค์์ ์ผ๋ถ๋ฅผ ์ถ์ถํ์ฌ ์ํ)
- batch ์ฌ์ด์ฆ : 4
- ํ์ต์ํจ epoch ์ : 28
- ํน์ด์ฌํญ
- ์ฑ๋ฅ ๊ฐ์ ์ ์ํด์ color ์ ๋ณด๋ฅผ condition์ผ๋ก ์ถ๊ฐํด ๋ณด๊ธฐ๋ ํ๊ณ , ๋ฐ์ดํฐ์ ์๋ฅผ ๋๋ ค๋ณด๊ธฐ๋ ํ์์ผ๋(10๋ง ๊ฐ, ์๋ณธ ๋ฐ์ดํฐ ์ ๋ถ), ์ฌ์ฉ์๊ฐ ๊ทธ๋ฆฐ edge์ ๋ํ ๋ณํ ์ฑ๋ฅ์ ์ด๋ ๋ค ํ ๊ฐ์ ์ ์ด ๋ณด์ด์ง ์์ (ํ์ต ๋ฐ์ดํฐ๋ ๊ต์ฅํ ์ ๋ณํ)
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : Kaggle, Panda or Bear Image Classification
- ๋ฐ์ดํฐ ์ : 300 (๊ณฐ ๋ฐ์ดํฐ๋ ์ ์ธํ๊ณ , ํ๋ค ๋ฐ์ดํฐ๋ง ์ฌ์ฉ)
- batch ์ฌ์ด์ฆ : 1
- ํ์ต ์ํจ epoch ์ : 180
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : DVM car dataset
- ๋ฐ์ดํฐ ์ : 11476 (DVM car dataset์์ ์ธ๋จ ํ์ bmw series 5 & 7๋ง ์ถ์ถ)
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : 19
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : ์ ์๋ค์ด ์ฌ์ฉํ๋ ๋ฐ์ดํฐ์
- ๋ฐ์ดํฐ ์ : 138567
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : 5
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : ์ ์๋ค์ด ์ฌ์ฉํ๋ ๋ฐ์ดํฐ์
- ๋ฐ์ดํฐ ์ : 49825
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : 25
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : Kaggle, maplestory_characters_hd
- ๋ฐ์ดํฐ ์ : 69372
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : 14
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : Kaggle, Gemstones Images
- ๋ฐ์ดํฐ ์ : 3219
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : ๋๋ต 36
- ์์ ๊ฒฐ๊ณผ
- ์ถ์ฒ : Kaggle, Cosmos Images
- ๋ฐ์ดํฐ ์ : 4649
- batch ์ฌ์ด์ฆ : 4
- ํ์ต ์ํจ epoch ์ : 40
- ์์ ๊ฒฐ๊ณผ