diff --git a/data/ubc/embedding_test.py b/data/ubc/embedding_test.py index 9e89e459..14ccfcd2 100755 --- a/data/ubc/embedding_test.py +++ b/data/ubc/embedding_test.py @@ -4,6 +4,7 @@ python3 embedding_test.py small/G80223_20230513.bin_scale.bin ''' + import os import sys import umap @@ -78,8 +79,8 @@ def get_model(model_type='tsne'): rgb_values[:, i][np.where(rgb_values[:, i] < 0.)] = 0. rgb_values[:, i][np.where(rgb_values[:, i] > 1.)] = 1. -# Create a figure with two subplots (for UMAP/t-SNE and RGB values) -fig, axs = plt.subplots(1, 2, figsize=(16, 6)) # 1 row, 2 columns +# Create a figure with three subplots (for UMAP/t-SNE, RGB values with masks, and original image) +fig, axs = plt.subplots(1, 3, figsize=(18, 6)) # 1 row, 3 columns # First subplot: UMAP or t-SNE projection colored by RGB values scatter = axs[0].scatter(embedding[:, 0], embedding[:, 1], c=rgb_values, s=1, alpha=0.5) # Scatter plot with RGB coloring @@ -87,7 +88,7 @@ def get_model(model_type='tsne'): axs[0].set_xlabel(f'{model_type.upper()} Component 1') axs[0].set_ylabel(f'{model_type.upper()} Component 2') -# Second subplot: Display RGB values using imshow +# Second subplot: Display RGB values using imshow with the mask applied # Reshape the RGB values back to (height, width, 3) for imshow rgb_image = rgb_values.reshape(height, width, 3) # Recreate the RGB image @@ -96,6 +97,15 @@ def get_model(model_type='tsne'): axs[1].axis('off') # Hide axes for better visualization axs[1].set_title('RGB Values of Raster') +# Third subplot: Original image (without any modifications) +# Display the original image (in this case, we are using RGB bands for visualization) +original_image = np.moveaxis(data[:3, :, :], 0, -1) # First 3 bands for the original image + +# Display the original image +axs[2].imshow(original_image) +axs[2].axis('off') # Hide axes for better visualization +axs[2].set_title('Original Image') + # Initialize variables for polygon drawing polygon_points = [] polygon_path = None @@ -170,3 +180,4 @@ def mask_polygon(): plt.tight_layout() plt.show() +