Skip to content

Commit

Permalink
Update backbone.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xoiga123 authored Oct 11, 2022
1 parent d53d2ea commit 48ff4a0
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def __init__(self, num_classes=80, compound_coef=0, seg_classes=1, backbone_name
weights='imagenet',
)

if not onnx_export:
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
onnx_export=onnx_export,
**kwargs)
else:
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
onnx_export=onnx_export,
**kwargs)
if onnx_export:
## TODO: timm
self.encoder.set_swish(memory_efficient=False)

Expand Down Expand Up @@ -121,9 +120,9 @@ def forward(self, inputs):

regression = self.regressor(features)
classification = self.classifier(features)
anchors = self.anchors(inputs, inputs.dtype)

if not self.onnx_export:
anchors = self.anchors(inputs, inputs.dtype)
return features, regression, classification, anchors, segmentation
else:
return regression, classification, segmentation
Expand Down

0 comments on commit 48ff4a0

Please sign in to comment.