diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..030d9e7 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: pylint + +on: + pull_request: + workflow_dispatch: + +jobs: + checks: + runs-on: ubuntu-20.04 + strategy: + max-parallel: 4 + matrix: + python-version: [3.7, 3.9] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + - name: Check lint + run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .)-lint \ No newline at end of file diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml new file mode 100644 index 0000000..9a87e95 --- /dev/null +++ b/.github/workflows/tox.yml @@ -0,0 +1,26 @@ +name: tox + +on: + pull_request: + workflow_dispatch: + +jobs: + checks: + runs-on: ubuntu-20.04 + strategy: + max-parallel: 4 + matrix: + python-version: [3.7, 3.9] + + steps: + - uses: actions/checkout@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox + - name: Test with tox + run: tox -e py$(echo ${{ matrix.python-version }} | tr -d .) \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3678d7e..6c411a6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# pre-compiled spectrum and models +*.npy +*.h5 + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/.pylintrc b/.pylintrc index 0e09569..e32b347 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,7 +6,7 @@ disable= E1123, # issues between pylint and tensorflow since 2.2.0 E1120, # see pylint#3613 C3001, # lambda function as variable - + C0116, C0114, # docstring [FORMAT] max-line-length=100 max-args=12 @@ -15,4 +15,7 @@ max-args=12 min-similarity-lines=6 ignore-comments=yes ignore-docstrings=yes -ignore-imports=no \ No newline at end of file +ignore-imports=no + +[TYPECHECK] +ignored-modules=torch diff --git a/README.md b/README.md index ec5eb2c..9b595a4 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ # Horama: A Compact Library for Feature Visualization Experiments +@todo add notebooks +@todo add illustration images + logo + Horama provides the implementation code for the research paper: - *Unlocking Feature Visualization for Deeper Networks with MAgnitude Constrained Optimization* by Thomas Fel*, Thibaut Boissin*, Victor Boutin*, Agustin Picard*, Paul Novello*, Julien Colin, Drew Linsley, Tom Rousseau, Rémi Cadène, Laurent Gardes, Thomas Serre. [Read the paper on arXiv](https://arxiv.org/abs/2211.10154). -In addition, this repository introduces various feature visualization methods, including a reimagined approach to the [incredible work of the Clarity team](https://distill.pub/2017/feature-visualization/) and an implementation of [Feature Accentuation](https://arxiv.org/abs/2402.10039) from Hamblin & al. For an official reproduction of distill's work complete with comprehensive notebooks, we highly recommend Lucent. However, Horama focuses on experimentation within PyTorch, offering a compact and modifiable codebase. +In addition, this repository introduces various feature visualization methods, including a reimagined approach to the [incredible work of the Clarity team](https://distill.pub/2017/feature-visualization/) and an implementation of [Feature Accentuation](https://arxiv.org/abs/2402.10039) from Hamblin & al. For an official reproduction of distill's work complete with comprehensive notebooks, we highly recommend [Lucent](https://github.com/greentfrapp/lucent). However, Horama focuses on **experimentation** within PyTorch, offering a compact and easily hackable codebase. # 🚀 Getting Started with Horama @@ -31,11 +34,21 @@ objective = lambda images: torch.mean(model(images)[:, 1]) image1, alpha1 = maco(objective) plot_maco(image1, alpha1) +plt.show() image2, alpha2 = fourier(objective) plot_maco(image2, alpha2) +plt.show() ``` +# Notebooks + +@todo: fourier, maco for various models on timm +@todo: cossim vs logits +@todo: speedup process, what parameters to change +@todo: feature inversion +@todo: feature accentuation + # Complete API Complete API Guide @@ -76,15 +89,17 @@ When optimizing, it's crucial to fine-tune the hyperparameters. Parameters like @article{fel2023maco, title={Unlocking Feature Visualization for Deeper Networks with MAgnitude Constrained Optimization}, author={Thomas, Fel and Thibaut, Boissin and Victor, Boutin and Agustin, Picard and Paul, Novello and Julien, Colin and Drew, Linsley and Tom, Rousseau and Rémi, Cadène and Laurent, Gardes and Thomas, Serre}, + journal={Advances in Neural Information Processing Systems (NeurIPS)}, year={2023}, } ``` # Additional Resources -For a simpler and maintenance-friendly implementation for TensorFlow and more on feature visualization methods, check out the Xplique toolbox. A simpler and maintain implementation of the code for Tensorflow and the other feature visualization methods used in the paper come from the [Xplique toolbox](https://github.com/deel-ai/xplique). Additionally, we have created a website called the [LENS Project](https://github.com/serre-lab/Lens), which features the 1000 classes of ImageNet. -# Authors of the code +For a code faithful to the original work of the Clarity team, we highly recommend [Lucent](https://github.com/greentfrapp/lucent). + +# Authors - [Thomas Fel](https://thomasfel.fr) - thomas_fel@brown.edu, PhD Student DEEL (ANITI), Brown University \ No newline at end of file diff --git a/horama/__init__.py b/horama/__init__.py index bd48996..7050cc3 100644 --- a/horama/__init__.py +++ b/horama/__init__.py @@ -14,4 +14,4 @@ from .maco_fv import maco from .fourier_fv import fourier from .plots import plot_maco -from .losses import dot_cossim \ No newline at end of file +from .losses import dot_cossim diff --git a/horama/common.py b/horama/common.py index c0e3f24..97de5e2 100644 --- a/horama/common.py +++ b/horama/common.py @@ -1,13 +1,17 @@ import torch from torchvision.ops import roi_align + def standardize(tensor): # standardizes the tensor to have 0 mean and unit variance tensor = tensor - torch.mean(tensor) tensor = tensor / (torch.std(tensor) + 1e-4) return tensor + def recorrelate_colors(image, device): + # recorrelates the colors of the images + assert len(image.shape) == 3 # tensor for color correlation svd square root color_correlation_svd_sqrt = torch.tensor( @@ -17,9 +21,6 @@ def recorrelate_colors(image, device): dtype=torch.float32 ).to(device) - # recorrelates the colors of the images - assert len(image.shape) == 3 - permuted_image = image.permute(1, 2, 0).contiguous() flat_image = permuted_image.view(-1, 3) @@ -28,8 +29,11 @@ def recorrelate_colors(image, device): return recorrelated_image -def optimization_step(objective_function, image, box_size, noise_level, number_of_crops_per_iteration, model_input_size): + +def optimization_step(objective_function, image, box_size, noise_level, + number_of_crops_per_iteration, model_input_size): # performs an optimization step on the generated image + # pylint: disable=C0103 assert box_size[1] >= box_size[0] assert len(image.shape) == 3 @@ -39,7 +43,8 @@ def optimization_step(objective_function, image, box_size, noise_level, number_o # generate random boxes x0 = 0.5 + torch.randn((number_of_crops_per_iteration,), device=device) * 0.15 y0 = 0.5 + torch.randn((number_of_crops_per_iteration,), device=device) * 0.15 - delta_x = torch.rand((number_of_crops_per_iteration,), device=device) * (box_size[1] - box_size[0]) + box_size[1] + delta_x = torch.rand((number_of_crops_per_iteration,), + device=device) * (box_size[1] - box_size[0]) + box_size[1] delta_y = delta_x boxes = torch.stack([torch.zeros((number_of_crops_per_iteration,), device=device), @@ -48,11 +53,13 @@ def optimization_step(objective_function, image, box_size, noise_level, number_o x0 + delta_x * 0.5, y0 + delta_y * 0.5], dim=1) * image.shape[1] - cropped_and_resized_images = roi_align(image.unsqueeze(0), boxes, output_size=(model_input_size, model_input_size)).squeeze(0) + cropped_and_resized_images = roi_align(image.unsqueeze( + 0), boxes, output_size=(model_input_size, model_input_size)).squeeze(0) # add normal and uniform noise for better robustness cropped_and_resized_images.add_(torch.randn_like(cropped_and_resized_images) * noise_level) - cropped_and_resized_images.add_((torch.rand_like(cropped_and_resized_images) - 0.5) * noise_level) + cropped_and_resized_images.add_( + (torch.rand_like(cropped_and_resized_images) - 0.5) * noise_level) # compute the score and loss score = objective_function(cropped_and_resized_images) diff --git a/horama/fourier_fv.py b/horama/fourier_fv.py index 3fb3523..16e13f7 100644 --- a/horama/fourier_fv.py +++ b/horama/fourier_fv.py @@ -1,7 +1,9 @@ import torch -from .common import standardize, recorrelate_colors, optimization_step from tqdm import tqdm +from .common import standardize, recorrelate_colors, optimization_step + + def fft_2d_freq(width, height): # calculate the 2D frequency grid for FFT freq_y = torch.fft.fftfreq(height).unsqueeze(1) @@ -11,21 +13,26 @@ def fft_2d_freq(width, height): return torch.sqrt(freq_x**2 + freq_y**2) + def get_fft_scale(width, height, decay_power=1.0): - # generate the FFT scale based on the image size and decay power + # generate the scaler that account for power decay in FFT space frequencies = fft_2d_freq(width, height) - fft_scale = 1.0 / torch.maximum(frequencies, torch.tensor(1.0 / max(width, height))) ** decay_power + fft_scale = 1.0 / torch.maximum(frequencies, + torch.tensor(1.0 / max(width, height))) ** decay_power fft_scale = fft_scale * torch.sqrt(torch.tensor(width * height).float()) return fft_scale.to(torch.complex64) -def init_olah_buffer(width, height, std=1.0): - # initialize the Olah buffer with a random spectrum + +def init_lucid_buffer(width, height, std=1.0): + # initialize the buffer with a random spectrum a la Lucid spectrum_shape = (3, width, height // 2 + 1) - random_spectrum = torch.complex(torch.randn(spectrum_shape) * std, torch.randn(spectrum_shape) * std) + random_spectrum = torch.complex(torch.randn(spectrum_shape) * std, + torch.randn(spectrum_shape) * std) return random_spectrum + def fourier_preconditionner(spectrum, spectrum_scaler, values_range, device): # precondition the Fourier spectrum and convert it to spatial domain assert spectrum.shape[0] == 3 @@ -37,16 +44,21 @@ def fourier_preconditionner(spectrum, spectrum_scaler, values_range, device): spatial_image = standardize(spatial_image) color_recorrelated_image = recorrelate_colors(spatial_image, device) - image = torch.sigmoid(color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0] + image = torch.sigmoid( + color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0] return image -def fourier(objective_function, decay_power=1.5, total_steps=1000, learning_rate=1.0, image_size=1280, model_input_size=224, - noise=0.05, values_range=(-2.5, 2.5), crops_per_iteration=6, box_size=(0.20, 0.25), device='cuda'): - # perform the Olah optimization process + +def fourier( + objective_function, decay_power=1.5, total_steps=1000, learning_rate=1.0, image_size=1280, + model_input_size=224, noise=0.05, values_range=(-2.5, 2.5), + crops_per_iteration=6, box_size=(0.20, 0.25), + device='cuda'): + # perform the Lucid (Olah & al.) optimization process assert values_range[1] >= values_range[0] assert box_size[1] >= box_size[0] - spectrum = init_olah_buffer(image_size, image_size, std=1.0) + spectrum = init_lucid_buffer(image_size, image_size, std=1.0) spectrum_scaler = get_fft_scale(image_size, image_size, decay_power) spectrum = spectrum.to(device) @@ -56,11 +68,12 @@ def fourier(objective_function, decay_power=1.5, total_steps=1000, learning_rate optimizer = torch.optim.NAdam([spectrum], lr=learning_rate) transparency_accumulator = torch.zeros((3, image_size, image_size)).to(device) - for step in tqdm(range(total_steps)): + for _ in tqdm(range(total_steps)): optimizer.zero_grad() image = fourier_preconditionner(spectrum, spectrum_scaler, values_range, device) - loss, img = optimization_step(objective_function, image, box_size, noise, crops_per_iteration, model_input_size) + loss, img = optimization_step(objective_function, image, box_size, + noise, crops_per_iteration, model_input_size) loss.backward() transparency_accumulator += torch.abs(img.grad) optimizer.step() diff --git a/horama/losses.py b/horama/losses.py index ec617a5..2c0fab2 100644 --- a/horama/losses.py +++ b/horama/losses.py @@ -1,14 +1,15 @@ import torch + def cosine_similarity(tensor_a, tensor_b): - # calculate cosine similarity norm_dims = list(range(1, len(tensor_a.shape))) tensor_a = torch.nn.functional.normalize(tensor_a.float(), dim=norm_dims) tensor_b = torch.nn.functional.normalize(tensor_b.float(), dim=norm_dims) return torch.sum(tensor_a * tensor_b, dim=norm_dims) + def dot_cossim(tensor_a, tensor_b, cossim_pow=2.0): - # compute dot product scaled by cosine similarity + # see https://github.com/tensorflow/lucid/issues/116 cosim = torch.clamp(cosine_similarity(tensor_a, tensor_b), min=1e-1) ** cossim_pow dot = torch.sum(tensor_a * tensor_b) return dot * cosim diff --git a/horama/maco_fv.py b/horama/maco_fv.py index 7bd7a9d..96bbe10 100644 --- a/horama/maco_fv.py +++ b/horama/maco_fv.py @@ -8,26 +8,33 @@ from .common import optimization_step, standardize, recorrelate_colors -MACO_SPECTRUM_URL = "https://storage.googleapis.com/serrelab/loupe/spectrums/imagenet_decorrelated.npy" +MACO_SPECTRUM_URL = ("https://storage.googleapis.com/serrelab/loupe/" + "spectrums/imagenet_decorrelated.npy") MACO_SPECTRUM_FILENAME = 'spectrum_decorrelated.npy' + def init_maco_buffer(image_shape, std_deviation=1.0): # initialize the maco buffer with a random phase and a magnitude template spectrum_shape = (image_shape[0], image_shape[1] // 2 + 1) # generate random phase - random_phase = torch.randn(3, *spectrum_shape, dtype=torch.float32) * std_deviation + random_phase = torch.randn( + 3, *spectrum_shape, dtype=torch.float32) * std_deviation # download magnitude template if not exists if not os.path.isfile(MACO_SPECTRUM_FILENAME): - download_url(MACO_SPECTRUM_URL, root=".", filename=MACO_SPECTRUM_FILENAME) + download_url(MACO_SPECTRUM_URL, root=".", + filename=MACO_SPECTRUM_FILENAME) # load and resize magnitude template - magnitude = torch.tensor(np.load(MACO_SPECTRUM_FILENAME), dtype=torch.float32).cuda() - magnitude = F.interpolate(magnitude.unsqueeze(0), size=spectrum_shape, mode='bilinear', align_corners=False, antialias=True)[0] + magnitude = torch.tensor( + np.load(MACO_SPECTRUM_FILENAME), dtype=torch.float32) + magnitude = F.interpolate(magnitude.unsqueeze( + 0), size=spectrum_shape, mode='bilinear', align_corners=False, antialias=True)[0] return magnitude, random_phase -def maco_preconditioner(magnitude_template, phase, values_range): + +def maco_preconditioner(magnitude_template, phase, values_range, device): # apply the maco preconditioner to generate spatial images from magnitude and phase # tfel: check why r exp^(j theta) give slighly diff results standardized_phase = standardize(phase) @@ -40,32 +47,39 @@ def maco_preconditioner(magnitude_template, phase, values_range): # recorrelate colors and adjust value range color_recorrelated_image = recorrelate_colors(spatial_image, device) - final_image = torch.sigmoid(color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0] + final_image = torch.sigmoid( + color_recorrelated_image) * (values_range[1] - values_range[0]) + values_range[0] return final_image -def maco(objective_function, total_steps=1000, learning_rate=1.0, image_size=1280, model_input_size=224, - noise=0.05, values_range=(-2.5, 2.5), crops_per_iteration=6, box_size=(0.20, 0.25), device='cuda'): + +def maco(objective_function, total_steps=1000, learning_rate=1.0, image_size=1280, + model_input_size=224, noise=0.05, values_range=(-2.5, 2.5), + crops_per_iteration=6, box_size=(0.20, 0.25), + device='cuda'): # perform the maco optimization process assert values_range[1] >= values_range[0] assert box_size[1] >= box_size[0] - magnitude, phase = init_maco_buffer((image_size, image_size), std_deviation=1.0) + magnitude, phase = init_maco_buffer( + (image_size, image_size), std_deviation=1.0) magnitude = magnitude.to(device) phase = phase.to(device) phase.requires_grad = True optimizer = torch.optim.NAdam([phase], lr=learning_rate) - transparency_accumulator = torch.zeros((3, image_size, image_size)).to(device) + transparency_accumulator = torch.zeros( + (3, image_size, image_size)).to(device) - for step in tqdm(range(total_steps)): + for _ in tqdm(range(total_steps)): optimizer.zero_grad() # preprocess and compute loss - img = maco_preconditioner(magnitude, phase, values_range) - loss, img = optimization_step(objective_function, img, box_size, noise, crops_per_iteration, model_input_size, device) + img = maco_preconditioner(magnitude, phase, values_range, device) + loss, img = optimization_step( + objective_function, img, box_size, noise, crops_per_iteration, model_input_size) loss.backward() - # get dy/dx to update transparency mask + # get dL/dx to update transparency mask transparency_accumulator += torch.abs(img.grad) optimizer.step() diff --git a/horama/plots.py b/horama/plots.py index 3aad048..6416901 100644 --- a/horama/plots.py +++ b/horama/plots.py @@ -2,10 +2,12 @@ import matplotlib.pyplot as plt import torch + def to_numpy(tensor): # Ensure tensor is on CPU and convert to NumPy return tensor.detach().cpu().numpy() + def check_format(arr): # ensure numpy array and move channels to the last dimension # if they are in the first dimension @@ -15,6 +17,7 @@ def check_format(arr): return np.moveaxis(arr, 0, -1) return arr + def normalize(image): # normalize image to 0-1 range image = np.array(image, dtype=np.float32) @@ -22,9 +25,11 @@ def normalize(image): image /= image.max() return image -def clip_percentile(img, p=0.1): + +def clip_percentile(img, percentile=0.1): # clip pixel values to specified percentile range - return np.clip(img, np.percentile(img, p), np.percentile(img, 100-p)) + return np.clip(img, np.percentile(img, percentile), np.percentile(img, 100-percentile)) + def show(img, **kwargs): # display image with normalization and channels in the last dimension @@ -33,7 +38,7 @@ def show(img, **kwargs): plt.imshow(img, **kwargs) plt.axis('off') - plt.show() + def plot_maco(image, alpha, percentile_image=1.0, percentile_alpha=80): # visualize image with alpha mask overlay after normalization and clipping @@ -49,4 +54,3 @@ def plot_maco(image, alpha, percentile_image=1.0, percentile_alpha=80): # overlay alpha mask on the image plt.imshow(np.concatenate([image, alpha], -1)) plt.axis('off') - plt.show() diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..a9131bd --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,5 @@ +tox +pytest +pytest-cov +pylint +timm \ No newline at end of file diff --git a/setup.py b/setup.py index e010397..1b87ed4 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ author="Thomas FEL, Thibaut BOISSIN, Victor BOUTIN, Agustin PICARD, Paul NOVELLO", author_email="thomas_fel@brown.edu", license="MIT", - install_requires=['numpy','matplotlib', 'torch', 'torchvision'], + install_requires=['numpy', 'matplotlib', 'torch', 'torchvision'], packages=find_packages(), python_requires=">=3.6", classifiers=[ @@ -22,4 +22,4 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], -) \ No newline at end of file +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..f6075cd --- /dev/null +++ b/tests/common.py @@ -0,0 +1,12 @@ +import torch +import torch.nn as nn + + +class SimpleDummyModel(nn.Module): + def __init__(self): + super(SimpleDummyModel, self).__init__() + + def forward(self, x): + x = torch.mean(x, (1, 2, 3)) + x = torch.relu(x) + return x diff --git a/tests/test_fourier.py b/tests/test_fourier.py new file mode 100644 index 0000000..39b550c --- /dev/null +++ b/tests/test_fourier.py @@ -0,0 +1,16 @@ +import torch +from horama import fourier + +from .common import SimpleDummyModel + + +def test_fourier(): + def objective(images): return torch.mean(model(images)) + model = SimpleDummyModel() + + img_size = 200 + image, alpha = fourier(objective, total_steps=10, image_size=img_size, + model_input_size=100, device='cpu') + + assert image.size() == (3, img_size, img_size) + assert alpha.size() == (3, img_size, img_size) diff --git a/tests/test_maco.py b/tests/test_maco.py new file mode 100644 index 0000000..a097ef1 --- /dev/null +++ b/tests/test_maco.py @@ -0,0 +1,16 @@ +import torch +from horama import maco + +from .common import SimpleDummyModel + + +def test_maco(): + def objective(images): return torch.mean(model(images)) + model = SimpleDummyModel() + + img_size = 200 + image, alpha = maco(objective, total_steps=10, image_size=img_size, + model_input_size=100, device='cpu') + + assert image.size() == (3, img_size, img_size) + assert alpha.size() == (3, img_size, img_size) diff --git a/tests/test_plots.py b/tests/test_plots.py new file mode 100644 index 0000000..2fcf1cc --- /dev/null +++ b/tests/test_plots.py @@ -0,0 +1,41 @@ +import torch +import pytest +from horama import maco, fourier, plot_maco +import matplotlib.pyplot as plt + +from .common import SimpleDummyModel + + +@pytest.fixture +def cleanup_plot(): + yield + plt.close('all') + plt.ion() + + +def test_plot_maco(cleanup_plot): + def objective(images): return torch.mean(model(images)) + model = SimpleDummyModel() + + img_size = 200 + image, alpha = maco(objective, total_steps=10, image_size=img_size, + model_input_size=100, device='cpu') + + plot_maco(image, alpha) + + fig = plt.gcf() + assert fig is not None, "Plotting failed: no figure created" + + +def test_plot_fourier(cleanup_plot): + def objective(images): return torch.mean(model(images)) + model = SimpleDummyModel() + + img_size = 200 + image, alpha = fourier(objective, total_steps=10, image_size=img_size, + model_input_size=100, device='cpu') + + plot_maco(image, alpha) + + fig = plt.gcf() + assert fig is not None, "Plotting failed: no figure created" diff --git a/tests/test_timm.py b/tests/test_timm.py new file mode 100644 index 0000000..f5890fb --- /dev/null +++ b/tests/test_timm.py @@ -0,0 +1,46 @@ +import torch +import timm +import pytest +from horama import maco, fourier, plot_maco + + +@pytest.fixture +def setup_model(): + model = timm.create_model('mobilenetv3_small_050.lamb_in1k', pretrained=False).cpu().eval() + def objective(images): return torch.mean(model(images)[:, 1]) + return model, objective + + +def test_fourier_timm(setup_model): + model, objective = setup_model + + img_size = 128 + model_size = 128 + + image1, alpha1 = maco(objective, total_steps=10, image_size=img_size, + model_input_size=model_size, device='cpu') + plot_maco(image1, alpha1) + + assert image1.size() == (3, img_size, img_size) + assert alpha1.size() == (3, img_size, img_size) + + image2, alpha2 = fourier(objective, total_steps=10, image_size=img_size, + model_input_size=model_size, device='cpu') + plot_maco(image2, alpha2) + + assert image2.size() == (3, img_size, img_size) + assert alpha2.size() == (3, img_size, img_size) + + +def test_maco_timm(setup_model): + model, objective = setup_model + + img_size = 128 + model_size = 128 + + image2, alpha2 = fourier(objective, total_steps=10, image_size=img_size, + model_input_size=model_size, device='cpu') + plot_maco(image2, alpha2) + + assert image2.size() == (3, img_size, img_size) + assert alpha2.size() == (3, img_size, img_size) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..a53420c --- /dev/null +++ b/tox.ini @@ -0,0 +1,13 @@ +[tox] +envlist = py{37,39},py{37,39}-lint + +[testenv] +deps = + -rrequirements.txt + -rrequirements_dev.txt +commands = + pytest --cov=horama --disable-pytest-warnings {posargs} + +[testenv:py{37,39}-lint] +commands = + python -m pylint horama \ No newline at end of file