@@ -338,7 +338,7 @@ def _render_shapes(
338338        cax  =  None 
339339        if  aggregate_with_reduction  is  not None :
340340            vmin  =  aggregate_with_reduction [0 ].values  if  norm .vmin  is  None  else  norm .vmin 
341-             vmax  =  aggregate_with_reduction [1 ].values  if  norm .vmin  is  None  else  norm .vmax 
341+             vmax  =  aggregate_with_reduction [1 ].values  if  norm .vmax  is  None  else  norm .vmax 
342342            if  (norm .vmin  is  not None  or  norm .vmax  is  not None ) and  norm .vmin  ==  norm .vmax :
343343                # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and 
344344                # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) 
@@ -846,20 +846,22 @@ def _render_images(
846846    # 2) Image has any number of channels but 1 
847847    else :
848848        layers  =  {}
849-         for  ch_index , c  in  enumerate (channels ):
850-             layers [c ] =  img .sel (c = c ).copy (deep = True ).squeeze ()
851- 
852-             if  not  isinstance (render_params .cmap_params , list ):
853-                 if  render_params .cmap_params .norm  is  not None :
854-                     layers [c ] =  render_params .cmap_params .norm (layers [c ])
849+         for  ch_idx , ch  in  enumerate (channels ):
850+             layers [ch ] =  img .sel (c = ch ).copy (deep = True ).squeeze ()
851+             if  isinstance (render_params .cmap_params , list ):
852+                 ch_norm  =  render_params .cmap_params [ch_idx ].norm 
853+                 ch_cmap_is_default  =  render_params .cmap_params [ch_idx ].cmap_is_default 
855854            else :
856-                 if  render_params .cmap_params [ch_index ].norm  is  not None :
857-                     layers [c ] =  render_params .cmap_params [ch_index ].norm (layers [c ])
855+                 ch_norm  =  render_params .cmap_params .norm 
856+                 ch_cmap_is_default  =  render_params .cmap_params .cmap_is_default 
857+ 
858+             if  not  ch_cmap_is_default  and  ch_norm  is  not None :
859+                 layers [ch_idx ] =  ch_norm (layers [ch_idx ])
858860
859861        # 2A) Image has 3 channels, no palette info, and no/only one cmap was given 
860862        if  palette  is  None  and  n_channels  ==  3  and  not  isinstance (render_params .cmap_params , list ):
861863            if  render_params .cmap_params .cmap_is_default :  # -> use RGB 
862-                 stacked  =  np .stack ([layers [c ] for  c  in  channels ], axis = - 1 )
864+                 stacked  =  np .stack ([layers [ch ] for  ch  in  layers ], axis = - 1 )
863865            else :  # -> use given cmap for each channel 
864866                channel_cmaps  =  [render_params .cmap_params .cmap ] *  n_channels 
865867                stacked  =  (
@@ -892,12 +894,54 @@ def _render_images(
892894            # overwrite if n_channels == 2 for intuitive result 
893895            if  n_channels  ==  2 :
894896                seed_colors  =  ["#ff0000ff" , "#00ff00ff" ]
895-             else :
897+                 channel_cmaps  =  [_get_linear_colormap ([c ], "k" )[0 ] for  c  in  seed_colors ]
898+                 colored  =  np .stack (
899+                     [channel_cmaps [ch_ind ](layers [ch ]) for  ch_ind , ch  in  enumerate (channels )],
900+                     0 ,
901+                 ).sum (0 )
902+                 colored  =  colored [:, :, :3 ]
903+             elif  n_channels  ==  3 :
896904                seed_colors  =  _get_colors_for_categorical_obs (list (range (n_channels )))
905+                 channel_cmaps  =  [_get_linear_colormap ([c ], "k" )[0 ] for  c  in  seed_colors ]
906+                 colored  =  np .stack (
907+                     [channel_cmaps [ind ](layers [ch ]) for  ind , ch  in  enumerate (channels )],
908+                     0 ,
909+                 ).sum (0 )
910+                 colored  =  colored [:, :, :3 ]
911+             else :
912+                 if  isinstance (render_params .cmap_params , list ):
913+                     cmap_is_default  =  render_params .cmap_params [0 ].cmap_is_default 
914+                 else :
915+                     cmap_is_default  =  render_params .cmap_params .cmap_is_default 
897916
898-             channel_cmaps  =  [_get_linear_colormap ([c ], "k" )[0 ] for  c  in  seed_colors ]
899-             colored  =  np .stack ([channel_cmaps [ind ](layers [ch ]) for  ind , ch  in  enumerate (channels )], 0 ).sum (0 )
900-             colored  =  colored [:, :, :3 ]
917+                 if  cmap_is_default :
918+                     seed_colors  =  _get_colors_for_categorical_obs (list (range (n_channels )))
919+                 else :
920+                     # Sample n_channels colors evenly from the colormap 
921+                     if  isinstance (render_params .cmap_params , list ):
922+                         seed_colors  =  [
923+                             render_params .cmap_params [i ].cmap (i  /  (n_channels  -  1 )) for  i  in  range (n_channels )
924+                         ]
925+                     else :
926+                         seed_colors  =  [render_params .cmap_params .cmap (i  /  (n_channels  -  1 )) for  i  in  range (n_channels )]
927+                 channel_cmaps  =  [_get_linear_colormap ([c ], "k" )[0 ] for  c  in  seed_colors ]
928+ 
929+                 # Stack (n_channels, height, width) → (height*width, n_channels) 
930+                 H , W  =  next (iter (layers .values ())).shape 
931+                 comp_rgb  =  np .zeros ((H , W , 3 ), dtype = float )
932+ 
933+                 # For each channel: map to RGBA, apply constant alpha, then add 
934+                 for  ch_idx , ch  in  enumerate (channels ):
935+                     layer_arr  =  layers [ch ]
936+                     rgba  =  channel_cmaps [ch_idx ](layer_arr )
937+                     rgba [..., 3 ] =  render_params .alpha 
938+                     comp_rgb  +=  rgba [..., :3 ] *  rgba [..., 3 ][..., None ]
939+ 
940+                 colored  =  np .clip (comp_rgb , 0 , 1 )
941+                 logger .info (
942+                     f"Your image has { n_channels }  
943+                     f"multichannel strategy 'stack' to render." 
944+                 )  # TODO: update when pca is added as strategy 
901945
902946            _ax_show_and_transform (
903947                colored ,
@@ -943,6 +987,7 @@ def _render_images(
943987                zorder = render_params .zorder ,
944988            )
945989
990+         # 2D) Image has n channels, no palette but cmap info 
946991        elif  palette  is  not None  and  got_multiple_cmaps :
947992            raise  ValueError ("If 'palette' is provided, 'cmap' must be None." )
948993
0 commit comments