diff --git a/models/efficientnet.py b/models/efficientnet.py index 4200f2b..72e716a 100644 --- a/models/efficientnet.py +++ b/models/efficientnet.py @@ -456,6 +456,22 @@ def efficientnet_b0(*args, **kwargs): model.num_classes = ... return model + +@register_model +def efficientnet_b1(*args, **kwargs): + model = _efficientnet_b0(**kwargs) + model.name = "EfficentNetB1" + model.num_classes = ... + return model + + +@register_model +def efficientnet_b2(*args, **kwargs): + model = _efficientnet_b0(**kwargs) + model.name = "EfficentNetB2" + model.num_classes = ... + return model + @register_model def efficientnet_b0_eff(*args, **kwargs): bneck_conf = partial(MBConvConfig, width_mult=1.0, depth_mult=1.0)