We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model: u2net downloaded from the link in the repository README:https://pan.baidu.com/s/1WjwyEwDiaUjBbx_QxcXBwQ
image: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/gJT2UhwWmcHgM6ep.jpg background: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/DCAqTPHo7Advhmrv.jpg current result: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/RFD5jxgohUXX4fAX.jpg
def main(): # --------- 1. get image path and name --------- model_name = 'u2net' # u2netp # model_name = 'u2netp' # u2netp image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images') prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep) model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth') img_name_list = glob.glob(image_dir + os.sep + '*') print(img_name_list) # --------- 2. dataloader --------- # 1. dataloader test_salobj_dataset = SalObjDataset(img_name_list=img_name_list, lbl_name_list=[], transform=transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) ) test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers=1) # --------- 3. model define --------- if (model_name == 'u2net'): print("...load U2NET---173.6 MB") net = U2NET(3, 1) elif (model_name == 'u2netp'): print("...load U2NEP---4.7 MB") net = U2NETP(3, 1) if torch.cuda.is_available(): net.load_state_dict(torch.load(model_dir)) net.cuda() else: net.load_state_dict(torch.load(model_dir, map_location='cpu')) net.eval() # --------- 4. inference for each image --------- for i_test, data_test in enumerate(test_salobj_dataloader): print("inferencing:", img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) if torch.cuda.is_available(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) # normalization pred = d1[:, 0, :, :] pred = normPRED(pred) # save results to test_results folder if not os.path.exists(prediction_dir): os.makedirs(prediction_dir, exist_ok=True) output_filename = save_output(img_name_list[i_test], pred, prediction_dir) ########## modification comes here ########## predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() input_image_file_path = img_name_list[i_test] image = cv2.imread(input_image_file_path) background = cv2.imread("test_data/bg-05.jpg") background = cv2.resize(background, (image.shape[1], image.shape[0])) im = Image.fromarray(predict_np * 255).convert('RGB') imo = im.resize((image.shape[1], image.shape[0]), resample=Image.BILINEAR) data = np.asarray(imo, dtype="int32") condition = data > 0.98 * 255 output_image = np.where(condition, image, background) cv2.imwrite(f"test_data/rbg/{os.path.basename(input_image_file_path)}", output_image) del d1, d2, d3, d4, d5, d6, d7
I'm a newbie at this field, could anyone offer a help, thanks a lot~
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Env
model: u2net downloaded from the link in the repository README:https://pan.baidu.com/s/1WjwyEwDiaUjBbx_QxcXBwQ
Test input
image: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/gJT2UhwWmcHgM6ep.jpg
background: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/DCAqTPHo7Advhmrv.jpg
current result: https://cdn.avatar.dmc-ai.cn/avatar/photos/2024/02/02/RFD5jxgohUXX4fAX.jpg
Test Method
I'm a newbie at this field, could anyone offer a help, thanks a lot~
The text was updated successfully, but these errors were encountered: