Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor feat extractor #37

Merged
merged 4 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions modeling/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Base Class
class ExtractorModel:
name: str
dim: int
model: torch.nn.Module
preprocess: Callable

Expand All @@ -47,6 +48,7 @@ class ExtractorModel:
# ConvNext Models
class ConvnextBaseExtractor(ExtractorModel):
name = "convnext_base"
dim = 1024

def __init__(self):
self.model = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1)
Expand All @@ -56,6 +58,7 @@ def __init__(self):

class ConvnextTinyExtractor(ExtractorModel):
name = "convnext_tiny"
dim = 768

def __init__(self):
self.model = convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
Expand All @@ -65,6 +68,7 @@ def __init__(self):

class ConvnextSmallExtractor(ExtractorModel):
name = "convnext_small"
dim = 768

def __init__(self):
self.model = convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1)
Expand All @@ -74,6 +78,7 @@ def __init__(self):

class ConvnextLargeExtractor(ExtractorModel):
name = "convnext_lg"
dim = 1536

def __init__(self):
self.model = convnext_large(weights=ConvNeXt_Large_Weights.IMAGENET1K_V1)
Expand All @@ -85,33 +90,37 @@ def __init__(self):
# DenseNet Models
class Densenet121Extractor(ExtractorModel):
name = "densenet121"
dim = 1024

def __init__(self):
self.model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
self.model.classifier = torch.nn.Identity()
self.preprocess = DenseNet121_Weights.IMAGENET1K_V1.transforms()


class Densenet161Extractor():
class Densenet161Extractor(ExtractorModel):
name = "densenet161"
dim = 2208

def __init__(self):
self.model = densenet161(weights=DenseNet161_Weights.IMAGENET1K_V1)
self.model.classifier = torch.nn.Identity()
self.preprocess = DenseNet161_Weights.IMAGENET1K_V1.transforms()


class Densenet169Extractor():
class Densenet169Extractor(ExtractorModel):
name = "densenet169"
dim = 1664

def __init__(self):
self.model = densenet169(weights=DenseNet169_Weights.IMAGENET1K_V1)
self.model.classifier = torch.nn.Identity()
self.preprocess = DenseNet169_Weights.IMAGENET1K_V1.transforms()


class Densenet201Extractor():
class Densenet201Extractor(ExtractorModel):
name = "densenet201"
dim = 1920

def __init__(self):
self.model = densenet201(weights=DenseNet201_Weights.IMAGENET1K_V1)
Expand All @@ -123,6 +132,7 @@ def __init__(self):
# EfficientNet Models
class EfficientnetSmallExtractor(ExtractorModel):
name = "efficientnet_small"
dim = 1280

def __init__(self):
self.model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
Expand All @@ -132,6 +142,7 @@ def __init__(self):

class EfficientnetMediumExtractor(ExtractorModel):
name = "efficientnet_med"
dim = 1280

def __init__(self):
self.model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
Expand All @@ -141,6 +152,7 @@ def __init__(self):

class EfficientnetLargeExtractor(ExtractorModel):
name = "efficientnet_large"
dim = 1280

def __init__(self):
self.model = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.IMAGENET1K_V1)
Expand All @@ -164,6 +176,7 @@ def __init__(self):

class Resnet18Extractor(ExtractorModel):
name = "resnet18"
dim = 512

def __init__(self):
self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
Expand All @@ -173,6 +186,7 @@ def __init__(self):

class Resnet50Extractor(ExtractorModel):
name = "resnet50"
dim = 2048

def __init__(self):
self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
Expand All @@ -182,6 +196,7 @@ def __init__(self):

class Resnet101Extractor(ExtractorModel):
name = "resnet101"
dim = 2048

def __init__(self):
self.model = resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)
Expand All @@ -191,6 +206,7 @@ def __init__(self):

class Resnet152Extractor(ExtractorModel):
name = "resnet152"
dim = 2048

def __init__(self):
self.model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
Expand All @@ -202,6 +218,7 @@ def __init__(self):
# VGG Models
class Vgg16Extractor(ExtractorModel):
name = "vgg16"
dim = 4096

def __init__(self):
self.model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
Expand All @@ -211,6 +228,7 @@ def __init__(self):

class BN_Vgg16Extractor(ExtractorModel):
name = "bn_vgg16"
dim = 4096

def __init__(self):
self.model = vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1)
Expand All @@ -220,6 +238,7 @@ def __init__(self):

class Vgg19Extractor(ExtractorModel):
name = "vgg19"
dim = 4096

def __init__(self):
self.model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1)
Expand All @@ -229,6 +248,7 @@ def __init__(self):

class BN_VGG19Extractor(ExtractorModel):
name = "bn_vgg19"
dim = 4096

def __init__(self):
self.model = vgg19_bn(weights=VGG19_BN_Weights.IMAGENET1K_V1)
Expand All @@ -242,6 +262,10 @@ def __init__(self):
model.name: model for model
in sys.modules[__name__].ExtractorModel.__subclasses__() if model.name != 'inceptionv3'}

model_dim_map = {
model.name: model.dim for model
in sys.modules[__name__].ExtractorModel.__subclasses__() if model.name != 'inceptionv3'}

if __name__ == "__main__":
import numpy as np
dummy_guid = 'cpb-aacip-fe9efa663c6'
Expand Down
Loading
Loading