Skip to content

Commit

Permalink
more pytroch model modficiations for better ONNX exports
Browse files Browse the repository at this point in the history
  • Loading branch information
phager90 committed Sep 1, 2020
1 parent b5c600f commit 0a8d426
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 8 deletions.
82 changes: 80 additions & 2 deletions effdet/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0-d92fd44f.pth',
),
tf_efficientdet_d0_phager=dict(
name='tf_efficientdet_d0',
backbone_name='efficientnet_b0', # HACK
image_size=512,
fpn_channels=64,
fpn_cell_repeats=3,
box_class_repeats=3,
backbone_args=dict(drop_path_rate=0.2),
pad_type='', # HACK do not mimick TF padding
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0-d92fd44f.pth',
),
tf_efficientdet_d1=dict(
name='tf_efficientdet_d1',
backbone_name='tf_efficientnet_b1',
Expand All @@ -214,6 +225,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1-4c7ebaf2.pth'
),
tf_efficientdet_d1_phager=dict(
name='tf_efficientdet_d1',
backbone_name='efficientnet_b1', # HACK
image_size=640,
fpn_channels=88,
fpn_cell_repeats=4,
box_class_repeats=3,
pad_type='', # HACK do not mimick TF padding
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1-4c7ebaf2.pth'
),
tf_efficientdet_d2=dict(
name='tf_efficientdet_d2',
backbone_name='tf_efficientnet_b2',
Expand All @@ -224,6 +246,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2-cb4ce77d.pth',
),
tf_efficientdet_d2_phager=dict(
name='tf_efficientdet_d2',
backbone_name='efficientnet_b2', # HACK
image_size=768,
fpn_channels=112,
fpn_cell_repeats=5,
box_class_repeats=3,
pad_type='', # HACK do not mimick TF padding
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2-cb4ce77d.pth',
),
tf_efficientdet_d3=dict(
name='tf_efficientdet_d3',
backbone_name='tf_efficientnet_b3',
Expand All @@ -234,6 +267,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_47-0b525f35.pth',
),
tf_efficientdet_d3_phager=dict(
name='tf_efficientdet_d3',
backbone_name='efficientnet_b3', # HACK
image_size=896,
fpn_channels=160,
fpn_cell_repeats=6,
box_class_repeats=4,
pad_type='', # HACK do not mimick TF padding
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_47-0b525f35.pth',
),
tf_efficientdet_d4=dict(
name='tf_efficientdet_d4',
backbone_name='tf_efficientnet_b4',
Expand All @@ -244,6 +288,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4-5b370b7a.pth',
),
tf_efficientdet_d4_phager=dict(
name='tf_efficientdet_d4',
backbone_name='efficientnet_b4', #HACK
image_size=1024,
fpn_channels=224,
fpn_cell_repeats=7,
box_class_repeats=4,
pad_type='', # HACK do not mimick TF padding
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4-5b370b7a.pth',
),
tf_efficientdet_d5=dict(
name='tf_efficientdet_d5',
backbone_name='tf_efficientnet_b5',
Expand All @@ -254,6 +309,17 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5-ef44aea8.pth',
),
tf_efficientdet_d5_phager=dict(
name='tf_efficientdet_d5',
backbone_name='efficientnet_b5', #HACK
image_size=1280,
fpn_channels=288,
fpn_cell_repeats=7,
box_class_repeats=4,
pad_type='', # HACK do not mimick TF padding
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5-ef44aea8.pth',
),
tf_efficientdet_d6=dict(
name='tf_efficientdet_d6',
backbone_name='tf_efficientnet_b6',
Expand All @@ -265,6 +331,18 @@ def default_detection_model_configs():
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d6-51cb0132.pth'
),
tf_efficientdet_d6_phager=dict(
name='tf_efficientdet_d6',
backbone_name='efficientnet_b6', #HACK
image_size=1280,
fpn_channels=384,
fpn_cell_repeats=8,
box_class_repeats=5,
pad_type='', # HACK do not mimick TF padding
fpn_name='bifpn_sum', # Use unweighted sum for training stability.
backbone_args=dict(drop_path_rate=0.2),
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d6-51cb0132.pth'
),
tf_efficientdet_d7=dict(
name='tf_efficientdet_d7',
backbone_name='tf_efficientnet_b6',
Expand Down Expand Up @@ -308,15 +386,15 @@ def default_detection_model_configs():
),
tf_efficientdet_lite0_phager=dict( # HACK VERSION
name='efficientdet_lite0',
backbone_name='efficientnet_lite0',
backbone_name='efficientnet_lite0', # HACK
image_size=512,
fpn_channels=64,
fpn_cell_repeats=3,
box_class_repeats=3,
act_type='relu',
redundant_bias=False,
backbone_args=dict(drop_path_rate=0.1),
pad_type='',
pad_type='', # do not mimick TF padding
# unlike other tf_ models, this was not ported from tf automl impl, but trained from tf pretrained efficient lite
# weights using this code, will likely replace if/when official det-lite weights are released
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_lite0-f5f303a9.pth',
Expand Down
34 changes: 28 additions & 6 deletions effdet/efficientdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ def forward(self, x):
x = self.act(x)
return x

# phager: FIX https://discuss.pytorch.org/t/using-nn-function-interpolate-inside-nn-sequential/23588
class Interpolate(nn.Module):
def __init__(self, scale_factor, mode):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode

def forward(self, x):
# FIX: https://github.com/pytorch/pytorch/issues/27376 + simpler but not yet really nice...
# issues are actualy batch size and channel interposlation
x = self.interp(x, size=[int(self.scale_factor * x.shape[2]), int(self.scale_factor * x.shape[3])], mode=self.mode)
#x = self.interp(x, scale_factor=[int(self.scale_factor), int(self.scale_factor)], mode=self.mode)
return x

class ResampleFeatureMap(nn.Sequential):

Expand Down Expand Up @@ -129,7 +143,9 @@ def __init__(self, in_channels, out_channels, reduction_ratio=1., pad_type='', p
#print("YYY "+ str(int(1 // reduction_ratio)))
#assert int(1 // reduction_ratio) == 2 # HACK: making stuff static...
#scale = 2
self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale))
# phager: FIX FOR https://github.com/pytorch/pytorch/issues/27376 -> adding interpolation module
# OLD: self.add_module('upsample', nn.UpsamplingNearest2d(scale_factor=scale))
self.add_module('upsample', Interpolate(scale_factor=scale, mode='nearest'))

