diff --git a/src/fire2a/agglomerative_clustering.py b/src/fire2a/agglomerative_clustering.py index 9e3f02c..bc17704 100644 --- a/src/fire2a/agglomerative_clustering.py +++ b/src/fire2a/agglomerative_clustering.py @@ -399,7 +399,19 @@ def write( return True -def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"): +def plot(labels_reshaped, pipeline, info_list, **kwargs): + """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. + Args: + labels_reshaped (np.ndarray): The reshaped labels of the clusters + pipeline (Pipeline): The pipeline object containing all the steps of the pipeline + info_list (list): A list of dictionaries containing information about each feature + **kargs: Additional keyword arguments + n_clusters (int): The number of clusters + distance_threshold (float): The linkage distance threshold + sieve (int): The number of pixels to use as a sieve filter + block (bool): Block the execution until the plot window is closed + filename (str): The filename to save the plot + """ from matplotlib import pyplot as plt no_data_imputed = pipeline.named_steps["no_data_imputer"].output_data @@ -425,7 +437,19 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"): names = [info_list[i]["fname"] for i in nohots_idxs] - fig, axs = plt.subplots(3, 2) + fgs = np.array(plt.rcParams["figure.figsize"]) * 5 + fig, axs = plt.subplots(3, 2, figsize=fgs) + suptitle = "" + if n_clusters := kwargs.get("n_clusters"): + suptitle = f"n_clusters: {n_clusters}" + if distance_threshold := kwargs.get("distance_threshold"): + suptitle = f"distance_threshold: {distance_threshold}" + if sieve := kwargs.get("sieve"): + suptitle += f", sieve: {sieve}" + if n_clusters or distance_threshold or sieve: + suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}" + suptitle += "\n(Not showing categorical data)" + fig.suptitle(suptitle) # plot violin plot axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True) @@ -446,12 +470,14 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"): axs[1, 0].set_title("Violin Plot of Common Rescaled") axs[1, 0].yaxis.grid(True) axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) + axs[0, 1].set_ylabel("Adjusted range") # plot boxplot axs[1, 1].boxplot(pre_aggclu_data) axs[1, 1].set_title("Box Plot of Common Rescaled") axs[1, 1].yaxis.grid(True) axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) + axs[0, 1].set_ylabel("Adjusted range") # cluster history unique_labels, counts = np.unique(labels_reshaped, return_counts=True) @@ -466,11 +492,15 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"): axs[2, 1].set_ylabel("Number of Clusters") axs[2, 1].set_title("Histogram of Cluster Sizes") - # plt.show(block=False) - plt.show() - # if filename: - # logger.info(f"Saving plot to {filename}") - # plt.savefig(filename) + plt.tight_layout() + if filename := kwargs.get("filename"): + logger.info(f"Saving plot to {filename}") + plt.savefig(filename) + else: + if block := kwargs.get("block"): + plt.show(block=block) + else: + plt.show() def sieve_filter(data, threshold=2, connectedness=4, feedback=None): @@ -552,12 +582,6 @@ def arg_parser(argv=None): parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="") parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg") parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857") - parser.add_argument( - "-p", - "--plots", - action="store_true", - help="Raise a matplotlib window with input/output data related to the clustering. For example, the rescaled data distributions and the clustering size history", - ) parser.add_argument( "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)" ) @@ -581,6 +605,29 @@ def arg_parser(argv=None): help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor", ) parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3") + + plot = parser.add_argument_group( + "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms" + ) + plot.add_argument( + "-p", + "--plots", + action="store_true", + help="Activate the plotting routines", + ) + plot.add_argument( + "-b", + "--block", + action="store_false", + default=True, + help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS", + ) + plot.add_argument( + "-f", + "--filename", + type=str, + help="Filename to save the plot. If not provided, matplotlib will raise a window", + ) args = parser.parse_args(argv) args.geotransform = tuple(map(float, args.geotransform[1:-1].split(","))) if Path(args.config_file).is_file() is False: @@ -662,7 +709,7 @@ def main(argv=None): # 7 debbuging plots if args.plots: - plot(labels_reshaped, pipeline, info_list) + plot(labels_reshaped, pipeline, info_list, **vars(args)) # 8. ESCRIBIR RASTER if not args.no_write: