Skip to content

Commit

Permalink
[Feat] ladi 결과 저장까지 완료 #6
Browse files Browse the repository at this point in the history
- ladi 결과가 user_db/ladi/buffer에 lower_body.png로 저장됨
- openpose 실행 전 input을  512, 384로 resize
  • Loading branch information
Hyunmin-H committed Jul 15, 2023
1 parent 14bf68f commit 2e5eab9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
13 changes: 7 additions & 6 deletions model/ladi_vton/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def parse_args():
###수정
parser.add_argument("--dataset", type=str, default="dresscode", choices=["dresscode", "vitonhd"], help="dataset to use") #required=True,
parser.add_argument("--category", type=str, choices=['all', 'lower_body', 'upper_body', 'dresses'], default='lower_body') ###
parser.add_argument("--use_png", default=False, action="store_true")
parser.add_argument("--use_png", default=True, action="store_true")
parser.add_argument("--num_inference_steps", default=50, type=int)
parser.add_argument("--guidance_scale", default=7.5, type=float)
parser.add_argument("--compute_metrics", default=False, action="store_true")
Expand Down Expand Up @@ -211,7 +211,8 @@ def main_ladi(db_dir, output_buffer_dir):

# Prepare the dataloader and create the output directory
test_dataloader = accelerator.prepare(test_dataloader)
save_dir = os.path.join(args.output_dir, args.test_order)
save_dir = os.path.join(args.output_dir, args.test_order) ## 수정
save_dir = args.output_dir
os.makedirs(save_dir, exist_ok=True)
generator = torch.Generator("cuda").manual_seed(args.seed)

Expand Down Expand Up @@ -305,16 +306,16 @@ def main_ladi(db_dir, output_buffer_dir):

# Save images
for gen_image, cat, name in zip(generated_images, category, batch["im_name"]):
if not os.path.exists(os.path.join(save_dir, cat)):
os.makedirs(os.path.join(save_dir, cat))
# if not os.path.exists(os.path.join(save_dir, cat)):
# os.makedirs(os.path.join(save_dir, cat))

if args.use_png:
name = name.replace(".jpg", ".png")
gen_image.save(
os.path.join(save_dir, cat, name))
os.path.join(save_dir, f'{cat}.png'))
else:
gen_image.save(
os.path.join(save_dir, cat, name), quality=95)
os.path.join(save_dir, f'{cat}.jpg'), quality=95)

# Free up memory
del val_pipe
Expand Down
4 changes: 1 addition & 3 deletions model/pytorch_openpose/extract_keypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def main_openpose(target_buffer_dir, output_buffer_dir):
# image read
# test_image = test_image
oriImg = cv2.imread(test_image) # B,G,R order
print('*******', test_image)
print('*******',oriImg.shape)
# oriImg = cv2.resize(oriImg, (384, 512)) # resize
oriImg = cv2.resize(oriImg, (384, 512)) # resize

# body_estimation foreward
candidate, subset = body_estimation(oriImg)
Expand Down

0 comments on commit 2e5eab9

Please sign in to comment.