Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clip Voronoi regions in cvt_archive_heatmap #356

Merged
merged 3 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Add visualization of QDax repertoires (#353)
- Improve cvt_archive_heatmap flexibility (#354)
- Speed up 2D cvt_archive_heatmap by order of magnitude (#355)
- Clip Voronoi regions in cvt_archive_heatmap (#356)

#### Documentation

Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,5 @@
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"qdax": ("https://qdax.readthedocs.io/en/latest/", None),
"shapely": ("https://shapely.readthedocs.io/en/stable/", None),
}
1 change: 1 addition & 0 deletions pinned_reqs/extras_visualize.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ scikit-learn==1.3.0
scipy==1.10.1
threadpoolctl==3.0.0
matplotlib==3.7.2
shapely==2.0.1
56 changes: 45 additions & 11 deletions ribs/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import shapely
from matplotlib.cm import ScalarMappable
from scipy.spatial import Voronoi # pylint: disable=no-name-in-module

Expand Down Expand Up @@ -270,6 +271,7 @@ def cvt_archive_heatmap(archive,
vmax=None,
cbar="auto",
cbar_kwargs=None,
clip=False,
plot_centroids=False,
plot_samples=False,
ms=1):
Expand Down Expand Up @@ -343,6 +345,15 @@ def cvt_archive_heatmap(archive,
the colorbar on the specified Axes.
cbar_kwargs (dict): Additional kwargs to pass to
:func:`~matplotlib.pyplot.colorbar`.
clip (bool, shapely.Polygon): Clip the heatmap cells to a given polygon.
By default, we draw the cells along the outer edges of the heatmap
as polygons that extend beyond the archive bounds, but these
polygons are hidden because we set the axis limits to be the archive
bounds. Passing `clip=True` will clip the heatmap such that these
"outer edge" polygons are within the archive bounds. An arbitrary
polygon can also be passed in to clip the heatmap to a custom shape.
See `#356 <https://github.com/icaros-usc/pyribs/pull/356>`_ for more
info.
plot_centroids (bool): Whether to plot the cluster centroids.
plot_samples (bool): Whether to plot the samples used when generating
the clusters.
Expand Down Expand Up @@ -376,6 +387,11 @@ def cvt_archive_heatmap(archive,
upper_bounds = np.flip(upper_bounds)
centroids = np.flip(centroids, axis=1)

# If clip is on, make it default to an archive bounding box.
if clip and not isinstance(clip, shapely.Polygon):
clip = shapely.box(lower_bounds[0], lower_bounds[1], upper_bounds[0],
upper_bounds[1])

if plot_samples:
samples = archive.samples
if transpose_measures:
Expand Down Expand Up @@ -443,18 +459,36 @@ def cvt_archive_heatmap(archive,
if -1 in region or len(region) == 0:
continue

if objective is None:
# Transparent white (RGBA format) -- this ensures that if a figure
# is saved with a transparent background, the empty cells will also
# be transparent.
facecolors.append(np.array([1.0, 1.0, 1.0, 0.0]))
facecolor_cmap_mask.append(False)
if clip:
# Clip the cell vertices to the polygon. Clipping may cause some
# cells to split into two or more polygons, especially if the clip
# polygon has holes.
polygon = shapely.Polygon(vor.vertices[region])
intersection = polygon.intersection(clip)
if isinstance(intersection, shapely.MultiPolygon):
for polygon in intersection.geoms:
vertices.append(polygon.exterior.coords)
n_splits = len(intersection.geoms)
else:
# The intersection is a single Polygon.
vertices.append(intersection.exterior.coords)
n_splits = 1
else:
facecolors.append(np.empty(4))
facecolor_cmap_mask.append(True)
facecolor_objs.append(objective)

vertices.append(vor.vertices[region])
vertices.append(vor.vertices[region])
n_splits = 1

# Repeat values for each split.
for _ in range(n_splits):
if objective is None:
# Transparent white (RGBA format) -- this ensures that if a
# figure is saved with a transparent background, the empty cells
# will also be transparent.
facecolors.append(np.array([1.0, 1.0, 1.0, 0.0]))
facecolor_cmap_mask.append(False)
else:
facecolors.append(np.empty(4))
facecolor_cmap_mask.append(True)
facecolor_objs.append(objective)

# Compute facecolors from the cmap. We first normalize the objectives and
# clip them to [0, 1].
Expand Down
10 changes: 7 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from setuptools import find_packages, setup

with open("README.md") as readme_file:
with open("README.md", encoding="utf-8") as readme_file:
readme = readme_file.read()

with open("HISTORY.md") as history_file:
with open("HISTORY.md", encoding="utf-8") as history_file:
history = history_file.read()

# NOTE: Update pinned_reqs whenever install_requires or extras_require changes.
Expand All @@ -24,12 +24,16 @@
]

extras_require = {
"visualize": ["matplotlib>=3.0.0",],
"visualize": [
"matplotlib>=3.0.0",
"shapely>=2.0.0",
],
# All dependencies except for dev. Don't worry if there are duplicate
# dependencies, since setuptools automatically handles duplicates.
"all": [
### visualize ###
"matplotlib>=3.0.0",
"shapely>=2.0.0",
],
"dev": [
"pip>=20.3",
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 94 additions & 0 deletions tests/visualize/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
import shapely
from matplotlib.testing.decorators import image_comparison

from ribs.archives import CVTArchive, GridArchive, SlidingBoundariesArchive
Expand Down Expand Up @@ -673,6 +674,99 @@ def test_cvt_archive_heatmap_voronoi_style(cvt_archive):
cvt_archive_heatmap(cvt_archive, lw=3.0, ec="grey")


#
# cvt_archive_heatmap clip tests
#


@image_comparison(baseline_images=["cvt_archive_heatmap_noclip"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_noclip(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, clip=False)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_clip"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_clip(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(cvt_archive, clip=True)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)


@image_comparison(baseline_images=["cvt_archive_heatmap_clip_polygon"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_clip_polygon(cvt_archive):
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(
cvt_archive,
clip=shapely.Polygon(shell=np.array([
[-0.75, -0.375],
[-0.75, 0.375],
[-0.375, 0.75],
[0.375, 0.75],
[0.75, 0.375],
[0.75, -0.375],
[0.375, -0.75],
[-0.375, -0.75],
]),),
)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)


@image_comparison(
baseline_images=["cvt_archive_heatmap_clip_polygon_with_hole"],
remove_text=False,
extensions=["png"],
tol=CVT_IMAGE_TOLERANCE)
def test_cvt_archive_heatmap_clip_polygon_with_hole(cvt_archive):
"""This test will force some cells to be split in two."""
plt.figure(figsize=(8, 6))
cvt_archive_heatmap(
cvt_archive,
clip=shapely.Polygon(
shell=np.array([
[-0.75, -0.375],
[-0.75, 0.375],
[-0.375, 0.75],
[0.375, 0.75],
[0.75, 0.375],
[0.75, -0.375],
[0.375, -0.75],
[-0.375, -0.75],
]),
holes=[
# Two holes that split some cells into two parts, and some cells
# into three parts.
np.array([
[-0.5, 0],
[-0.5, 0.05],
[0.5, 0.05],
[0.5, 0],
]),
np.array([
[-0.5, 0.125],
[-0.5, 0.175],
[0.5, 0.175],
[0.5, 0.125],
]),
],
),
)
plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)


#
# Parallel coordinate plot test
#
Expand Down