diff --git a/surrogates_overview/scripts/cifar_label_map.py b/surrogates_overview/scripts/cifar_label_map.py new file mode 100644 index 0000000..ee419cc --- /dev/null +++ b/surrogates_overview/scripts/cifar_label_map.py @@ -0,0 +1,166 @@ +""" +CIFAR 10 & 100 Labels Map +========================= + +This module provides a map from a class id to label. +Two maps are available: + +* ``CIFAR10_LABEL_MAP`` -- maps class ids to labels for CIFAR10; and +* ``CIFAR100_LABEL_MAP`` -- maps class ids to labels for CIFAR100. + +The data set files needed to regenerate the label maps are available at +. + +See +for more details. +""" +# Author: Kacper Sokol +# License: new BSD + + +import pickle + + +def _load_cifar10_labels(data_folder): + """Generates the label map for CIFAR10.""" + with open(f'{data_folder}/cifar-10-batches-py/batches.meta', 'rb') as fo: + cf10meta = pickle.load(fo, encoding='bytes') + + cf10labels = {i: j.decode() + for i, j in enumerate(cf10meta.get(b'label_names'))} + + return cf10labels + + +def _load_cifar100_labels(data_folder, fine_labels=True): + """Generates the label map for CIFAR100.""" + with open(f'{data_folder}/cifar-100-python/meta', 'rb') as fo: + cf100meta = pickle.load(fo, encoding='bytes') + + if fine_labels: + cf100_labels_type = b'fine_label_names' + else: + cf100_labels_type = b'coarse_label_names' + + cf100labels = {i: j.decode() + for i, j in enumerate(cf100meta.get(cf100_labels_type))} + + return cf100labels + + +CIFAR10_LABEL_MAP = { + 0: 'airplane', + 1: 'automobile', + 2: 'bird', + 3: 'cat', + 4: 'deer', + 5: 'dog', + 6: 'frog', + 7: 'horse', + 8: 'ship', + 9: 'truck' +} + + +CIFAR100_LABEL_MAP = { + 0: 'apple', + 1: 'aquarium_fish', + 2: 'baby', + 3: 'bear', + 4: 'beaver', + 5: 'bed', + 6: 'bee', + 7: 'beetle', + 8: 'bicycle', + 9: 'bottle', + 10: 'bowl', + 11: 'boy', + 12: 'bridge', + 13: 'bus', + 14: 'butterfly', + 15: 'camel', + 16: 'can', + 17: 'castle', + 18: 'caterpillar', + 19: 'cattle', + 20: 'chair', + 21: 'chimpanzee', + 22: 'clock', + 23: 'cloud', + 24: 'cockroach', + 25: 'couch', + 26: 'crab', + 27: 'crocodile', + 28: 'cup', + 29: 'dinosaur', + 30: 'dolphin', + 31: 'elephant', + 32: 'flatfish', + 33: 'forest', + 34: 'fox', + 35: 'girl', + 36: 'hamster', + 37: 'house', + 38: 'kangaroo', + 39: 'keyboard', + 40: 'lamp', + 41: 'lawn_mower', + 42: 'leopard', + 43: 'lion', + 44: 'lizard', + 45: 'lobster', + 46: 'man', + 47: 'maple_tree', + 48: 'motorcycle', + 49: 'mountain', + 50: 'mouse', + 51: 'mushroom', + 52: 'oak_tree', + 53: 'orange', + 54: 'orchid', + 55: 'otter', + 56: 'palm_tree', + 57: 'pear', + 58: 'pickup_truck', + 59: 'pine_tree', + 60: 'plain', + 61: 'plate', + 62: 'poppy', + 63: 'porcupine', + 64: 'possum', + 65: 'rabbit', + 66: 'raccoon', + 67: 'ray', + 68: 'road', + 69: 'rocket', + 70: 'rose', + 71: 'sea', + 72: 'seal', + 73: 'shark', + 74: 'shrew', + 75: 'skunk', + 76: 'skyscraper', + 77: 'snail', + 78: 'snake', + 79: 'spider', + 80: 'squirrel', + 81: 'streetcar', + 82: 'sunflower', + 83: 'sweet_pepper', + 84: 'table', + 85: 'tank', + 86: 'telephone', + 87: 'television', + 88: 'tiger', + 89: 'tractor', + 90: 'train', + 91: 'trout', + 92: 'tulip', + 93: 'turtle', + 94: 'wardrobe', + 95: 'whale', + 96: 'willow_tree', + 97: 'wolf', + 98: 'woman', + 99: 'worm' +} diff --git a/surrogates_overview/scripts/image_classifier.py b/surrogates_overview/scripts/image_classifier.py index 14d058c..7319ae8 100644 --- a/surrogates_overview/scripts/image_classifier.py +++ b/surrogates_overview/scripts/image_classifier.py @@ -2,8 +2,11 @@ Image Classifier ================ -This module implements an image classifier based on PyTorch. -Inception v3 and AlexNet are availabel. +This module implements image classifiers based on PyTorch. +Inception v3 and AlexNet are available for ImageNet; +ResNet56 is available for CIFAR10; and +RepVGG (a2) is available for CIFAR100. + See for more details. """ @@ -11,6 +14,7 @@ # License: new BSD from scripts.imagenet_label_map import IMAGENET_LABEL_MAP +from scripts.cifar_label_map import CIFAR10_LABEL_MAP, CIFAR100_LABEL_MAP import numpy as np @@ -33,6 +37,26 @@ def _get_preprocess_transform(): return transf +def _get_preprocess_transform_cifar10(): + # https://github.com/chenyaofo/pytorch-cifar-models/issues/4 + # https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar10.conf + normalize = transforms.Normalize( + mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + transf = transforms.Compose([transforms.ToTensor(), normalize]) + + return transf + + +def _get_preprocess_transform_cifar100(): + # https://github.com/chenyaofo/pytorch-cifar-models/issues/4 + # https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar100.conf + normalize = transforms.Normalize( + mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]) + transf = transforms.Compose([transforms.ToTensor(), normalize]) + + return transf + + class ImageClassifier(object): """Image classifier based on PyTorch.""" @@ -128,3 +152,77 @@ def proba2tuple(self, Y, labels_no=5): tuples_.append((lab, Y[idx, cls], cls)) tuples.append(tuples_) return tuples + + +class ImageNetClassifier(ImageClassifier): + """ImageNet classifiers -- Inception v3 & AlexNet -- based on PyTorch.""" + + +class Cifar10Classifier(ImageClassifier): + """CIFAR10 classifiers -- ResNet56 -- based on PyTorch.""" + + def __init__(self, use_gpu=False): + """Initialises the image classifier.""" + # Get class labels + self.class_idx = CIFAR10_LABEL_MAP + + # Get the model + # https://github.com/huyvnphan/PyTorch_CIFAR10 + clf = torch.hub.load( + 'chenyaofo/pytorch-cifar-models', + 'cifar10_resnet56', + pretrained=True) + + if use_gpu: + if CUDA_AVAILABLE: + clf = clf.to(DEVICE) + # clf.cuda() + predict_proba = self._predict_proba_gpu + else: + logger.warning('GPU was requested but it is not available. ' + 'Using CPU instead.') + predict_proba = self._predict_proba_cpu + else: + predict_proba = self._predict_proba_cpu + self.predict_proba = predict_proba + + self.clf = clf + self.clf.eval() + + # Get transformation + self.preprocess_transform = _get_preprocess_transform_cifar10() + + +class Cifar100Classifier(ImageClassifier): + """CIFAR100 classifiers -- RepVGG (a2) -- based on PyTorch.""" + + def __init__(self, use_gpu=False): + """Initialises the image classifier.""" + # Get class labels + self.class_idx = CIFAR100_LABEL_MAP + + # Get the model + # https://github.com/huyvnphan/PyTorch_CIFAR10 + clf = torch.hub.load( + 'chenyaofo/pytorch-cifar-models', + 'cifar100_repvgg_a2', + pretrained=True) + + if use_gpu: + if CUDA_AVAILABLE: + clf = clf.to(DEVICE) + # clf.cuda() + predict_proba = self._predict_proba_gpu + else: + logger.warning('GPU was requested but it is not available. ' + 'Using CPU instead.') + predict_proba = self._predict_proba_cpu + else: + predict_proba = self._predict_proba_cpu + self.predict_proba = predict_proba + + self.clf = clf + self.clf.eval() + + # Get transformation + self.preprocess_transform = _get_preprocess_transform_cifar100()