-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae_to_video.py
81 lines (68 loc) · 2.09 KB
/
vae_to_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
latent_dim = 32
batch_size = 64
class VAE(nn.Module):
def __init__(self, input_size, latent_dim):
super(VAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_size, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
)
self.z_mean = nn.Linear(64, latent_dim)
self.z_log_var = nn.Linear(64, latent_dim)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, input_size),
nn.Sigmoid() # Assuming you want values between 0 and 1
)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# Encoder
x = self.encoder(x)
z_mean = self.z_mean(x)
z_log_var = self.z_log_var(x)
# Reparameterization trick
z = self.reparameterize(z_mean, z_log_var)
# Decoder
x_hat = self.decoder(z)
return x_hat, z_mean, z_log_var
# Instantiate the VAE model
input_size = 2
model = VAE(input_size=input_size, latent_dim=latent_dim)
# Load the model weights
model.load_state_dict(torch.load("vae_model.pt"))
# Generate 300 new frames
images = []
with torch.no_grad():
for _ in range(300):
# Generate a random latent vector
z = torch.randn(1, latent_dim)
# Decode the latent vector
frame = model.decoder(z)
images.append(frame)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('vae_output.mp4', fourcc, 30, (512, 512))
# Iterate through each 512x512 array and create a frame
for i in range(len(images)):
frame = images[i]
print(frame.shape)
out.write(frame)
# Release the VideoWriter object and close all windows
out.release()
cv2.destroyAllWindows()