diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 29c96ae06..a034a0850 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -70,8 +70,12 @@ jobs: pytest tests/archives tests/emitters tests/schedulers - name: Install extras deps run: pip install -r pinned_reqs/extras_visualize.txt - - name: Test extras + - name: Test visualize extra run: pytest tests/visualize + - name: Install QDax + run: pip install qdax + - name: Test visualize extra for QDax + run: pytest tests/visualize_qdax coverage: runs-on: ubuntu-latest steps: @@ -87,7 +91,11 @@ jobs: - name: Test coverage env: NUMBA_DISABLE_JIT: 1 - run: pytest tests + # Exclude `visualize_qdax` since we don't install QDax here. We also + # exclude `tests` since we don't want the base directory here. + run: + pytest $(find tests -maxdepth 1 -type d -not -name 'tests' -not -name + 'visualize_qdax') benchmarks: runs-on: ubuntu-latest steps: diff --git a/HISTORY.md b/HISTORY.md index 744b8209b..cb4cb7138 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -7,6 +7,7 @@ #### API - Drop Python 3.7 support and upgrade dependencies (#350) +- Add visualization of QDax repertoires (#353) #### Documentation diff --git a/docs/conf.py b/docs/conf.py index b2f96552c..15de98353 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -255,4 +255,5 @@ "python": ("https://docs.python.org/3/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "sklearn": ("https://scikit-learn.org/stable/", None), + "qdax": ("https://qdax.readthedocs.io/en/latest/", None), } diff --git a/ribs/visualize.py b/ribs/visualize.py index 907d2c3b8..b905691bb 100644 --- a/ribs/visualize.py +++ b/ribs/visualize.py @@ -27,6 +27,8 @@ from matplotlib.cm import ScalarMappable from scipy.spatial import Voronoi # pylint: disable=no-name-in-module +from ribs.archives import CVTArchive + # Matplotlib functions tend to have a ton of args. # pylint: disable = too-many-arguments @@ -782,3 +784,51 @@ def parallel_axes_plot(archive, ax=host_ax, pad=cbar_pad, orientation=cbar_orientation) + + +def qdax_repertoire_heatmap( + repertoire, + ranges, + *args, + **kwargs, +): + # pylint: disable = line-too-long + """Plots a heatmap of a QDax MapElitesRepertoire. + + Internally, this function converts a + :class:`~qdax.core.containers.mapelites_repertoire.MapElitesRepertoire` into + a :class:`~ribs.archives.CVTArchive` and plots it with + :meth:`cvt_archive_heatmap`. + + Args: + repertoire (qdax.core.containers.mapelites_repertoire.MapElitesRepertoire): + A MAP-Elites repertoire output by an algorithm in QDax. + ranges (array-like of (float, float)): Upper and lower bound of each + dimension of the measure space, e.g. ``[(-1, 1), (-2, 2)]`` + indicates the first dimension should have bounds :math:`[-1,1]` + (inclusive), and the second dimension should have bounds + :math:`[-2,2]` (inclusive). + *args: Positional arguments to pass to :meth:`cvt_archive_heatmap`. + **kwargs: Keyword arguments to pass to :meth:`cvt_archive_heatmap`. + """ + # pylint: enable = line-too-long + + # Construct a CVTArchive. We set solution_dim to 0 since we are only + # plotting and do not need to have the solutions available. + cvt_archive = CVTArchive( + solution_dim=0, + cells=repertoire.centroids.shape[0], + ranges=ranges, + custom_centroids=repertoire.centroids, + ) + + # Add everything to the CVTArchive. + occupied = repertoire.fitnesses != -np.inf + cvt_archive.add( + np.empty((occupied.sum(), 0)), + repertoire.fitnesses[occupied], + repertoire.descriptors[occupied], + ) + + # Plot the archive. + cvt_archive_heatmap(cvt_archive, *args, **kwargs) diff --git a/tests/README.md b/tests/README.md index c061712dd..728a3436d 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,12 +1,17 @@ # Tests -This directory contains tests and micro-benchmarks for ribs. The tests mirror +This directory contains tests and micro-benchmarks for pyribs. The tests mirror the directory structure of `ribs`. To run these tests, install the dev dependencies for ribs with `pip install ribs[dev]` or `pip install -e .[dev]` (from the root directory of the repo). For information on running tests, see [CONTRIBUTING.md](../CONTRIBUTING.md). +## Visualization Tests + +We divide the visualization tests into `visualize` and `visualize_qdax`, where +`visualize_qdax` tests visualizations of QDax components. + ## Additional Tests This directory also contains: diff --git a/tests/visualize_qdax/__init__.py b/tests/visualize_qdax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/visualize_qdax/baseline_images/visualize_qdax_test/qdax_repertoire_heatmap.png b/tests/visualize_qdax/baseline_images/visualize_qdax_test/qdax_repertoire_heatmap.png new file mode 100644 index 000000000..1bf29b9fd Binary files /dev/null and b/tests/visualize_qdax/baseline_images/visualize_qdax_test/qdax_repertoire_heatmap.png differ diff --git a/tests/visualize_qdax/visualize_qdax_test.py b/tests/visualize_qdax/visualize_qdax_test.py new file mode 100644 index 000000000..ac13c7c3a --- /dev/null +++ b/tests/visualize_qdax/visualize_qdax_test.py @@ -0,0 +1,60 @@ +"""Tests for ribs.visualize that use qdax. + +Instructions are identical as in visualize_test.py, but images are stored in +tests/visualize_qdax_test/baseline_images/visualize_qdax_test instead. +""" +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import pytest +from matplotlib.testing.decorators import image_comparison +from qdax.core.containers.mapelites_repertoire import (MapElitesRepertoire, + compute_cvt_centroids) + +from ribs.visualize import qdax_repertoire_heatmap + + +@pytest.fixture(autouse=True) +def clean_matplotlib(): + """Cleans up matplotlib figures before and after each test.""" + # Before the test. + plt.close("all") + + yield + + # After the test. + plt.close("all") + + +@image_comparison(baseline_images=["qdax_repertoire_heatmap"], + remove_text=False, + extensions=["png"]) +def test_qdax_repertoire_heatmap(): + plt.figure(figsize=(8, 6)) + + # Compute the CVT centroids. + centroids, _ = compute_cvt_centroids( + num_descriptors=2, + num_init_cvt_samples=1000, + num_centroids=100, + minval=-1, + maxval=1, + random_key=jax.random.PRNGKey(42), + ) + + # Create initial population. + init_pop_x, init_pop_y = jnp.meshgrid(jnp.linspace(-1, 1, 50), + jnp.linspace(-1, 1, 50)) + init_pop = jnp.stack((init_pop_x.flatten(), init_pop_y.flatten()), axis=1) + + # Create repertoire with the initial population inserted. + repertoire = MapElitesRepertoire.init( + genotypes=init_pop, + # Negative sphere function. + fitnesses=-jnp.sum(jnp.square(init_pop), axis=1), + descriptors=init_pop, + centroids=centroids, + ) + + # Plot heatmap. + qdax_repertoire_heatmap(repertoire, ranges=[(-1, 1), (-1, 1)])