Skip to content

Commit

Permalink
flake8 check
Browse files Browse the repository at this point in the history
  • Loading branch information
emmaamblard committed Oct 19, 2023
1 parent 08b43d1 commit 46ff750
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 39 deletions.
51 changes: 24 additions & 27 deletions multi_plankton_separation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def predict(**kwargs):

from PIL import Image
from webargs import fields, validate
from skimage.segmentation import find_boundaries

import multi_plankton_separation.config as cfg
from multi_plankton_separation.misc import _catch_error
Expand Down Expand Up @@ -109,20 +108,22 @@ def get_predict_args():
Get the list of arguments for the predict function
"""
# Get list of available models
list_models = [filename[:-3] for filename in os.listdir(cfg.MODEL_DIR) if filename.endswith(".pt")]
list_models = [filename[:-3]
for filename in os.listdir(cfg.MODEL_DIR)
if filename.endswith(".pt")]

arg_dict = {
"image": fields.Field(
required=True,
type="file",
location="form",
description="An image containing plankton to separate",
),
"image": fields.Field(
required=True,
type="file",
location="form",
description="An image containing plankton to separate",
),
"model": fields.Str(
required=False,
missing=list_models[0],
enum = list_models,
description = "The model used to perform instance segmentation"
enum=list_models,
description="The model used to perform instance segmentation"
),
"threshold": fields.Float(
required=False,
Expand All @@ -133,7 +134,8 @@ def get_predict_args():
required=False,
missing='image/png',
validate=validate.OneOf(['image/png']),
description="Returns an image or a json with the path to the saved result"),
description="Return an image or a json with the path to the saved result"
),
}

return arg_dict
Expand All @@ -144,7 +146,6 @@ def predict(**kwargs):
"""
Prediction function
"""
#kwargs = {"model": "mask_multi_plankton_b8", "threshold": 0.9}

# Check if a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -154,16 +155,15 @@ def predict(**kwargs):
if model is None:
message = "Model not found."
return message

# Convert image to tensor
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
orig_img = Image.open(kwargs['image'].filename)
#orig_img = Image.open("/Users/emmaamblard/Downloads/seg_data/images/img_00105.png")
img = transform(orig_img)

# Get predicted masks
pred_masks, pred_masks_probs = get_predicted_masks(model, img, kwargs["threshold"])

# Get sum of masks probabilities and mask centers
mask_sum = np.zeros(pred_masks[0].shape)
mask_centers_x = []
Expand All @@ -174,7 +174,7 @@ def predict(**kwargs):
center_x, center_y = np.unravel_index(np.argmax(mask), mask.shape)
mask_centers_x.append(center_x)
mask_centers_y.append(center_y)

mask_centers = zip(mask_centers_x, mask_centers_y)

# Apply watershed algorithm
Expand All @@ -189,16 +189,19 @@ def predict(**kwargs):

plot_width = mask_sum.shape[0] + 1000
plot_height = mask_sum.shape[1] + 1000
px = 1/plt.rcParams['figure.dpi']
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(plot_width * 5 * px, plot_height * px), subplot_kw={'xticks': [], 'yticks': []})
px = 1 / plt.rcParams['figure.dpi']
fig, axes = plt.subplots(nrows=1, ncols=5,
figsize=(plot_width * 5 * px, plot_height * px),
subplot_kw={'xticks': [], 'yticks': []})

# Plot original image
axes[0].imshow(orig_img, interpolation='none')
for mask in pred_masks_probs:
rmin, rmax, cmin, cmax = bounding_box(mask)
x, y = cmin, rmin
width, height = cmax - cmin, rmax - rmin
rect = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor='r', facecolor='none')
rect = patches.Rectangle((x, y), width, height,
linewidth=1, edgecolor='r', facecolor='none')
axes[0].add_patch(rect)
axes[0].set_title("Detected objects: {}".format(len(pred_masks)))

Expand All @@ -225,15 +228,9 @@ def predict(**kwargs):
plt.savefig(result_path, bbox_inches='tight')
plt.close()

if(kwargs["accept"] == 'image/png'):
if kwargs["accept"] == 'image/png':
message = open(output_path, 'rb')
else:
message = "Result saved in {}".format(output_path)

return message


if __name__ == '__main__':
message = predict()
print(message)
pass
2 changes: 1 addition & 1 deletion multi_plankton_separation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

BASE_DIR = Path(__file__).resolve().parents[1]
MODEL_DIR = os.path.join(BASE_DIR, "models")
TEMP_DIR = os.path.join(BASE_DIR, "temp")
TEMP_DIR = os.path.join(BASE_DIR, "temp")
31 changes: 20 additions & 11 deletions multi_plankton_separation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def get_model_instance_segmentation(num_classes):
# Initialize mask predictor
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)

return model

Expand All @@ -39,11 +41,13 @@ def load_saved_model(model_name, device):
if not os.path.exists(model_path):
print("Model {} not found.".format(model_name))
return None

state_dict = torch.load(model_path, map_location=device)
model = get_model_instance_segmentation(list(state_dict["roi_heads.mask_predictor.mask_fcn_logits.bias"].size())[0])
model = get_model_instance_segmentation(
list(state_dict["roi_heads.mask_predictor.mask_fcn_logits.bias"].size())[0]
)
model.load_state_dict(state_dict)

return model


Expand All @@ -59,9 +63,9 @@ def get_predicted_masks(model, image, score_threshold=0.9, mask_threshold=0.7):
pred_masks_probs = pred[0]['masks'].detach().numpy().squeeze(1)
try:
pred_t = [pred_score.index(x) for x in pred_score if x > score_threshold][-1]
pred_score = pred_score[:pred_t+1]
pred_masks = pred_masks[:pred_t+1]
pred_masks_probs = pred_masks_probs[:pred_t+1]
pred_score = pred_score[:pred_t + 1]
pred_masks = pred_masks[:pred_t + 1]
pred_masks_probs = pred_masks_probs[:pred_t + 1]
except IndexError:
pred_t = 'null'
pred_score = 'null'
Expand All @@ -73,7 +77,8 @@ def get_predicted_masks(model, image, score_threshold=0.9, mask_threshold=0.7):

def get_watershed_result(mask_map, mask_centers):
"""
Apply the watershed algorithm on the predicted mask map, using the mask centers as markers
Apply the watershed algorithm on the predicted mask map,
using the mask centers as markers
"""
markers_mask = np.zeros(mask_map.shape, dtype=bool)
for (x, y) in mask_centers:
Expand All @@ -83,8 +88,12 @@ def get_watershed_result(mask_map, mask_centers):
watershed_mask = np.zeros(mask_map.shape, dtype='int64')
watershed_mask[mask_map > .01] = 1

labels = watershed(-mask_map, markers, mask=watershed_mask, watershed_line=False)
labels_with_lines = watershed(-mask_map, markers, mask=watershed_mask, watershed_line=True)
labels = watershed(
-mask_map, markers, mask=watershed_mask, watershed_line=False
)
labels_with_lines = watershed(
-mask_map, markers, mask=watershed_mask, watershed_line=True
)
labels_with_lines[labels == 0] = -1

return labels_with_lines
Expand All @@ -99,4 +108,4 @@ def bounding_box(img):
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]

return rmin, rmax, cmin, cmax
return rmin, rmax, cmin, cmax

0 comments on commit 46ff750

Please sign in to comment.