diff --git a/app.py b/app.py index 949a406..282143d 100644 --- a/app.py +++ b/app.py @@ -192,10 +192,13 @@ def segment_image(img,prompt_mode, categoryname, custom_category, expressiong, r boxes = LSJ_box_postprocess(pred_boxes,padding_size,re_size, ori_height,ori_width) mask_pred = mask_pred[topk_indices] - pred_masks = F.interpolate( mask_pred[None,], size=(padding_size[0], padding_size[1]), mode="bilinear", align_corners=False ) - pred_masks = pred_masks[:,:,:re_size[0],:re_size[1]] - pred_masks = F.interpolate( pred_masks, size=(ori_height,ori_width), mode="bilinear", align_corners=False ) - pred_masks = (pred_masks>0).detach().cpu().numpy()[0] + if len(topk_indices) == 0: # 检查是否有选中的索引 + pred_masks = [] + else: + pred_masks = F.interpolate( mask_pred[None,], size=(padding_size[0], padding_size[1]), mode="bilinear", align_corners=False ) + pred_masks = pred_masks[:,:,:re_size[0],:re_size[1]] + pred_masks = F.interpolate( pred_masks, size=(ori_height,ori_width), mode="bilinear", align_corners=False ) + pred_masks = (pred_masks>0).detach().cpu().numpy()[0] if 'mask' in results_select: @@ -324,11 +327,12 @@ def segment_image(img,prompt_mode, categoryname, custom_category, expressiong, r boxes = LSJ_box_postprocess(pred_boxes,padding_size,re_size, ori_height,ori_width) mask_pred = mask_pred[topk_indices] - pred_masks = F.interpolate( mask_pred[None,], size=(padding_size[0], padding_size[1]), mode="bilinear", align_corners=False ) - pred_masks = pred_masks[:,:,:re_size[0],:re_size[1]] - pred_masks = F.interpolate( pred_masks, size=(ori_height,ori_width), mode="bilinear", align_corners=False ) - pred_masks = (pred_masks>0).detach().cpu().numpy()[0] - mask_results_list.append(pred_masks) + if len(topk_indices) > 0: # 检查是否有选中的索引 + pred_masks = F.interpolate( mask_pred[None,], size=(padding_size[0], padding_size[1]), mode="bilinear", align_corners=False ) + pred_masks = pred_masks[:,:,:re_size[0],:re_size[1]] + pred_masks = F.interpolate( pred_masks, size=(ori_height,ori_width), mode="bilinear", align_corners=False ) + pred_masks = (pred_masks>0).detach().cpu().numpy()[0] + mask_results_list.append(pred_masks) zero_mask = np.zeros_like(copyed_img) for mask,RGB in zip(mask_results_list,visual_prompt_RGB_list):