Skip to content

Commit

Permalink
fixes (#4)
Browse files Browse the repository at this point in the history
* fixes

* small fix

* black

* fix lint

---------

Co-authored-by: Benjamin Bolte <[email protected]>
  • Loading branch information
nathanjzhao and codekansas authored Jul 24, 2024
1 parent ce8430f commit d7e0aaf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ out*/
*.stl
mnist_data/
contents/
*.pth
Binary file added assets/steps_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 21 additions & 1 deletion infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn.functional as F
from torch import Tensor, nn
from torch.utils.data.dataloader import DataLoader
from torchvision.utils import save_image
from torchvision.utils import make_grid, save_image

from dataloader import mnist
from model import ConsistencyModel
Expand Down Expand Up @@ -140,6 +140,26 @@ def main() -> None:
logger.info("Low quality image saved as: %slow_quality_image.png", args.prefix)
logger.info("Finished image saved as: %sfinished_image.png", args.prefix)

image, _ = next(iter(test_loader))
image = image[:20]
step_spreads = [
[80.0],
[5.0, 80.0],
[5.0, 50.0, 80.0],
[5.0, 20.0, 40.0, 80.0],
[5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0],
]

for i, steps in enumerate(step_spreads):
num_steps = len(steps)
xh = model.sample(
torch.randn_like(image).to(device=device) * 80.0,
list(reversed(steps)),
)
xh = (xh * 0.5 + 0.5).clamp(0, 1)
grid = make_grid(xh, nrow=4)
save_image(grid, os.path.join(args.output_dir, f"{args.prefix}ct_{name}_sample_{num_steps}step.png"))


if __name__ == "__main__":
main()
8 changes: 4 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def sample(self, x: Tensor, ts: list[float], partial_start: float | None = None)
# Start from a partially denoised state
start_idx = next(i for i, t in enumerate(ts) if t <= partial_start)
x = self(x, ts[start_idx])
ts = ts[start_idx:]
ts = ts[start_idx + 1 :]

# just running through the model at random timestamps until end
# bigger jumps more unstable
for t in ts[1:]:
# Just running through the model at random timestamps until end
# Bigger jumps more unstable
for t in ts:
z = torch.randn_like(x)
x = x + (math.sqrt(t**2 - self.eps**2) * z)
x = self(x, t)
Expand Down

0 comments on commit d7e0aaf

Please sign in to comment.