Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New: modernized repository #40

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# IDE related
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ Alternatively you can build your own dataset by setting up the following directo
| | | └── B # Contains domain B images (i.e. Batman)

### 2. Train!

Before training, make sure to startup the visdom server in another terminal. Otherwise, you will get HTTPConnection errors. It's as simple as:

```
./train --dataroot datasets/<dataset_name>/ --cuda
visdom
```

Next, you can launch the actual training script. If your python version is not located at `/usr/bin/python3` (e.g., if you are using conda), you can delete the first line of `./train`.

```
python train.py --dataroot datasets/<dataset_name>/ --cuda
```

This command will start a training session using the images under the *dataroot/train* directory with the hyperparameters that showed best results according to CycleGAN authors. You are free to change those hyperparameters, see ```./train --help``` for a description of those.

Both generators and discriminators weights will be saved under the output directory.
Expand All @@ -53,7 +63,7 @@ You can also view the training progress as well as live output images by running

## Testing
```
./test --dataroot datasets/<dataset_name>/ --cuda
python test.py --dataroot datasets/<dataset_name>/ --cuda
```
This command will take the images under the *dataroot/test* directory, run them through the generators and save the output under the *output/A* and *output/B* directories. As with train, some parameters like the weights to load, can be tweaked, see ```./test --help``` for more information.

Expand Down
45 changes: 45 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import glob
from collections.abc import Iterable
import os
import random
from typing import Optional, Callable

from PIL import Image

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class ImageDataset(Dataset):
def __init__(
self,
root: str,
transforms_: Optional[transforms.Compose] = None,
unaligned: bool = True,
mode: str = 'train',
grayscale: bool = False
) -> None:
self.transform: Callable[Image, torch.Tensor] = transforms.Compose(transforms_)
self.unaligned: bool = unaligned
self.grayscale: bool = grayscale

self.files_A: Iterable[str] = sorted(glob.glob(os.path.join(root, f'{mode}/A') + '/*.*'))
self.files_B: Iterable[str] = sorted(glob.glob(os.path.join(root, f'{mode}/B') + '/*.*'))

def __getitem__(self, index: int):
idx_A: int = index % len(self.files_A)
item_A: Image = Image.open(self.files_A[idx_A])

idx_B: int = random.randint(0, len(self.files_B) - 1) \
if self.unaligned else (index % len(self.files_B))
item_B: Image = Image.open(self.files_B[idx_B])

if self.grayscale:
item_A = item_A.convert('L')
item_B = item_B.convert('L')

return dict(A=self.transform(item_A), B=self.transform(item_B))

def __len__(self):
return max(len(self.files_A), len(self.files_B))
28 changes: 0 additions & 28 deletions datasets.py

This file was deleted.

104 changes: 63 additions & 41 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,114 @@
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()

conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features) ]
conv_block = [
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features)
]

self.conv_block = nn.Sequential(*conv_block)

def forward(self, x):
return x + self.conv_block(x)


class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
super(Generator, self).__init__()

# Initial convolution block
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ]
# Initial convolution block
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True)
]

# Downsampling
in_features = 64
out_features = in_features*2
out_features = in_features * 2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]

in_features = out_features
out_features = in_features*2
out_features = in_features * 2

# Residual blocks
for _ in range(n_residual_blocks):
model += [ResidualBlock(in_features)]

# Upsampling
out_features = in_features//2
out_features = in_features // 2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
model += [
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
out_features = in_features//2
out_features = in_features // 2

# Output layer
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh() ]
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh()
]

self.model = nn.Sequential(*model)

def forward(self, x):
return self.model(x)


class Discriminator(nn.Module):
def __init__(self, input_nc):
super(Discriminator, self).__init__()

# A bunch of convolutions one after another
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ]

model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True) ]

model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True) ]

model += [ nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True) ]
model = [
nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
]

model += [
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
]

model += [
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
]

model += [
nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
]

# FCN classification layer
model += [nn.Conv2d(512, 1, 4, padding=1)]

self.model = nn.Sequential(*model)

def forward(self, x):
x = self.model(x)
x = self.model(x)
# Average pooling and flatten
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
numpy==1.22.1
Pillow==9.1.0
torch==1.10.2
torchvision==0.11.3
visdom==0.1.8.9
63 changes: 34 additions & 29 deletions test → test.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,41 @@
#!/usr/bin/python3

import argparse
import sys
import os
import sys

import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch

from models import Generator
from datasets import ImageDataset
from models import Generator

parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/', help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=3, help='number of channels of output data')
parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
parser.add_argument('--cuda', action='store_true', help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth', help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='output/netG_B2A.pth', help='B2A generator checkpoint file')
parser.add_argument('--batchSize', type=int, default=1,
help='size of the batches')
parser.add_argument('--dataroot', type=str, default='datasets/horse2zebra/',
help='root directory of the dataset')
parser.add_argument('--input_nc', type=int, default=1,
help='number of channels of input data')
parser.add_argument('--output_nc', type=int, default=1,
help='number of channels of output data')
parser.add_argument('--size', type=int, default=256,
help='size of the data (squared assumed)')
parser.add_argument('--cuda', action='store_true',
help='use GPU computation')
parser.add_argument('--n_cpu', type=int, default=8,
help='number of cpu threads to use during batch generation')
parser.add_argument('--generator_A2B', type=str, default='output/netG_A2B.pth',
help='A2B generator checkpoint file')
parser.add_argument('--generator_B2A', type=str, default='output/netG_B2A.pth',
help='B2A generator checkpoint file')
opt = parser.parse_args()
print(opt)

if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###### Definition of variables ######
# Networks
netG_A2B = Generator(opt.input_nc, opt.output_nc)
netG_B2A = Generator(opt.output_nc, opt.input_nc)
Expand All @@ -52,13 +58,13 @@
input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)

# Dataset loader
transforms_ = [ transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(opt.dataroot, transforms_=transforms_, mode='test'),
batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu)
###################################
transforms_ = [
transforms.ToTensor(),
transforms.Normalize(tuple([0.5] * opt.output_nc), tuple([0.5] * opt.output_nc))
]

###### Testing######
dataset = ImageDataset(opt.dataroot, transforms_=transforms_, mode='test', grayscale=(opt.input_nc == 1))
dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu)

# Create output dirs if they don't exist
if not os.path.exists('output/A'):
Expand All @@ -72,14 +78,13 @@
real_B = Variable(input_B.copy_(batch['B']))

# Generate output
fake_B = 0.5*(netG_A2B(real_A).data + 1.0)
fake_A = 0.5*(netG_B2A(real_B).data + 1.0)
fake_B = 0.5 * (netG_A2B(real_A).data + 1.0)
fake_A = 0.5 * (netG_B2A(real_B).data + 1.0)

# Save image files
save_image(fake_A, 'output/A/%04d.png' % (i+1))
save_image(fake_B, 'output/B/%04d.png' % (i+1))
save_image(fake_A, f'output/A/{(i + 1):04d}.png')
save_image(fake_B, f'output/B/{(i + 1):04d}.png')

sys.stdout.write('\rGenerated images %04d of %04d' % (i+1, len(dataloader)))
sys.stdout.write(f'\rGenerated images {(i + 1):04d} of {len(dataloader):04d}')

sys.stdout.write('\n')
###################################
11 changes: 11 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[flake8]
ignore = W293, C901
max-line-length = 110
exclude =
# No need to traverse these directories
.git
# Jupyter Notebooks
*.ipynb
# __init__ files break rule F401
*/__init__.py
max-complexity = 10
Loading