Skip to content

Commit

Permalink
update for yolov8 segmentation header
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Apr 20, 2024
1 parent 306fbfb commit 99eaf2f
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions keras_cv_attention_models/yolov8/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,48 +293,30 @@ def yolov8_head(
return outputs


def yolov8_seg_head(
inputs,
num_classes=80,
regression_len=64,
num_anchors=1,
depth=2,
hidden_channels=-1,
use_object_scores=False,
segment_hidden_channels=-1,
segment_num_masks=32,
activation="swish",
classifier_activation="sigmoid",
name="",
):
def yolov8_seg_head(inputs, depth=2, hidden_channels=-1, num_masks=32, activation="swish", name=""):
channel_axis = -1 if image_data_format() == "channels_last" else 1
segment_hidden_channels = segment_hidden_channels if segment_hidden_channels > 0 else max(64, inputs[0].shape[channel_axis])
hidden_channels = hidden_channels if hidden_channels > 0 else max(64, inputs[0].shape[channel_axis])

""" mask_protos """
mask_protos = inputs[0]
mask_protos = conv_bn(mask_protos, segment_hidden_channels, 3, activation=activation, name=name + "mask_protos_1_")
mask_protos = layers.Conv2DTranspose(segment_hidden_channels, kernel_size=2, strides=2, padding="VALID", name=name + "mask_protos_up")(mask_protos)
mask_protos = conv_bn(mask_protos, segment_hidden_channels, 3, activation=activation, name=name + "mask_protos_2_")
mask_protos = conv_bn(mask_protos, segment_num_masks, 1, activation=activation, name=name + "mask_protos_3_")
mask_protos = conv_bn(mask_protos, hidden_channels, 3, activation=activation, name=name + "mask_protos_1_")
mask_protos = layers.Conv2DTranspose(hidden_channels, kernel_size=2, strides=2, padding="VALID", name=name + "mask_protos_up")(mask_protos)
mask_protos = conv_bn(mask_protos, hidden_channels, 3, activation=activation, name=name + "mask_protos_2_")
mask_protos = conv_bn(mask_protos, num_masks, 1, activation=activation, name=name + "mask_protos_3_")
mask_protos = mask_protos if image_data_format() == "channels_last" else layers.Permute([2, 3, 1])(mask_protos)

""" mask_coefficients """
mask_coefficients = []
mask_hidden_channels = max(inputs[0].shape[channel_axis] // 4, segment_num_masks)
mask_hidden_channels = max(inputs[0].shape[channel_axis] // 4, num_masks)
for id, feature in enumerate(inputs):
cur_name = name + "{}_".format(id + 1)
for id in range(depth):
feature = conv_bn(feature, mask_hidden_channels, 3, activation=activation, name=cur_name + "mask_coefficients_{}_".format(id + 1))
feature = conv2d_no_bias(feature, segment_num_masks, use_bias=True, bias_initializer="ones", name=cur_name + "mask_coefficients_{}_".format(depth + 1))
feature = conv2d_no_bias(feature, num_masks, use_bias=True, bias_initializer="ones", name=cur_name + "mask_coefficients_{}_".format(depth + 1))
feature = feature if image_data_format() == "channels_last" else layers.Permute([2, 3, 1])(feature)
feature = layers.Reshape([-1, feature.shape[-1]], name=cur_name + "mask_coefficients_reshape")(feature)
mask_coefficients.append(feature)
mask_coefficients = functional.concat(mask_coefficients, axis=1)

detect_out = yolov8_head(
inputs, num_classes, regression_len, num_anchors, depth, hidden_channels, use_object_scores, activation, classifier_activation, name=name
)
return functional.concat([detect_out, mask_coefficients], axis=-1), mask_protos # detect_out, mask_coefficients needs to apply NMS together
return functional.concat(mask_coefficients, axis=1), mask_protos


""" YOLOV8 models """
Expand Down Expand Up @@ -386,11 +368,16 @@ def YOLOV8(

header_kwargs = {"use_object_scores": use_object_scores, "activation": activation, "classifier_activation": classifier_activation}
if segment_num_masks > 0:
header_kwargs.update({"segment_num_masks": segment_num_masks})
outputs = yolov8_seg_head(fpn_features, num_classes, regression_len, num_anchors, **header_kwargs, name="seg_head_")
mask_coefficients, mask_protos = yolov8_seg_head(fpn_features, num_masks=segment_num_masks, activation=activation, name="seg_head_")
detect_out = yolov8_head(fpn_features, num_classes, regression_len, num_anchors, **header_kwargs, name="seg_head_")
detect_mask_out = functional.concat([detect_out, mask_coefficients], axis=-1) # detect_out, mask_coefficients needs to apply NMS together
outputs = [
layers.Activation("linear", dtype="float32", name="detect_mask_outputs_fp32")(detect_mask_out),
layers.Activation("linear", dtype="float32", name="mask_protos_outputs_fp32")(mask_protos),
]
else:
outputs = yolov8_head(fpn_features, num_classes, regression_len, num_anchors, **header_kwargs, name="head_")
outputs = layers.Activation("linear", dtype="float32", name="outputs_fp32")(outputs)
outputs = layers.Activation("linear", dtype="float32", name="outputs_fp32")(outputs)

model = models.Model(inputs, outputs, name=model_name)
reload_model_weights(model, PRETRAINED_DICT, "yolov8", pretrained)
Expand Down

0 comments on commit 99eaf2f

Please sign in to comment.