Skip to content

Commit

Permalink
docs: little update
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiteretsu77 committed Mar 24, 2024
1 parent 318d491 commit 6bca544
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
9 changes: 7 additions & 2 deletions test_code/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This is file is to execute the inference for a single image or a folder input
'''
import argparse
import time
import os, sys, cv2, shutil, warnings
import torch
from torchvision.transforms import ToTensor
Expand Down Expand Up @@ -76,6 +77,7 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc

# Sample Command
# 4x GRL (Default): python test_code/inference.py --model GRL --scale 4 --weight_path pretrained/4x_APISR_GRL_GAN_generator.pth
# 4x RRDB: python test_code/inference.py --model RRDB --scale 4 --weight_path pretrained/4x_APISR_RRDB_GAN_generator.pth
# 2x RRDB: python test_code/inference.py --model RRDB --scale 2 --weight_path pretrained/2x_APISR_RRDB_GAN_generator.pth


Expand Down Expand Up @@ -116,8 +118,9 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc
elif model == "RRDB":
generator = load_rrdb(weight_path, scale=scale) # Can be any size
generator = generator.to(dtype=weight_dtype)



start = time.time()
# Take the input path and do inference
if os.path.isdir(store_dir): # If the input is a directory, we will iterate it
for filename in sorted(os.listdir(input_dir)):
Expand All @@ -131,7 +134,9 @@ def super_resolve_img(generator, input_path, output_path=None, weight_dtype=torc
output_path = os.path.join(store_dir, filename+"_"+str(scale)+"x.png")
# In default, we will automatically use crop to match 4x size
super_resolve_img(generator, input_dir, output_path, weight_dtype, crop_for_4x=True)

end = time.time()

print("Total inference time spent is ", end-start)



Expand Down
1 change: 1 addition & 0 deletions train_code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def parse_args():

if not args.auto_resume_closest and not args.auto_resume_best:
# Restart tensorboard (delete all things under ./runs)
print("We will remove the log of tensorboard.")
if os.path.exists("./runs"):
storage_manage()
shutil.rmtree("./runs")
Expand Down

0 comments on commit 6bca544

Please sign in to comment.