-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathesrgan_infer.py
95 lines (79 loc) · 3.55 KB
/
esrgan_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from PIL import Image
import os
from torchvision.utils import save_image
from torchvision import transforms
import argparse
from tqdm import tqdm
from esrgan import ESRGAN
import torch.nn.functional as F
import time
def load_model(checkpoint_path):
"""Load the trained ESRGAN model"""
model = ESRGAN.load_from_checkpoint(checkpoint_path)
model.eval()
return model.to('cuda' if torch.cuda.is_available() else 'cpu')
def process_image(model, image_path, output_path):
"""Process a single image"""
# Load and preprocess image
img = Image.open(image_path).convert('L')
transform = transforms.Compose([transforms.ToTensor()])
input_tensor = transform(img).unsqueeze(0)
# Move to same device as model
device = next(model.parameters()).device
input_tensor = input_tensor.to(device)
# Generate SR image
with torch.no_grad():
sr_image = model(input_tensor)
# Save output
save_image(sr_image, output_path)
return sr_image
def main():
parser = argparse.ArgumentParser(description='ESRGAN Inference')
parser.add_argument('--checkpoint', '-p', type=str, default='best_models/esrgan-epoch=17-val_ssim=0.8919.ckpt',
help='Path to model checkpoint')
parser.add_argument('--input_dir', '-i', type=str, default='inference_data/input',
help='Directory containing input images')
parser.add_argument('--output_dir', '-o', type=str, default='inference_data/output',
help='Directory to save super-resolved images')
parser.add_argument('--compare', '-c', action='store_true',
help='Save comparison with input image')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load model
print("Loading model...")
model = load_model(args.checkpoint)
device = next(model.parameters()).device
print(f"Model loaded on {device}")
# Process all images in input directory
image_files = [f for f in os.listdir(args.input_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
print(f"Found {len(image_files)} images to process")
for img_file in tqdm(image_files, desc="Processing images"):
input_path = os.path.join(args.input_dir, img_file)
output_path = os.path.join(args.output_dir, f"sr_{img_file}")
# Process image
start = time.perf_counter()
sr_image = process_image(model, input_path, output_path)
end = time.perf_counter()
time_ms = (end - start) * 1000
print(f"{time_ms} ms")
# Create comparison if requested
if args.compare:
# Load original image
lr_img = Image.open(input_path).convert('L')
transform = transforms.Compose([transforms.ToTensor()])
lr_tensor = transform(lr_img).unsqueeze(0)
# Resize LR image to match SR size using bicubic interpolation
lr_resized = F.interpolate(
lr_tensor.to(device),
size=sr_image.shape[-2:], # Get height and width from SR image
mode='bicubic',
align_corners=False
)
# Create side-by-side comparison
comparison = torch.cat([lr_resized.cpu(), sr_image.cpu()], dim=-1)
save_image(lr_resized, os.path.join(args.output_dir, f"lr_bicubic_{img_file}"))
save_image(comparison, os.path.join(args.output_dir, f"compare_{img_file}"))
if __name__ == "__main__":
main()