Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up 2D cvt_archive_heatmap by order of magnitude #355

Merged
merged 5 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Drop Python 3.7 support and upgrade dependencies (#350)
- Add visualization of QDax repertoires (#353)
- Improve cvt_archive_heatmap flexibility (#354)
- Speed up 2D cvt_archive_heatmap by order of magnitude (#355)

#### Documentation

Expand Down
69 changes: 49 additions & 20 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,27 +422,56 @@ def cvt_archive_heatmap(archive,
if min_obj == max_obj:
min_obj, max_obj = min_obj - 0.01, max_obj + 0.01

# Shade the regions.
#
# Note: by default, the first region will be an empty list -- see:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html
# However, this empty region is ignored by ax.fill since `polygon` is also
# an empty list in this case.
# Vertices of all cells.
vertices = []
# The facecolor of each cell. Shape (n_regions, 4) for RGBA format, but we
# do not know n_regions in advance.
facecolors = []
# Boolean array indicating which of the facecolors needs to be computed with
# the cmap. The other colors correspond to empty cells. Shape (n_regions,)
facecolor_cmap_mask = []
# The objective corresponding to the regions which must be passed through
# the cmap. Shape (sum(facecolor_cmap_mask),)
facecolor_objs = []

# Cycle through the regions to set up polygon vertices and facecolors.
for region, objective in zip(vor.regions, region_obj):
# This check is O(n), but n is typically small, and creating
# `polygon` is also O(n) anyway.
if -1 not in region:
if objective is None:
# Transparent white (RGBA format) -- this ensures that if a
# figure is saved with a transparent background, the empty cells
# will also be transparent.
color = (1.0, 1.0, 1.0, 0.0)
else:
normalized_obj = np.clip(
(objective - min_obj) / (max_obj - min_obj), 0.0, 1.0)
color = cmap(normalized_obj)
polygon = vor.vertices[region]
ax.fill(*zip(*polygon), color=color, ec=ec, lw=lw)
# Checking for -1 is O(n), but n is typically small.
#
# We check length since the first region is an empty list by default:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.Voronoi.html
if -1 in region or len(region) == 0:
continue

if objective is None:
# Transparent white (RGBA format) -- this ensures that if a figure
# is saved with a transparent background, the empty cells will also
# be transparent.
facecolors.append(np.array([1.0, 1.0, 1.0, 0.0]))
facecolor_cmap_mask.append(False)
else:
facecolors.append(np.empty(4))
facecolor_cmap_mask.append(True)
facecolor_objs.append(objective)

vertices.append(vor.vertices[region])

# Compute facecolors from the cmap. We first normalize the objectives and
# clip them to [0, 1].
normalized_objs = np.clip(
(np.asarray(facecolor_objs) - min_obj) / (max_obj - min_obj), 0.0, 1.0)
facecolors = np.asarray(facecolors)
facecolors[facecolor_cmap_mask] = cmap(normalized_objs)

# Plot the collection on the axes. Note that this is faster than plotting
# each polygon individually with ax.fill().
ax.add_collection(
matplotlib.collections.PolyCollection(
vertices,
edgecolors=ec,
facecolors=facecolors,
linewidths=lw,
))

# Create a colorbar.
mappable = ScalarMappable(cmap=cmap)
Expand Down
41 changes: 29 additions & 12 deletions tests/visualize/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

# pylint: disable = redefined-outer-name

# Tolerance for root mean square difference between the pixels of the images,
# where 255 is the max value. We only have tolerance for `cvt_archive_heatmap`
# since it is a bit more finicky than the other plots.
CVT_IMAGE_TOLERANCE = 0.1


@pytest.fixture(autouse=True)
def clean_matplotlib():
Expand Down Expand Up @@ -379,7 +384,8 @@ def test_heatmap_archive__grid_custom_cbar_axis(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_archive__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive)
Expand All @@ -403,7 +409,8 @@ def test_heatmap_with_custom_axis__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_with_custom_axis__cvt(cvt_archive):
_, ax = plt.subplots(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, ax=ax)
Expand All @@ -427,7 +434,8 @@ def test_heatmap_long__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive)
Expand All @@ -451,7 +459,8 @@ def test_heatmap_long_square__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long_square"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long_square__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive, aspect="equal")
Expand All @@ -475,7 +484,8 @@ def test_heatmap_long_transpose__grid(long_grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_long_transpose"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_long_transpose__cvt(long_cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(long_cvt_archive, transpose_measures=True)
Expand All @@ -502,7 +512,8 @@ def test_heatmap_with_limits__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_limits"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_with_limits__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-1.0, vmax=-0.5)
Expand All @@ -527,7 +538,8 @@ def test_heatmap_listed_cmap__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_listed_cmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_listed_cmap__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, cmap=[[1, 0, 0], [0, 1, 0], [0, 0, 1]])
Expand All @@ -553,7 +565,8 @@ def test_heatmap_coolwarm_cmap__grid(grid_archive):

@image_comparison(baseline_images=["cvt_archive_heatmap_with_coolwarm_cmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_heatmap_coolwarm_cmap__cvt(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, cmap="coolwarm")
Expand Down Expand Up @@ -614,23 +627,26 @@ def test_sliding_archive_mismatch_xy_with_boundaries():

@image_comparison(baseline_images=["cvt_archive_heatmap_vmin_equals_vmax"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_vmin_equals_vmax(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, vmin=-0.5, vmax=-0.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_centroids"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_with_centroids(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, plot_centroids=True)


@image_comparison(baseline_images=["cvt_archive_heatmap_with_samples"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_with_samples(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, plot_samples=True)
Expand All @@ -650,7 +666,8 @@ def test_cvt_archive_heatmap_no_samples_error():

@image_comparison(baseline_images=["cvt_archive_heatmap_voronoi_style"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_voronoi_style(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey")
Expand Down
3 changes: 2 additions & 1 deletion tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def clean_matplotlib():

@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"])
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in visualize_test.py
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

Expand Down