diff --git a/model/ladi_vton/inference.py b/model/ladi_vton/inference.py index 3cd020a..49cfbd6 100644 --- a/model/ladi_vton/inference.py +++ b/model/ladi_vton/inference.py @@ -81,7 +81,7 @@ def parse_args(): ###수정 parser.add_argument("--dataset", type=str, default="dresscode", choices=["dresscode", "vitonhd"], help="dataset to use") #required=True, parser.add_argument("--category", type=str, choices=['all', 'lower_body', 'upper_body', 'dresses'], default='lower_body') ### - parser.add_argument("--use_png", default=False, action="store_true") + parser.add_argument("--use_png", default=True, action="store_true") parser.add_argument("--num_inference_steps", default=50, type=int) parser.add_argument("--guidance_scale", default=7.5, type=float) parser.add_argument("--compute_metrics", default=False, action="store_true") @@ -211,7 +211,8 @@ def main_ladi(db_dir, output_buffer_dir): # Prepare the dataloader and create the output directory test_dataloader = accelerator.prepare(test_dataloader) - save_dir = os.path.join(args.output_dir, args.test_order) + save_dir = os.path.join(args.output_dir, args.test_order) ## 수정 + save_dir = args.output_dir os.makedirs(save_dir, exist_ok=True) generator = torch.Generator("cuda").manual_seed(args.seed) @@ -305,16 +306,16 @@ def main_ladi(db_dir, output_buffer_dir): # Save images for gen_image, cat, name in zip(generated_images, category, batch["im_name"]): - if not os.path.exists(os.path.join(save_dir, cat)): - os.makedirs(os.path.join(save_dir, cat)) + # if not os.path.exists(os.path.join(save_dir, cat)): + # os.makedirs(os.path.join(save_dir, cat)) if args.use_png: name = name.replace(".jpg", ".png") gen_image.save( - os.path.join(save_dir, cat, name)) + os.path.join(save_dir, f'{cat}.png')) else: gen_image.save( - os.path.join(save_dir, cat, name), quality=95) + os.path.join(save_dir, f'{cat}.jpg'), quality=95) # Free up memory del val_pipe diff --git a/model/pytorch_openpose/extract_keypoint.py b/model/pytorch_openpose/extract_keypoint.py index 39fe2e0..69b21bb 100644 --- a/model/pytorch_openpose/extract_keypoint.py +++ b/model/pytorch_openpose/extract_keypoint.py @@ -29,9 +29,7 @@ def main_openpose(target_buffer_dir, output_buffer_dir): # image read # test_image = test_image oriImg = cv2.imread(test_image) # B,G,R order - print('*******', test_image) - print('*******',oriImg.shape) - # oriImg = cv2.resize(oriImg, (384, 512)) # resize + oriImg = cv2.resize(oriImg, (384, 512)) # resize # body_estimation foreward candidate, subset = body_estimation(oriImg)