From 581756fd66c584f341888e875855c320b33e30d7 Mon Sep 17 00:00:00 2001
From: Jenny Zhang <yucizhang2008@gmail.com>
Date: Thu, 16 Jan 2025 00:25:35 -0800
Subject: [PATCH] build: function restructure so can operate in sequential
 order

---
 src/sharpedge/modulate_image.py | 88 +++++++++++++++++----------------
 1 file changed, 46 insertions(+), 42 deletions(-)

diff --git a/src/sharpedge/modulate_image.py b/src/sharpedge/modulate_image.py
index f8163ed..13d0dcc 100644
--- a/src/sharpedge/modulate_image.py
+++ b/src/sharpedge/modulate_image.py
@@ -119,49 +119,53 @@ def modulate_image(img, mode='gray', ch_swap=None, ch_extract=None):
         print("Converting RGB to grayscale...")
         img = np.mean(img, axis=-1)
 
-    # Validate channel extraction when requested (only for RGB images)
-    if ch_extract is not None:
-        if len(img.shape) == 2:
-            warnings.warn("Grayscale images have no channels to extract.", UserWarning)
-            return img  # No channel extraction for grayscale
-        
-        # Validate ch_extract: should be a list or tuple of 0, 1, or 2, with no duplicates
-        if not isinstance(ch_extract, (list, tuple)):
-            raise TypeError("ch_extract must be a list or tuple.")
-        
-        if not all(isinstance(ch, int) for ch in ch_extract):
-            raise TypeError("All elements in ch_extract must be integers.")
-        
-        if not all(ch in [0, 1, 2] for ch in ch_extract):
-            raise ValueError("Invalid channel indices. Only 0, 1, or 2 are valid.")
-        
-        if len(set(ch_extract)) != len(ch_extract):
-            raise ValueError("ch_extract contains duplicate channel indices.")
-        
-        # Handle channel extraction 
-        return img[..., ch_extract]
-
-    # Validate channel swapping when requested (only for RGB images)
-    if ch_swap is not None:
-        if len(img.shape) == 2:
-            warnings.warn("Grayscale images have no channels to swap.", UserWarning)
-            return img  # No channel swapping for grayscale
-        
-        # Validate ch_swap: must be a list or tuple of 3 integers, with no duplicates, and must include all 0, 1, 2
-        if not isinstance(ch_swap, (list, tuple)):
-            raise TypeError(f"ch_swap must be a list or tuple, got {type(ch_swap)}.")
-        
-        if not all(isinstance(ch, int) for ch in ch_swap):
-            raise TypeError("All elements in ch_swap must be integers.")
-        
-        if len(ch_swap) != 3 or not all(ch in [0, 1, 2] for ch in ch_swap):
-            raise ValueError("ch_swap must be three elements of valid RGB channel indices 0, 1, or 2.")
-        
-        if len(set(ch_swap)) != 3:
-            raise ValueError("ch_swap must include all channels 0, 1, and 2 exactly once.")
+    # Check if the image is grayscale (2D) after conversion
+    if len(img.shape) == 2:
+        if ch_swap is not None or ch_extract is not None:
+            warnings.warn("Grayscale images have no channels to swap or extract.", UserWarning)
+        return img  # Return grayscale image
+    
+    # Proceed with channel manipulations when image is RGB (3D) after conversion
+    if len(img.shape) == 3:
+    
+        # Validate channel swapping when requested
+        if ch_swap is not None:
+            
+            # Validate ch_swap: must be a list or tuple of 3 integers, with no duplicates, and must include all 0, 1, 2
+            if not isinstance(ch_swap, (list, tuple)):
+                raise TypeError(f"ch_swap must be a list or tuple, got {type(ch_swap)}.")
+            
+            if not all(isinstance(ch, int) for ch in ch_swap):
+                raise TypeError("All elements in ch_swap must be integers.")
+            
+            if len(ch_swap) != 3 or not all(ch in [0, 1, 2] for ch in ch_swap):
+                raise ValueError("ch_swap must be three elements of valid RGB channel indices 0, 1, or 2.")
+            
+            if len(set(ch_swap)) != 3:
+                raise ValueError("ch_swap must include all channels 0, 1, and 2 exactly once.")
+
+            # Perform channel swapping
+            img = img[..., ch_swap]
+    
+        # Validate channel extraction when requested (can be potentially after ch_swap)
+        if ch_extract is not None:
+
+            # Validate ch_extract: should be a list or tuple of 0, 1, or 2, with no duplicates
+            if not isinstance(ch_extract, (list, tuple)):
+                raise TypeError("ch_extract must be a list or tuple.")
+            
+            if not all(isinstance(ch, int) for ch in ch_extract):
+                raise TypeError("All elements in ch_extract must be integers.")
+            
+            if not all(ch in [0, 1, 2] for ch in ch_extract):
+                raise ValueError("Invalid channel indices. Only 0, 1, or 2 are valid.")
+            
+            if len(set(ch_extract)) != len(ch_extract):
+                raise ValueError("ch_extract contains duplicate channel indices.")
+            
+            # Perform channel extraction 
+            img = img[..., ch_extract]
 
-        # Handle channel swapping
-        return img[..., ch_swap]
 
     # If no operation is requested, return the original image
     return img