Skip to content

Commit

Permalink
[Chore] 함수 이름 수정
Browse files Browse the repository at this point in the history
- 함수명 변경

related to : #33
  • Loading branch information
Hyunmin-H committed Aug 17, 2023
1 parent 0cc7df2 commit b0c9d12
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 442 deletions.
94 changes: 1 addition & 93 deletions model/Self_Correction_Human_Parsing/simple_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,99 +87,7 @@ def get_palette(num_cls):
lab >>= 3
return palette

def get_metaData(image, input_size):

img = image
h, w, _ = img.shape

# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0

from utils.transforms import get_affine_transform
trans = get_affine_transform(person_center, s, r, input_size)
input = cv2.warpAffine(
img,
trans,
(int(input_size[1]), int(input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))

meta = {
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}


def main_schp(target_buffer_dir):

args = get_arguments()
args.input_dir = target_buffer_dir

gpus = [int(i) for i in args.gpu.split(',')]
assert len(gpus) == 1
if not args.gpu == 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

num_classes = dataset_settings[args.dataset]['num_classes']
input_size = dataset_settings[args.dataset]['input_size']
label = dataset_settings[args.dataset]['label']
print("Evaluating total class number {} with {}".format(num_classes, label))

model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)

state_dict = torch.load(args.model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
dataset = SimpleFolderDataset(root=args.input_dir, input_size=input_size, transform=transform)
dataloader = DataLoader(dataset)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

palette = get_palette(num_classes)
with torch.no_grad():
for idx, batch in enumerate(tqdm(dataloader)):
image, meta = batch
img_name = meta['name'][0]
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]

output = model(image.cuda())
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC

logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size)
parsing_result = np.argmax(logits_result, axis=2)
parsing_result_path = os.path.join(args.output_dir, img_name[:-4] + '.png')
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
output_img.save(parsing_result_path)
if args.logits:
logits_result_path = os.path.join(args.output_dir, img_name[:-4] + '.npy')
np.save(logits_result_path, logits_result)
return

def main_schp_from_image_byte(image_byte, dataset='atr'):
def main_schp(image_byte, dataset='atr'):
args = get_arguments()

gpus = [int(i) for i in args.gpu.split(',')]
Expand Down
246 changes: 1 addition & 245 deletions model/ladi_vton/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,231 +96,8 @@ def parse_args():

return args


@torch.inference_mode()
def main_ladi(category_, db_dir, output_buffer_dir, ladi_models, target_name='target.jpg'):
args = parse_args()
args.dresscode_dataroot = db_dir
args.output_dir = output_buffer_dir
args.category = category_

# Check if the dataset dataroot is provided
if args.dataset == "vitonhd" and args.vitonhd_dataroot is None:
raise ValueError("VitonHD dataroot must be provided")
if args.dataset == "dresscode" and args.dresscode_dataroot is None:
raise ValueError("DressCode dataroot must be provided")

# Setup accelerator and device.
accelerator = Accelerator(mixed_precision=args.mixed_precision)
device = accelerator.device

# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

t = time.time()

weight_dtype = torch.float32
if args.mixed_precision == 'fp16':
weight_dtype = torch.float16

val_scheduler, text_encoder,vae , vision_encoder ,processor ,tokenizer ,unet ,emasc ,inversion_adapter, tps ,refinement = ladi_models

print('***Ladi load time', time.time() - t)


int_layers = [1, 2, 3, 4, 5]

# Enable xformers memory efficient attention if requested
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

# Load the datasets
if args.category != 'all':
category = [args.category]
else:
category = ['dresses', 'upper_body', 'lower_body']

outputlist = ['image', 'pose_map', 'inpaint_mask', 'im_mask', 'category', 'im_name', 'cloth']
if args.dataset == "dresscode":
test_dataset = DressCodeDataset(
dataroot_path=args.dresscode_dataroot,
phase='test',
order=args.test_order,
radius=5,
outputlist=outputlist,
category=category,
size=(512, 384),
target_name=target_name
)
elif args.dataset == "vitonhd":
test_dataset = VitonHDDataset(
dataroot_path=args.vitonhd_dataroot,
phase='test',
order=args.test_order,
radius=5,
outputlist=outputlist,
size=(512, 384),
)
else:
raise NotImplementedError(f"Dataset {args.dataset} not implemented")

test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers,
)



# Set to eval mode
text_encoder.eval()
vae.eval()
emasc.eval()
inversion_adapter.eval()
unet.eval()
tps.eval()
refinement.eval()
vision_encoder.eval()

# Create the pipeline
val_pipe = StableDiffusionTryOnePipeline(
text_encoder=text_encoder,
vae=vae,
tokenizer=tokenizer,
unet=unet,
scheduler=val_scheduler,
emasc=emasc,
emasc_int_layers=int_layers,
).to(device)

# Prepare the dataloader and create the output directory
test_dataloader = accelerator.prepare(test_dataloader)
save_dir = os.path.join(args.output_dir, args.test_order) ## 수정
save_dir = args.output_dir
os.makedirs(save_dir, exist_ok=True)
generator = torch.Generator("cuda").manual_seed(args.seed)