# def forward(self, x):
# # here for debugging only
Expand Down Expand Up @@ -183,26 +199,32 @@ def __init__(self, feature_info, fpn_config, fpn_channels, inputs_offsets, targe
def forward(self, x):
dtype = x[0].dtype
nodes = []
sum_nodes = []
for offset in self.inputs_offsets:
input_node = x[offset]
input_node = self.resample[str(offset)](input_node)
nodes.append(input_node)

if self.weight_method == 'attn':
normalized_weights = torch.softmax(self.edge_weights.type(dtype), dim=0)
x = torch.stack(nodes, dim=-1) * normalized_weights
sum_nodes = [node * normalized_weights[i] for node, i in enumerate(nodes)]
#x = torch.stack(nodes, dim=-1) * normalized_weights
elif self.weight_method == 'fastattn':
print("FPN detach")
edge_weights = nn.functional.relu(self.edge_weights.type(dtype))
weights_sum = torch.sum(edge_weights)
weights_sum = weights_sum.item()
x = torch.stack(
[(nodes[i] * edge_weights[i].item()) / (weights_sum + 0.0001) for i in range(len(nodes))], dim=-1)
#x = torch.stack([nodes[i] * (edge_weights[i].item() / (weights_sum + 0.0001)) for i in range(len(nodes))], dim=-1)
sum_nodes = [nodes[i] * (edge_weights[i].item() / (weights_sum + 0.0001)) for i in range(len(nodes))]
elif self.weight_method == 'sum':
x = torch.stack(nodes, dim=-1)
sum_nodes = nodes
#x = torch.stack(nodes, dim=-1)
else:
raise ValueError('unknown weight_method {}'.format(self.weight_method))
x = torch.sum(x, dim=-1)
#x = torch.sum(x, dim=-1)
x = sum_nodes[0]
for i in range(1,len(sum_nodes)):
x += sum_nodes[i]
return x


Expand Down

0 comments on commit 0a8d426

Please sign in to comment.