Skip to content

Commit

Permalink
code cleanup (#3)
Browse files Browse the repository at this point in the history
* code cleanup

* cleanup

* vscode settings

* cleanup

* black

---------

Co-authored-by: Nathan Zhao <[email protected]>
  • Loading branch information
codekansas and nathanjzhao authored Jul 24, 2024
1 parent c2b45b6 commit aea158b
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 177 deletions.
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[python]": {
"editor.formatOnSave": true,
"editor.defaultFormatter": "ms-python.black-formatter"
}
}
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,22 @@ To customize the training of the consistency models, the following command line
- `--loss_type`: The type of loss function to use. Can be either `mse` for Mean Squared Error or `huber` for Huber loss. Default is `mse`.
- `--partial_sampling`: Enables partial sampling, which can be useful for reducing the number of sampling steps required to reach the final model prediction. This is disabled by default and can be enabled by adding this flag without any value.


### Example Usage

To run the training with specific options, you can use the command line as follows:

```bash
python train.py --prefix experiment1 --n_epochs 200 --output_dir ./experiment1_outputs --device cuda:0 --loss_type huber --partial_sampling
python train.py \
--prefix experiment1 \
--n_epochs 200 \
--output_dir ./experiment1_outputs \
--device cuda:0 \
--loss_type huber \
--partial_sampling
```

## Miscellaneous

TODO

- [ ] Latent Consistency Modeling with a VAE
15 changes: 5 additions & 10 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""Defines a simple MNIST dataloader."""

from typing import Tuple
from torchvision import datasets, transforms

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define transformations for the images
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])


def mnist() -> Tuple[DataLoader, DataLoader]:
# Download and load the training data
train_dataset = datasets.MNIST(
root="mnist_data", train=True, download=True, transform=transform
)
train_dataset = datasets.MNIST(root="mnist_data", train=True, download=True, transform=transform)

train_loader = DataLoader(
dataset=train_dataset,
Expand All @@ -23,9 +20,7 @@ def mnist() -> Tuple[DataLoader, DataLoader]:
)

# Download and load the test data
test_dataset = datasets.MNIST(
root="mnist_data", train=False, download=True, transform=transform
)
test_dataset = datasets.MNIST(root="mnist_data", train=False, download=True, transform=transform)

test_loader = DataLoader(
dataset=test_dataset,
Expand Down
101 changes: 53 additions & 48 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
"""Defines the inference script."""

import argparse
import logging
import os

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.utils.data.dataloader import DataLoader
from torchvision.utils import save_image

from model import ConsistencyModel
from dataloader import mnist
from model import ConsistencyModel

logger = logging.getLogger(__name__)

def get_low_quality_image(test_loader):
"""
Get a single low quality image from the test loader.

def get_low_quality_image(test_loader: DataLoader) -> tuple[Tensor, int]:
"""Get a single low quality image from the test loader.
Args:
- test_loader: DataLoader for the MNIST test set
test_loader: DataLoader for the MNIST test set
Returns:
- A low quality (12x12) image tensor and its corresponding label
A low quality (12x12) image tensor and its corresponding label
"""
# Get a single batch from the test loader
images, labels = next(iter(test_loader))
Expand All @@ -27,27 +34,33 @@ def get_low_quality_image(test_loader):

# Resize the image to 14x14
low_quality_image = F.interpolate(
image.unsqueeze(0), size=(14, 14), mode="bilinear", align_corners=False
image.unsqueeze(0),
size=(14, 14),
mode="bilinear",
align_corners=False,
).squeeze(0)

return low_quality_image, label


