Skip to content

Commit

Permalink
plots
Browse files Browse the repository at this point in the history
  • Loading branch information
fdobad committed Nov 7, 2024
1 parent e8115bb commit e3da07f
Showing 1 changed file with 61 additions and 14 deletions.
75 changes: 61 additions & 14 deletions src/fire2a/agglomerative_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,19 @@ def write(
return True


def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"):
def plot(labels_reshaped, pipeline, info_list, **kwargs):
"""Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram.
Args:
labels_reshaped (np.ndarray): The reshaped labels of the clusters
pipeline (Pipeline): The pipeline object containing all the steps of the pipeline
info_list (list): A list of dictionaries containing information about each feature
**kargs: Additional keyword arguments
n_clusters (int): The number of clusters
distance_threshold (float): The linkage distance threshold
sieve (int): The number of pixels to use as a sieve filter
block (bool): Block the execution until the plot window is closed
filename (str): The filename to save the plot
"""
from matplotlib import pyplot as plt

no_data_imputed = pipeline.named_steps["no_data_imputer"].output_data
Expand All @@ -425,7 +437,19 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"):

names = [info_list[i]["fname"] for i in nohots_idxs]

fig, axs = plt.subplots(3, 2)
fgs = np.array(plt.rcParams["figure.figsize"]) * 5
fig, axs = plt.subplots(3, 2, figsize=fgs)
suptitle = ""
if n_clusters := kwargs.get("n_clusters"):
suptitle = f"n_clusters: {n_clusters}"
if distance_threshold := kwargs.get("distance_threshold"):
suptitle = f"distance_threshold: {distance_threshold}"
if sieve := kwargs.get("sieve"):
suptitle += f", sieve: {sieve}"
if n_clusters or distance_threshold or sieve:
suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}"
suptitle += "\n(Not showing categorical data)"
fig.suptitle(suptitle)

# plot violin plot
axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True)
Expand All @@ -446,12 +470,14 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"):
axs[1, 0].set_title("Violin Plot of Common Rescaled")
axs[1, 0].yaxis.grid(True)
axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
axs[0, 1].set_ylabel("Adjusted range")

# plot boxplot
axs[1, 1].boxplot(pre_aggclu_data)
axs[1, 1].set_title("Box Plot of Common Rescaled")
axs[1, 1].yaxis.grid(True)
axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names)
axs[0, 1].set_ylabel("Adjusted range")

# cluster history
unique_labels, counts = np.unique(labels_reshaped, return_counts=True)
Expand All @@ -466,11 +492,15 @@ def plot(labels_reshaped, pipeline, info_list): # , filename="plot.png"):
axs[2, 1].set_ylabel("Number of Clusters")
axs[2, 1].set_title("Histogram of Cluster Sizes")

# plt.show(block=False)
plt.show()
# if filename:
# logger.info(f"Saving plot to {filename}")
# plt.savefig(filename)
plt.tight_layout()
if filename := kwargs.get("filename"):
logger.info(f"Saving plot to {filename}")
plt.savefig(filename)
else:
if block := kwargs.get("block"):
plt.show(block=block)
else:
plt.show()


def sieve_filter(data, threshold=2, connectedness=4, feedback=None):
Expand Down Expand Up @@ -552,12 +582,6 @@ def arg_parser(argv=None):
parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="")
parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg")
parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857")
parser.add_argument(
"-p",
"--plots",
action="store_true",
help="Raise a matplotlib window with input/output data related to the clustering. For example, the rescaled data distributions and the clustering size history",
)
parser.add_argument(
"-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)"
)
Expand All @@ -581,6 +605,29 @@ def arg_parser(argv=None):
help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor",
)
parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3")

plot = parser.add_argument_group(
"Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms"
)
plot.add_argument(
"-p",
"--plots",
action="store_true",
help="Activate the plotting routines",
)
plot.add_argument(
"-b",
"--block",
action="store_false",
default=True,
help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS",
)
plot.add_argument(
"-f",
"--filename",
type=str,
help="Filename to save the plot. If not provided, matplotlib will raise a window",
)
args = parser.parse_args(argv)
args.geotransform = tuple(map(float, args.geotransform[1:-1].split(",")))
if Path(args.config_file).is_file() is False:
Expand Down Expand Up @@ -662,7 +709,7 @@ def main(argv=None):

# 7 debbuging plots
if args.plots:
plot(labels_reshaped, pipeline, info_list)
plot(labels_reshaped, pipeline, info_list, **vars(args))

# 8. ESCRIBIR RASTER
if not args.no_write:
Expand Down

0 comments on commit e3da07f

Please sign in to comment.