Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visual Perturbations #63

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions code_soup/common/perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod


class Perturbation(ABC):
"""
Docstring for Abstract Class Perturbation
"""

@classmethod
@abstractmethod
def __init__(self):
pass
70 changes: 70 additions & 0 deletions code_soup/common/vision/perturbations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from abc import abstractmethod
from typing import Union

import numpy as np
import torch
import torch.nn as nn

from math import log10

from code_soup.common.perturbation import Perturbation


class VisualPerturbation(Perturbation):
"""
An abstract method for various Visual Perturbation Metrics
Methods
__init__(self, original : Union[np.ndarray, torch.Tensor], perturbed: Union[np.ndarray, torch.Tensor])
- init method
"""

def __init__(
self,
original: Union[np.ndarray, torch.Tensor],
perturbed: Union[np.ndarray, torch.Tensor],
):
"""
Docstring
#Automatically cast to Tensor using the torch.from_numpy() in the __init__ using if
"""

if type(original) == torch.Tensor:
self.original = original
else:
self.original = torch.from_numpy(original)
print(self.original.shape)

if type(perturbed) == torch.Tensor:
self.perturbed = perturbed
else:
self.perturbed = torch.from_numpy(perturbed)

def flatten(self, array : torch.tensor) -> torch.Tensor:
return array.flatten()

def totensor(self, array : np.ndarray) -> torch.Tensor:
return torch.from_numpy(array)

def subtract(self,original : torch.Tensor, perturbed : torch.Tensor) -> torch.Tensor:
return torch.sub(original, perturbed)

def calculate_LPNorm(self, p: Union[int, str]) -> float:
if p == 'inf':
return torch.linalg.vector_norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = float('inf')).item()
elif p == 'fro':
return self.calculate_LPNorm(2)
else:
return torch.linalg.norm(self.flatten(self.subtract(self.original,self.perturbed)), ord = p).item()

def calculate_PSNR(self) -> float:
return 20 * log10(1.0/self.calculate_RMSE())

def calculate_RMSE(self) -> float:
loss = nn.MSELoss()
return (loss(self.original, self.perturbed)**0.5).item()

def calculate_SAM(self):
raise NotImplementedError

def calculate_SRE(self):
raise NotImplementedError
60 changes: 60 additions & 0 deletions tests/test_common/test_vision/test_perturbations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random
import unittest

import numpy as np
import torch
from torchvision.datasets.fakedata import FakeData
from torchvision.transforms import ToTensor

from code_soup.common.vision.perturbations import VisualPerturbation


class TestVisualPerturbation(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
df = FakeData(size=2, image_size=(3, 64, 64))
a, b = tuple(df)
a, b = ToTensor()(a[0]).unsqueeze_(0), ToTensor()(b[0]).unsqueeze_(0)
cls.obj_tensor = VisualPerturbation(original=a, perturbed=b)
cls.obj_numpy = VisualPerturbation(original=a.numpy(), perturbed=b.numpy())

def test_LPNorm(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_LPNorm(p=1), 4143.0249, places=3
)
self.assertAlmostEqual(
TestVisualPerturbation.obj_numpy.calculate_LPNorm(p="fro"),
45.6525,
places=3,
)

def test_PSNR(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_PSNR(),
33.773994480876496,
places=3,
)

def test_RMSE(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_RMSE(),
0.018409499898552895,
places=3,
)

def test_SAM(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_SAM(),
89.34839413786915,
places=3,
)

def test_SRE(self):
self.assertAlmostEqual(
TestVisualPerturbation.obj_tensor.calculate_SRE(),
41.36633261587073,
places=3,
)