def finish_low_quality_image(
model, low_quality_image, device, partial_start=40.0, steps=None
):
"""
Finish a low quality MNIST image using partial sampling.
model: nn.Module,
low_quality_image: Tensor,
device: torch.device,
partial_start: float = 40.0,
steps: list[float] | None = None,
) -> Tensor:
"""Finish a low quality MNIST image using partial sampling.
Args:
- model: The trained ConsistencyModel
- low_quality_image: A tensor of size 1x12x12
- device: The device to run the model on
- partial_start: The starting point for partial sampling (default: 40.0)
- steps: List of timesteps for sampling (if None, default steps will be used)
model: The trained ConsistencyModel
low_quality_image: A tensor of size 1x12x12
device: The device to run the model on
partial_start: The starting point for partial sampling (default: 40.0)
steps: List of timesteps for sampling (if None, default steps will be used)
Returns:
- A tensor of the finished image (28x28)
A tensor of the finished image (28x28)
"""
# Upscale the image to 28x28
upscaled_image = F.interpolate(
Expand All @@ -69,31 +82,40 @@ def finish_low_quality_image(

# Perform partial sampling
with torch.no_grad():
finished_image = model.sample(
noisy_image.unsqueeze(0), ts=steps, partial_start=partial_start
)
finished_image = model.sample(noisy_image.unsqueeze(0), ts=steps, partial_start=partial_start)

# Denormalize and clamp the image
finished_image = (finished_image.squeeze(0) * 0.5 + 0.5).clamp(0, 1)

return finished_image


def main(args):
def main() -> None:
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser(description="Consistency Model Training and Image Finishing")
parser.add_argument("--device", type=str, default="cuda:0", help="CUDA device to use")
parser.add_argument("--prefix", type=str, default="", help="Prefix for checkpoint and output names")
parser.add_argument(
"--output_dir",
type=str,
default="./contents",
help="Output directory for checkpoints and images",
)
args = parser.parse_args()

n_channels = 1
name = "mnist"

device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
logger.info("Using device: %s", device)

train_loader, test_loader = mnist()
model = ConsistencyModel(n_channels, D=128)
_, test_loader = mnist()
model = ConsistencyModel(n_channels, hdims=128)
model.to(device)

# Load the trained model
model.load_state_dict(
torch.load(os.path.join(args.output_dir, f"{args.prefix}ct_{name}.pth"))
)
model.load_state_dict(torch.load(os.path.join(args.output_dir, f"{args.prefix}ct_{name}.pth")))
model.eval()

# Get a low quality image from the test set
Expand All @@ -114,27 +136,10 @@ def main(args):
os.path.join(args.output_dir, f"{args.prefix}finished_image.png"),
)

print(f"Processed image with label: {label}")
print(f"Low quality image saved as: {args.prefix}low_quality_image.png")
print(f"Finished image saved as: {args.prefix}finished_image.png")
logger.info("Processed image with label: %s", label)
logger.info("Low quality image saved as: %slow_quality_image.png", args.prefix)
logger.info("Finished image saved as: %sfinished_image.png", args.prefix)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Consistency Model Training and Image Finishing"
)
parser.add_argument(
"--device", type=str, default="cuda:0", help="CUDA device to use"
)
parser.add_argument(
"--prefix", type=str, default="", help="Prefix for checkpoint and output names"
)
parser.add_argument(
"--output_dir",
type=str,
default="./contents",
help="Output directory for checkpoints and images",
)
args = parser.parse_args()

main(args)
main()
86 changes: 48 additions & 38 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Defines consistency model."""

import torch
import math
from typing import List, Optional
import torch.nn as nn
from typing import Literal

import torch
import torch.nn.functional as F
from torch import Tensor, nn


def blk(ic, oc):
def blk(ic: int, oc: int) -> nn.Module:
return nn.Sequential(
nn.GroupNorm(32, num_channels=ic),
nn.SiLU(),
Expand All @@ -19,50 +20,44 @@ def blk(ic, oc):


class ConsistencyModel(nn.Module):
def __init__(self, n_channel: int, eps: float = 0.002, D: int = 128) -> None:
def __init__(self, n_channel: int, eps: float = 0.002, hdims: int = 128) -> None:
super(ConsistencyModel, self).__init__()

self.eps = eps

self.freqs = torch.exp(
-math.log(10000) * torch.arange(start=0, end=D, dtype=torch.float32) / D
)
self.freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=hdims, dtype=torch.float32) / hdims)

self.down = nn.Sequential(
*[
nn.Conv2d(n_channel, D, 3, padding=1),
blk(D, D),
blk(D, 2 * D),
blk(2 * D, 2 * D),
nn.Conv2d(n_channel, hdims, 3, padding=1),
blk(hdims, hdims),
blk(hdims, 2 * hdims),
blk(2 * hdims, 2 * hdims),
]
)