# Generate the images
for idx, batch in enumerate(tqdm(test_dataloader)):
model_img = batch.get("image").to(weight_dtype)
mask_img = batch.get("inpaint_mask").to(weight_dtype)
if mask_img is not None:
mask_img = mask_img.to(weight_dtype)
pose_map = batch.get("pose_map").to(weight_dtype)
category = batch.get("category")
cloth = batch.get("cloth").to(weight_dtype)
im_mask = batch.get('im_mask').to(weight_dtype)

# Generate the warped cloth
# For sake of performance, the TPS parameters are predicted on a low resolution image

low_cloth = torchvision.transforms.functional.resize(cloth, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
low_im_mask = torchvision.transforms.functional.resize(im_mask, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
low_pose_map = torchvision.transforms.functional.resize(pose_map, (256, 192),
torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True)
agnostic = torch.cat([low_im_mask, low_pose_map], 1)

low_grid, theta, rx, ry, cx, cy, rg, cg = tps(low_cloth, agnostic)

# We upsample the grid to the original image size and warp the cloth using the predicted TPS parameters
highres_grid = torchvision.transforms.functional.resize(low_grid.permute(0, 3, 1, 2),
size=(512, 384),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True).permute(0, 2, 3, 1)

warped_cloth = F.grid_sample(cloth, highres_grid, padding_mode='border')

# Refine the warped cloth using the refinement network
warped_cloth = torch.cat([im_mask, pose_map, warped_cloth], 1)
warped_cloth = refinement(warped_cloth)
warped_cloth = warped_cloth.clamp(-1, 1)

# Get the visual features of the in-shop cloths
input_image = torchvision.transforms.functional.resize((cloth + 1) / 2, (224, 224),
antialias=True).clamp(0, 1)
processed_images = processor(images=input_image, return_tensors="pt")
clip_cloth_features = vision_encoder(
processed_images.pixel_values.to(model_img.device, dtype=weight_dtype)).last_hidden_state

# Compute the predicted PTEs
word_embeddings = inversion_adapter(clip_cloth_features.to(model_img.device))
word_embeddings = word_embeddings.reshape((word_embeddings.shape[0], args.num_vstar, -1))

category_text = {
'dresses': 'a dress',
'upper_body': 'an upper body garment',
'lower_body': 'a lower body garment',
}
text = [f'a photo of a model wearing {category_text[category]} {" $ " * args.num_vstar}' for
category in batch['category']]

# Tokenize text
tokenized_text = tokenizer(text, max_length=tokenizer.model_max_length, padding="max_length",
truncation=True, return_tensors="pt").input_ids
tokenized_text = tokenized_text.to(word_embeddings.device)

# Encode the text using the PTEs extracted from the in-shop cloths
encoder_hidden_states = encode_text_word_embedding(text_encoder, tokenized_text,
word_embeddings, args.num_vstar).last_hidden_state

# Generate images
generated_images = val_pipe(
image=model_img,
mask_image=mask_img,
pose_map=pose_map,
warped_cloth=warped_cloth,
prompt_embeds=encoder_hidden_states,
height=512,
width=384,
guidance_scale=args.guidance_scale,
num_images_per_prompt=1,
generator=generator,
cloth_input_type='warped',
num_inference_steps=args.num_inference_steps
).images

# Save images
for gen_image, cat, name in zip(generated_images, category, batch["im_name"]):
# if not os.path.exists(os.path.join(save_dir, cat)):
# os.makedirs(os.path.join(save_dir, cat))

if args.use_png:
name = name.replace(".jpg", ".png")
gen_image.save(
os.path.join(save_dir, f'{cat}.png'))
else:
gen_image.save(
os.path.join(save_dir, f'{cat}.jpg'), quality=95)

# Free up memory
del val_pipe
del text_encoder
del vae
del emasc
del unet
del tps
del refinement
del vision_encoder
torch.cuda.empty_cache()

if args.compute_metrics:
metrics = compute_metrics(save_dir, args.test_order, args.dataset, args.category, ['all'],
args.dresscode_dataroot, args.vitonhd_dataroot)

with open(os.path.join(save_dir, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
json.dump(metrics, f, indent=4)

def main_ladi_fromImageByte(category_, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models, target_name='target.jpg'):
def main_ladi(category_, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models, target_name='target.jpg'):
args = parse_args()
args.category = category_

Expand Down Expand Up @@ -374,8 +151,6 @@ def main_ladi_fromImageByte(category_, target_bytes, schp_img, keypoint_dict, ga
num_workers=args.num_workers,
)



# Set to eval mode
text_encoder.eval()
vae.eval()
Expand Down Expand Up @@ -495,25 +270,6 @@ def main_ladi_fromImageByte(category_, target_bytes, schp_img, keypoint_dict, ga
# os.makedirs(os.path.join(save_dir, cat))

return gen_image


# Free up memory
del val_pipe
del text_encoder
del vae
del emasc
del unet
del tps
del refinement
del vision_encoder
torch.cuda.empty_cache()

if args.compute_metrics:
metrics = compute_metrics(save_dir, args.test_order, args.dataset, args.category, ['all'],
None, args.vitonhd_dataroot)

with open(os.path.join(save_dir, f"metrics_{args.test_order}_{args.category}.json"), "w+") as f:
json.dump(metrics, f, indent=4)

# if __name__ == "__main__":
# main()
Loading

0 comments on commit b0c9d12

Please sign in to comment.