Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization of 3D QDax repertoires #373

Merged
merged 5 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Add `rasterized` arg for heatmaps (#359)
- Support 1D cvt_archive_heatmap ({pr}`362`)
- Add 3D plots for CVTArchive ({pr}`371`)
- Add visualization of 3D QDax repertoires ({pr}`372`)

#### Documentation

Expand Down
5 changes: 4 additions & 1 deletion ribs/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,24 @@
ribs.visualize.grid_archive_heatmap
ribs.visualize.parallel_axes_plot
ribs.visualize.sliding_boundaries_archive_heatmap
ribs.visualize.qdax_repertoire_3d_plot
ribs.visualize.qdax_repertoire_heatmap
"""
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap
from ribs.visualize._grid_archive_heatmap import grid_archive_heatmap
from ribs.visualize._parallel_axes_plot import parallel_axes_plot
from ribs.visualize._qdax_repertoire_heatmap import qdax_repertoire_heatmap
from ribs.visualize._sliding_boundaries_archive_heatmap import \
sliding_boundaries_archive_heatmap
from ribs.visualize._visualize_qdax import (qdax_repertoire_3d_plot,
qdax_repertoire_heatmap)

__all__ = [
"cvt_archive_3d_plot",
"cvt_archive_heatmap",
"grid_archive_heatmap",
"parallel_axes_plot",
"sliding_boundaries_archive_heatmap",
"qdax_repertoire_3d_plot",
"qdax_repertoire_heatmap",
]
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
"""Provides qdax_repertoire_heatmap."""
"""Provides visualization functions for QDax repertoires."""
import numpy as np

from ribs.archives import CVTArchive
from ribs.visualize._cvt_archive_3d_plot import cvt_archive_3d_plot
from ribs.visualize._cvt_archive_heatmap import cvt_archive_heatmap


def _as_cvt_archive(repertoire, ranges):
"""Converts a QDax repertoire into a CVTArchive."""

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

return cvt_archive


def qdax_repertoire_heatmap(
repertoire,
ranges,
Expand Down Expand Up @@ -33,22 +57,36 @@ def qdax_repertoire_heatmap(
"""
# pylint: enable = line-too-long

# Construct a CVTArchive. We set solution_dim to 0 since we are only
# plotting and do not need to have the solutions available.
cvt_archive = CVTArchive(
solution_dim=0,
cells=repertoire.centroids.shape[0],
ranges=ranges,
custom_centroids=repertoire.centroids,
)
cvt_archive_heatmap(_as_cvt_archive(repertoire, ranges), *args, **kwargs)

# Add everything to the CVTArchive.
occupied = repertoire.fitnesses != -np.inf
cvt_archive.add(
np.empty((occupied.sum(), 0)),
repertoire.fitnesses[occupied],
repertoire.descriptors[occupied],
)

# Plot the archive.
cvt_archive_heatmap(cvt_archive, *args, **kwargs)
def qdax_repertoire_3d_plot(
repertoire,
ranges,
*args,
**kwargs,
):
# pylint: disable = line-too-long
"""Plots a QDax MapElitesRepertoire with 3D measure space.

Internally, this function converts a
:class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into
a :class:`~ribs.archives.CVTArchive` and plots it with
:meth:`cvt_archive_3d_plot`.

Args:
repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire):
A MAP-Elites repertoire output by an algorithm in QDax.
ranges (array-like of (float, float)): Upper and lower bound of each
dimension of the measure space, e.g. ``[(-1, 1), (-2, 2), (-3, 3)]``
indicates the first dimension should have bounds :math:`[-1,1]`
(inclusive), the second dimension should have bounds :math:`[-2,2]`,
and the third dimension should have bounds :math:`[-3,3]`
(inclusive). This is needed since the MapElitesRepertoire does not
store measure space bounds.
*args: Positional arguments to pass to :meth:`cvt_archive_3d_plot`.
**kwargs: Keyword arguments to pass to :meth:`cvt_archive_3d_plot`.
"""
# pylint: enable = line-too-long

cvt_archive_3d_plot(_as_cvt_archive(repertoire, ranges), *args, **kwargs)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
52 changes: 47 additions & 5 deletions tests/visualize_qdax/visualize_qdax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire,
compute_cvt_centroids)

from ribs.visualize import qdax_repertoire_heatmap
from ribs.visualize import qdax_repertoire_3d_plot, qdax_repertoire_heatmap


@pytest.fixture(autouse=True)
Expand All @@ -26,10 +26,11 @@ def clean_matplotlib():
plt.close("all")


@image_comparison(baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in visualize_test.py
@image_comparison(
baseline_images=["qdax_repertoire_heatmap"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in cvt_archive_heatmap_test.py
def test_qdax_repertoire_heatmap():
plt.figure(figsize=(8, 6))

Expand Down Expand Up @@ -59,3 +60,44 @@ def test_qdax_repertoire_heatmap():

# Plot heatmap.
qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)])


@image_comparison(
baseline_images=["qdax_repertoire_3d_plot"],
remove_text=False,
extensions=["png"],
tol=0.1) # See CVT_IMAGE_TOLERANCE in cvt_archive_3d_plot_test.py
def test_qdax_repertoire_3d_plot():
plt.figure(figsize=(8, 6))

random_key = jax.random.PRNGKey(42)

# Compute the CVT centroids.
random_key, subkey = jax.random.split(random_key)
centroids, _ = compute_cvt_centroids(
num_descriptors=3,
num_init_cvt_samples=1000,
num_centroids=500,
minval=-1,
maxval=1,
random_key=subkey,
)

# Create initial population.
random_key, *subkeys = jax.random.split(random_key, 4)
x = jax.random.uniform(subkeys[0], (10000,), minval=-1.0, maxval=1.0)
y = jax.random.uniform(subkeys[1], (10000,), minval=-1.0, maxval=1.0)
z = jax.random.uniform(subkeys[2], (10000,), minval=-1.0, maxval=1.0)
init_pop = jnp.stack((x, y, z), axis=1)

# Create repertoire with the initial population inserted.
repertoire = MapElitesRepertoire.init(
genotypes=init_pop,
# Negative sphere function.
fitnesses=-jnp.sum(jnp.square(init_pop), axis=1),
descriptors=init_pop,
centroids=centroids,
)

# Plot heatmap.
qdax_repertoire_3d_plot(repertoire, ranges=[(-1, 1), (-1, 1), (-1, 1)])