Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correction for visualization #91

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions src/morphoclass/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def plot_accuracies(
file_name : str, pathlib.Path, optional
Path to where to store the plot.
supported_ext : list, optional
Supported image extensions. If default, it will save as ".png", ".eps", ".pdf".
Supported image extensions. If default, it will save as ".png", ".svg", ".pdf".

Returns
-------
Expand All @@ -355,11 +355,10 @@ def plot_accuracies(

ax.grid(color="gray", linestyle=":", linewidth=1)
ax.legend()
ax.set_rasterized(True)
fig.suptitle(title)
if file_name:
if not supported_ext:
supported_ext = [".png", ".eps", ".pdf"]
supported_ext = [".png", ".svg", ".pdf"]
for ext in supported_ext:
fig.savefig(file_name.with_suffix(ext))

Expand Down Expand Up @@ -415,19 +414,18 @@ def plot_confusion_matrix(cm, file_name, labels, supported_ext=None):
labels : list
List of labels for prediction.
"""
fig = Figure(dpi=75)
fig = Figure(dpi=300)
ax = fig.subplots()

sns.heatmap(cm, annot=True, ax=ax, xticklabels=labels, yticklabels=labels)
ax.invert_yaxis()
ax.set_xlabel("Predicted Label")
ax.set_ylabel("True Label")
ax.set_rasterized(True)
fig.tight_layout()

if file_name:
if not supported_ext:
supported_ext = [".png", ".eps", ".pdf"]
supported_ext = [".png", ".svg", ".pdf"]
for ext in supported_ext:
fig.savefig(file_name.with_suffix(ext))
return fig
Expand Down Expand Up @@ -630,7 +628,7 @@ def plot_number_of_nodes(
The figsize for the matplotlib figure.
supported_ext : list, optional
The extensions to store the images. By default it will store them
as ".png", ".eps", ".pdf".
as ".png", ".svg", ".pdf".

Returns
-------
Expand Down Expand Up @@ -679,13 +677,12 @@ def plot_number_of_nodes(
ax.tick_params(axis="both")

ax.tick_params(axis="x", labelrotation=45)
ax.set_rasterized(True)
# fig.tight_layout()

image_path = imagedir / f"number_of_nodes_{rtype.replace(' ', '_')}_{k}"
if image_path:
if not supported_ext:
supported_ext = [".png", ".eps", ".pdf"]
supported_ext = [".png", ".svg", ".pdf"]
for ext in supported_ext:
fig.savefig(image_path.with_suffix(ext))

Expand Down Expand Up @@ -717,7 +714,7 @@ def plot_number_of_nodes_per_layer(
The figsize for the matplotlib figure.
supported_ext : list, optional
The extensions to store the images. By default it will store them
as ".png", ".eps", ".pdf".
as ".png", ".svg", ".pdf".

Returns
-------
Expand Down Expand Up @@ -770,14 +767,13 @@ def plot_number_of_nodes_per_layer(
)
ax.tick_params(axis="both")
# fig.tight_layout()
ax.set_rasterized(True)

image_path = (
imagedir / f"number_of_nodes_{rtype.replace(' ', '_')}_per_layer_{k}"
)
if image_path:
if not supported_ext:
supported_ext = [".png", ".eps", ".pdf"]
supported_ext = [".png", ".svg", ".pdf"]
for ext in supported_ext:
fig.savefig(image_path.with_suffix(ext))

Expand Down
Loading