diff --git a/code_soup/common/perturbation.py b/code_soup/common/perturbation.py new file mode 100644 index 0000000..b2e023d --- /dev/null +++ b/code_soup/common/perturbation.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + + +class Perturbation(ABC): + """ + Docstring for Abstract Class Perturbation + """ + + @classmethod + @abstractmethod + def __init__(self): + pass diff --git a/code_soup/common/vision/perturbations.py b/code_soup/common/vision/perturbations.py new file mode 100644 index 0000000..bdc85cf --- /dev/null +++ b/code_soup/common/vision/perturbations.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +from typing import Union + +import numpy as np +import torch +import torch.nn as nn + +from math import log10 + +from code_soup.common.perturbation import Perturbation + + +class VisualPerturbation(Perturbation): + """ + An abstract method for various Visual Perturbation Metrics + Methods + __init__(self, original : Union[np.ndarray, torch.Tensor], perturbed: Union[np.ndarray, torch.Tensor]) + - init method + """ + + def __init__( + self, + original: Union[np.ndarray, torch.Tensor], + perturbed: Union[np.ndarray, torch.Tensor], + ): + """ + Docstring + #Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if + """ + + if type(original) == torch.Tensor: + self.original = original + else: + self.original = torch.from_numpy(original) + print(self.original.shape) + + if type(perturbed) == torch.Tensor: + self.perturbed = perturbed + else: + self.perturbed = torch.from_numpy(perturbed) + + def flatten(self, array : torch.tensor) -> torch.Tensor: + return array.flatten() + + def totensor(self, array : np.ndarray) -> torch.Tensor: + return torch.from_numpy(array) + + def subtract(self,original : torch.Tensor, perturbed : torch.Tensor) -> torch.Tensor: + return torch.sub(original, perturbed) + + def calculate_LPNorm(self, p: Union[int, str]) -> float: + if p == 'inf': + return torch.linalg.vector_norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = float('inf')).item() + elif p == 'fro': + return self.calculate_LPNorm(2) + else: + return torch.linalg.norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = p).item() + + def calculate_PSNR(self) -> float: + return 20 * log10(1.0/self.calculate_RMSE()) + + def calculate_RMSE(self) -> float: + loss = nn.MSELoss() + return (loss(self.original, self.perturbed)**0.5).item() + + def calculate_SAM(self): + raise NotImplementedError + + def calculate_SRE(self): + raise NotImplementedError diff --git a/tests/test_common/test_vision/test_perturbations.py b/tests/test_common/test_vision/test_perturbations.py new file mode 100644 index 0000000..eb71992 --- /dev/null +++ b/tests/test_common/test_vision/test_perturbations.py @@ -0,0 +1,60 @@ +import random +import unittest + +import numpy as np +import torch +from torchvision.datasets.fakedata import FakeData +from torchvision.transforms import ToTensor + +from code_soup.common.vision.perturbations import VisualPerturbation + + +class TestVisualPerturbation(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + torch.manual_seed(42) + np.random.seed(42) + random.seed(42) + df = FakeData(size=2, image_size=(3, 64, 64)) + a, b = tuple(df) + a, b = ToTensor()(a[0]).unsqueeze_(0), ToTensor()(b[0]).unsqueeze_(0) + cls.obj_tensor = VisualPerturbation(original=a, perturbed=b) + cls.obj_numpy = VisualPerturbation(original=a.numpy(), perturbed=b.numpy()) + + def test_LPNorm(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_LPNorm(p=1), 4143.0249, places=3 + ) + self.assertAlmostEqual( + TestVisualPerturbation.obj_numpy.calculate_LPNorm(p="fro"), + 45.6525, + places=3, + ) + + def test_PSNR(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_PSNR(), + 33.773994480876496, + places=3, + ) + + def test_RMSE(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_RMSE(), + 0.018409499898552895, + places=3, + ) + + def test_SAM(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_SAM(), + 89.34839413786915, + places=3, + ) + + def test_SRE(self): + self.assertAlmostEqual( + TestVisualPerturbation.obj_tensor.calculate_SRE(), + 41.36633261587073, + places=3, + )