self.time_downs = nn.Sequential(
nn.Linear(2 * D, D),
nn.Linear(2 * D, D),
nn.Linear(2 * D, 2 * D),
nn.Linear(2 * D, 2 * D),
nn.Linear(2 * hdims, hdims),
nn.Linear(2 * hdims, hdims),
nn.Linear(2 * hdims, 2 * hdims),
nn.Linear(2 * hdims, 2 * hdims),
)

self.mid = blk(2 * D, 2 * D)
self.mid = blk(2 * hdims, 2 * hdims)

self.up = nn.Sequential(
*[
blk(2 * D, 2 * D),
blk(2 * 2 * D, D),
blk(D, D),
nn.Conv2d(2 * D, 2 * D, 3, padding=1),
blk(2 * hdims, 2 * hdims),
blk(2 * 2 * hdims, hdims),
blk(hdims, hdims),
nn.Conv2d(2 * hdims, 2 * hdims, 3, padding=1),
]
)
self.last = nn.Conv2d(2 * D + n_channel, n_channel, 3, padding=1)
self.last = nn.Conv2d(2 * hdims + n_channel, n_channel, 3, padding=1)

def forward(self, x, t):
def forward(self, x: Tensor, t: Tensor) -> Tensor:
if isinstance(t, float):
t = (
torch.tensor([t] * x.shape[0], dtype=torch.float32)
.to(x.device)
.unsqueeze(1)
)
t = torch.tensor([t] * x.shape[0], dtype=torch.float32).to(x.device).unsqueeze(1)

# time embedding
args = t.float() * self.freqs[None].to(t.device)
Expand Down Expand Up @@ -105,7 +100,15 @@ def forward(self, x, t):
# as time progresses want to rely more on the model's output image (likely to be more informed)
return c_skip_t[:, :, None, None] * x_ori + c_out_t[:, :, None, None] * x

def loss(self, x, z, t1, t2, ema_model, loss_type="mse"):
def loss(
self,
x: Tensor,
z: Tensor,
t1: Tensor,
t2: Tensor,
ema_model: nn.Module,
loss_type: Literal["mse", "huber"] = "mse",
) -> Tensor:
x2 = x + z * t2[:, :, None, None]

# forward pass
Expand All @@ -121,15 +124,16 @@ def loss(self, x, z, t1, t2, ema_model, loss_type="mse"):
# across the flow.
x1 = ema_model(x1, t1)

if loss_type == "mse":
return F.mse_loss(x1, x2)
elif loss_type == "huber":
return pseudo_huber_loss(x1, x2)
else:
raise ValueError("Invalid loss type. Choose 'mse' or 'huber'.")
match loss_type:
case "mse":
return F.mse_loss(x1, x2)
case "huber":
return pseudo_huber_loss(x1, x2)
case _:
raise ValueError("Invalid loss type. Choose 'mse' or 'huber'.")

@torch.no_grad()
def sample(self, x, ts: List[float], partial_start: Optional[float] = None):
def sample(self, x: Tensor, ts: list[float], partial_start: float | None = None) -> Tensor:
if partial_start is not None:
# Start from a partially denoised state
start_idx = next(i for i, t in enumerate(ts) if t <= partial_start)
Expand All @@ -140,12 +144,18 @@ def sample(self, x, ts: List[float], partial_start: Optional[float] = None):
# bigger jumps more unstable
for t in ts[1:]:
z = torch.randn_like(x)
x = x + math.sqrt(t**2 - self.eps**2) * z
x = x + (math.sqrt(t**2 - self.eps**2) * z)
x = self(x, t)

return x


def pseudo_huber_loss(x, y, delta=1.0):
def pseudo_huber_loss(x: Tensor, y: Tensor, delta: float = 1.0) -> Tensor:
diff = x - y
return torch.mean(delta**2 * (torch.sqrt(1 + (diff / delta) ** 2) - 1))


def kerras_boundaries(sigma: float, eps: float, n: int, t: float) -> Tensor:
# This will be used to generate the boundaries for the time discretization
bounds = [(eps ** (1 / sigma) + i / (n - 1) * (t ** (1 / sigma) - eps ** (1 / sigma))) ** sigma for i in range(n)]
return torch.tensor(bounds)
11 changes: 0 additions & 11 deletions mypy.ini

This file was deleted.

Loading

0 comments on commit aea158b

Please sign in to comment.