-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add test_inference_utils and test_data_utils
- Loading branch information
Ruruthia
committed
Sep 27, 2024
1 parent
74baf71
commit ebd6ef8
Showing
2 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import unittest | ||
import numpy as np | ||
import cv2 | ||
from PIL import Image | ||
from cvdm.utils.data_utils import ( | ||
read_and_patch_image_from_filename, | ||
center_crop, | ||
obtain_noisy_sample, | ||
) | ||
|
||
|
||
class TestImageProcessing(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# This method will run before each test | ||
# Create a sample random image (300x300) with 3 color channels | ||
self.sample_image = np.random.randint(0, 256, (300, 300, 3), dtype=np.uint8) | ||
|
||
def test_read_and_patch_image_from_filename(self): | ||
# Test resizing when image is smaller than im_size | ||
im_size = 400 | ||
|
||
# Convert the sample image to a format that the function can accept | ||
cv2.imwrite("test_image.jpg", self.sample_image) # Temporarily save for testing | ||
patched_image = read_and_patch_image_from_filename("test_image.jpg", im_size) | ||
self.assertEqual(patched_image.size, (im_size, im_size)) | ||
|
||
# Test extracting patches when image is larger than im_size | ||
im_size = 100 | ||
patched_image = read_and_patch_image_from_filename("test_image.jpg", im_size) | ||
self.assertEqual(patched_image.size, (im_size, im_size)) | ||
|
||
def test_center_crop(self): | ||
# Use the random image created in setUp | ||
crop_size = 2048 | ||
|
||
# Create a larger dummy image for cropping | ||
larger_dummy_image = np.random.rand(3000, 3000, 3).astype(np.float32) | ||
cropped_image = center_crop(larger_dummy_image, crop_size) | ||
|
||
# Check if the cropped image has the expected size | ||
self.assertEqual(cropped_image.shape, (crop_size, crop_size, 3)) | ||
|
||
def test_obtain_noisy_sample(self): | ||
# Create a dummy input for testing | ||
x = [ | ||
np.random.rand(256, 256, 3).astype(np.float32), | ||
np.array(0.5, dtype=np.float32), | ||
] | ||
samples = obtain_noisy_sample(x) | ||
|
||
# Check that the output is as expected | ||
self.assertEqual(len(samples), 4) | ||
for sample in samples: | ||
self.assertEqual(sample.shape, (256, 256, 3)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
from tensorflow.keras.layers import Dense | ||
from tensorflow.keras.models import Input, Sequential | ||
|
||
from cvdm.utils.inference_utils import ( | ||
create_output_montage, | ||
ddpm_obtain_sr_img, | ||
log_loss, | ||
log_metrics, | ||
obtain_output_montage_and_metrics, | ||
save_output_montage, | ||
save_weights, | ||
) | ||
|
||
|
||
class TestDDPMFunctions(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Set up test parameters and models | ||
self.x = np.random.rand(1, 256, 256, 3).astype(np.float32) # Random input | ||
self.y = np.random.rand(1, 256, 256, 3).astype(np.float32) # Random target | ||
self.timesteps_test = 10 | ||
|
||
# Simple noise model and schedule model for testing | ||
self.noise_model = self._create_simple_model() | ||
self.schedule_model = self._create_simple_model() | ||
self.mu_model = self._create_simple_model() # Optional model | ||
|
||
self.output_shape = (1, 256, 256, 3) # Example output shape | ||
|
||
def _create_simple_model(self): | ||
# Create a simple model for testing | ||
model = Sequential( | ||
[ | ||
Input(shape=(None, None, 3)), # Flexible input shape | ||
Dense(3, activation="sigmoid"), # Output must match expected shape | ||
] | ||
) | ||
return model | ||
|
||
def test_ddpm_obtain_sr_img(self): | ||
pred_diff, gamma_vec, alpha_vec = ddpm_obtain_sr_img( | ||
self.x, | ||
self.timesteps_test, | ||
self.noise_model, | ||
self.schedule_model, | ||
self.mu_model, | ||
out_shape=self.output_shape, | ||
) | ||
self.assertEqual(pred_diff.shape, self.output_shape) | ||
self.assertEqual(gamma_vec.shape, self.output_shape + (self.timesteps_test,)) | ||
self.assertEqual(alpha_vec.shape, self.output_shape + (self.timesteps_test,)) | ||
|
||
def test_create_output_montage(self): | ||
pred_y = np.random.rand(1, 256, 256, 3) | ||
gamma_vec = np.random.rand(1, 256, 256, 10) # Random gamma vector | ||
output_image = create_output_montage(pred_y, gamma_vec, self.y, self.x) | ||
|
||
# Check the output shape of the montage image | ||
self.assertEqual(output_image.ndim, 3) # Ensure it is a 3D image | ||
self.assertGreater(output_image.shape[0], 0) # Check it has some height | ||
self.assertGreater(output_image.shape[1], 0) # Check it has some width | ||
|
||
def test_log_loss(self): | ||
# Check if logging loss does not raise any errors | ||
avg_loss = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) # Example loss values | ||
log_loss(None, avg_loss, "test_prefix") # Should print without errors | ||
|
||
def test_log_metrics(self): | ||
# Test logging metrics | ||
metrics_dict = {"accuracy": 0.9, "loss": 0.1} | ||
log_metrics(None, metrics_dict, "test_prefix") # Should print without errors | ||
|
||
def test_save_weights(self): | ||
# Check if saving weights does not raise any errors | ||
try: | ||
save_weights( | ||
None, self.noise_model, self.mu_model, 1, "test_path", "test_run_id" | ||
) | ||
except Exception as e: | ||
self.fail(f"save_weights raised an exception: {e}") | ||
|
||
def test_save_output_montage(self): | ||
# Check if saving output montage does not raise any errors | ||
output_montage = np.random.rand(256, 256, 3) | ||
try: | ||
save_output_montage( | ||
None, output_montage, 1, "test_path", "test_run_id", "test_prefix" | ||
) | ||
except Exception as e: | ||
self.fail(f"save_output_montage raised an exception: {e}") | ||
|
||
def test_obtain_output_montage_and_metrics(self): | ||
output_montage, metrics = obtain_output_montage_and_metrics( | ||
self.x, | ||
self.y, | ||
self.noise_model, | ||
self.schedule_model, | ||
self.mu_model, | ||
self.timesteps_test, | ||
diff_inp=True, | ||
task="imagenet_sr", | ||
) | ||
|
||
self.assertEqual(output_montage.shape, (256, 256, 3)) # Check output shape | ||
self.assertIsInstance(metrics, dict) # Ensure metrics is a dictionary |