diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d26a231..17a9b33 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,13 +1,32 @@ name: CI on: + push: + branches: + - main pull_request: branches: - main + jobs: - ruff: + ci: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: chartboost/ruff-action@v1 + - name: Set up rye + uses: eifinger/setup-rye@v3 + - name: Install dependencies + run: | + rye config --set-bool behavior.use-uv=true + rye sync --no-lock + - name: Run lint + run: | + rye lint + - name: Run tests + run: | + rye run cov + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index de87450..3b7f7ab 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ ![python](https://img.shields.io/badge/python-3.10-blue) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![CI](https://github.com/nomutin/RSSM/actions/workflows/ci.yaml/badge.svg)](https://github.com/nomutin/RSSM/actions/workflows/ci.yaml) +[![codecov](https://codecov.io/gh/nomutin/RSSM/graph/badge.svg?token=YMR2H87R5C)](https://codecov.io/gh/nomutin/RSSM) [RSSMs](https://danijar.com/project/dreamer/) for imitation learning. diff --git a/example/__init__.py b/example/__init__.py new file mode 100644 index 0000000..0e1c1bb --- /dev/null +++ b/example/__init__.py @@ -0,0 +1 @@ +"""Training/Evaluation examples.""" diff --git a/src/rssm/callback.py b/example/callback.py similarity index 100% rename from src/rssm/callback.py rename to example/callback.py diff --git a/config/.gitkeep b/example/config/.gitkeep similarity index 100% rename from config/.gitkeep rename to example/config/.gitkeep diff --git a/config/two_buttons_v1.yaml b/example/config/two_buttons_v1.yaml similarity index 100% rename from config/two_buttons_v1.yaml rename to example/config/two_buttons_v1.yaml diff --git a/config/two_buttons_v2.yaml b/example/config/two_buttons_v2.yaml similarity index 100% rename from config/two_buttons_v2.yaml rename to example/config/two_buttons_v2.yaml diff --git a/data/.gitkeep b/example/data/.gitkeep similarity index 100% rename from data/.gitkeep rename to example/data/.gitkeep diff --git a/src/rssm/dataset.py b/example/dataset.py similarity index 100% rename from src/rssm/dataset.py rename to example/dataset.py diff --git a/scripts/download.py b/example/download.py similarity index 100% rename from scripts/download.py rename to example/download.py diff --git a/scripts/train.py b/example/train.py similarity index 100% rename from scripts/train.py rename to example/train.py diff --git a/src/rssm/transform.py b/example/transform.py similarity index 100% rename from src/rssm/transform.py rename to example/transform.py diff --git a/pyproject.toml b/pyproject.toml index bef8f64..645cde7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,10 @@ [project] name = "rssm" -version = "0.1.0" +version = "0.1.2" description = "Reccurent State-Space Model" dependencies = [ "lightning>=2.2.1", - "wandb>=0.16.4", "torchrl>=0.3.1", - "hydra-core>=1.3.2", "distribution-extension @ git+https://github.com/nomutin/distribution-extension.git", ] readme = "README.md" @@ -21,16 +19,7 @@ managed = true dev-dependencies = [ "mypy>=1.9.0", "ruff>=0.4.2", - "kornia>=0.7.2", - "imageio>=2.34.1", - "moviepy>=1.0.3", - "gdown>=5.1.0", - "matplotlib>=3.8.4", - "rich>=13.7.1", - "einops>=0.8.0", - "torchvision>=0.18.0", - "jsonargparse[signatures]>=4.27.7", - "cnn @ git+https://github.com/nomutin/CNN", + "pytest-cov>=5.0.0", ] [tool.hatch.metadata] @@ -39,12 +28,21 @@ allow-direct-references = true [tool.hatch.build.targets.wheel] packages = ["src/rssm"] +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::UserWarning", + "ignore::DeprecationWarning", +] + +[tool.rye.scripts] +cov = "pytest -ra --cov=src --cov-report=term --cov-report=xml" + [tool.mypy] python_version = "3.10" ignore_missing_imports = true [tool.ruff] -line-length = 79 +line-length = 80 target-version = "py310" [tool.ruff.lint] @@ -69,8 +67,10 @@ known-first-party = ["rssm"] [tool.ruff.lint.per-file-ignores] "src/rssm/core.py" = ["PLR0913"] "src/rssm/networks.py" = ["PLR0913"] -"src/rssm/dataset.py" = ["PLR0913"] -"src/rssm/callback.py" = ["SLF001"] +"example/dataset.py" = ["PLR0913"] +"example/callback.py" = ["SLF001"] +"tests/*.py" = ["S101"] +"tests/test__core.py" = ["PLR6301", "PLR2004"] [tool.ruff.lint.pydocstyle] convention = "numpy" diff --git a/requirements-dev.lock b/requirements-dev.lock index b84e047..57ba444 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -13,98 +13,45 @@ aiohttp==3.9.5 # via fsspec aiosignal==1.3.1 # via aiohttp -antlr4-python3-runtime==4.9.3 - # via hydra-core - # via omegaconf -appdirs==1.4.4 - # via wandb async-timeout==4.0.3 # via aiohttp attrs==23.2.0 # via aiohttp -beautifulsoup4==4.12.3 - # via gdown -certifi==2024.2.2 - # via requests - # via sentry-sdk -charset-normalizer==3.3.2 - # via requests -click==8.1.7 - # via wandb cloudpickle==3.0.0 # via tensordict # via torchrl -cnn @ git+https://github.com/nomutin/CNN@589f47eb1f269de4532d25686b6fe2ea880711d3 -contourpy==1.2.1 - # via matplotlib -cycler==0.12.1 - # via matplotlib -decorator==4.4.2 - # via moviepy +coverage==7.5.4 + # via pytest-cov distribution-extension @ git+https://github.com/nomutin/distribution-extension.git@0b5c0cdf5bd19f6f21e373b24d5139deffe93c98 # via rssm -docker-pycreds==0.4.0 - # via wandb -docstring-parser==0.16 - # via jsonargparse einops==0.8.0 - # via cnn # via distribution-extension +exceptiongroup==1.2.1 + # via pytest filelock==3.14.0 - # via gdown - # via huggingface-hub # via torch # via triton -fonttools==4.51.0 - # via matplotlib frozenlist==1.4.1 # via aiohttp # via aiosignal fsspec==2024.3.1 - # via huggingface-hub # via lightning # via pytorch-lightning # via torch -gdown==5.1.0 -gitdb==4.0.11 - # via gitpython -gitpython==3.1.43 - # via wandb -huggingface-hub==0.23.4 - # via timm -hydra-core==1.3.2 - # via rssm idna==3.7 - # via requests # via yarl -imageio==2.34.1 - # via moviepy -imageio-ffmpeg==0.4.9 - # via moviepy -importlib-resources==6.4.0 - # via typeshed-client +iniconfig==2.0.0 + # via pytest jinja2==3.1.3 # via torch -jsonargparse==4.28.0 -kiwisolver==1.4.5 - # via matplotlib -kornia==0.7.2 -kornia-rs==0.1.3 - # via kornia lightning==2.2.3 # via rssm lightning-utilities==0.11.2 # via lightning # via pytorch-lightning # via torchmetrics -markdown-it-py==3.0.0 - # via rich markupsafe==2.1.5 # via jinja2 -matplotlib==3.8.4 -mdurl==0.1.2 - # via markdown-it-py -moviepy==1.0.3 mpmath==1.3.0 # via sympy multidict==6.0.5 @@ -116,16 +63,11 @@ mypy-extensions==1.0.0 networkx==3.3 # via torch numpy==1.26.4 - # via contourpy - # via imageio # via lightning - # via matplotlib - # via moviepy # via pytorch-lightning # via tensordict # via torchmetrics # via torchrl - # via torchvision nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 @@ -154,121 +96,55 @@ nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch -omegaconf==2.3.0 - # via hydra-core packaging==24.0 - # via huggingface-hub - # via hydra-core - # via kornia # via lightning # via lightning-utilities - # via matplotlib + # via pytest # via pytorch-lightning # via torchmetrics # via torchrl -pillow==10.3.0 - # via imageio - # via matplotlib - # via torchvision -proglog==0.1.10 - # via moviepy -protobuf==4.25.3 - # via wandb -psutil==5.9.8 - # via wandb -pygments==2.18.0 - # via rich -pyparsing==3.1.2 - # via matplotlib -pysocks==1.7.1 - # via requests -python-dateutil==2.9.0.post0 - # via matplotlib +pluggy==1.5.0 + # via pytest +pytest==8.2.2 + # via pytest-cov +pytest-cov==5.0.0 pytorch-lightning==2.2.3 # via lightning pyyaml==6.0.1 - # via huggingface-hub - # via jsonargparse # via lightning - # via omegaconf # via pytorch-lightning - # via timm - # via wandb -requests==2.31.0 - # via gdown - # via huggingface-hub - # via moviepy - # via wandb -rich==13.7.1 ruff==0.4.2 -safetensors==0.4.3 - # via timm -sentry-sdk==2.0.1 - # via wandb -setproctitle==1.3.3 - # via wandb setuptools==69.5.1 - # via imageio-ffmpeg # via lightning-utilities - # via wandb -six==1.16.0 - # via docker-pycreds - # via python-dateutil -smmap==5.0.1 - # via gitdb -soupsieve==2.5 - # via beautifulsoup4 sympy==1.12 # via torch tensordict==0.4.0 # via torchrl -timm==1.0.7 - # via cnn tomli==2.0.1 + # via coverage # via mypy + # via pytest torch==2.3.0 - # via cnn - # via kornia # via lightning # via pytorch-lightning # via tensordict - # via timm - # via torchgeometry # via torchmetrics # via torchrl - # via torchvision -torchgeometry==0.1.2 - # via cnn torchmetrics==1.3.2 # via lightning # via pytorch-lightning torchrl==0.4.0 # via rssm -torchvision==0.18.0 - # via timm tqdm==4.66.2 - # via gdown - # via huggingface-hub # via lightning - # via moviepy - # via proglog # via pytorch-lightning triton==2.3.0 # via torch -typeshed-client==2.5.1 - # via jsonargparse typing-extensions==4.11.0 - # via huggingface-hub # via lightning # via lightning-utilities # via mypy # via pytorch-lightning # via torch - # via typeshed-client -urllib3==2.2.1 - # via requests - # via sentry-sdk -wandb==0.16.6 - # via rssm yarl==1.9.4 # via aiohttp diff --git a/requirements.lock b/requirements.lock index caabff3..4eb9259 100644 --- a/requirements.lock +++ b/requirements.lock @@ -13,29 +13,15 @@ aiohttp==3.9.5 # via fsspec aiosignal==1.3.1 # via aiohttp -antlr4-python3-runtime==4.9.3 - # via hydra-core - # via omegaconf -appdirs==1.4.4 - # via wandb async-timeout==4.0.3 # via aiohttp attrs==23.2.0 # via aiohttp -certifi==2024.2.2 - # via requests - # via sentry-sdk -charset-normalizer==3.3.2 - # via requests -click==8.1.7 - # via wandb cloudpickle==3.0.0 # via tensordict # via torchrl distribution-extension @ git+https://github.com/nomutin/distribution-extension.git@0b5c0cdf5bd19f6f21e373b24d5139deffe93c98 # via rssm -docker-pycreds==0.4.0 - # via wandb einops==0.8.0 # via distribution-extension filelock==3.14.0 @@ -48,14 +34,7 @@ fsspec==2024.3.1 # via lightning # via pytorch-lightning # via torch -gitdb==4.0.11 - # via gitpython -gitpython==3.1.43 - # via wandb -hydra-core==1.3.2 - # via rssm idna==3.7 - # via requests # via yarl jinja2==3.1.3 # via torch @@ -108,39 +87,19 @@ nvidia-nvjitlink-cu12==12.4.127 # via nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch -omegaconf==2.3.0 - # via hydra-core packaging==24.0 - # via hydra-core # via lightning # via lightning-utilities # via pytorch-lightning # via torchmetrics # via torchrl -protobuf==4.25.3 - # via wandb -psutil==5.9.8 - # via wandb pytorch-lightning==2.2.3 # via lightning pyyaml==6.0.1 # via lightning - # via omegaconf # via pytorch-lightning - # via wandb -requests==2.31.0 - # via wandb -sentry-sdk==2.0.1 - # via wandb -setproctitle==1.3.3 - # via wandb setuptools==69.5.1 # via lightning-utilities - # via wandb -six==1.16.0 - # via docker-pycreds -smmap==5.0.1 - # via gitdb sympy==1.12 # via torch tensordict==0.4.0 @@ -166,10 +125,5 @@ typing-extensions==4.11.0 # via lightning-utilities # via pytorch-lightning # via torch -urllib3==2.2.1 - # via requests - # via sentry-sdk -wandb==0.16.6 - # via rssm yarl==1.9.4 # via aiohttp diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index 92e6543..0000000 --- a/scripts/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Scripts for the project. - -When executing the code, always call the code in this directory. -""" diff --git a/src/rssm/core.py b/src/rssm/core.py index aed1f15..b231e86 100644 --- a/src/rssm/core.py +++ b/src/rssm/core.py @@ -1,22 +1,17 @@ """Reccurent State Space Model (RSSM).""" -import tempfile -from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TypeAlias -import torch -import wandb from distribution_extension import kl_divergence -from hydra.utils import instantiate from lightning import LightningModule from torch import Tensor, nn -from rssm.custom_types import DataGroup, LossDict +from rssm.networks import Representation, Transition from rssm.objective import likelihood from rssm.state import State, stack_states -if TYPE_CHECKING: - from rssm.networks import Representation, Transition +DataGroup: TypeAlias = tuple[Tensor, Tensor, Tensor, Tensor] +LossDict: TypeAlias = dict[str, Tensor] class RSSM(LightningModule): @@ -31,14 +26,14 @@ class RSSM(LightningModule): Parameters ---------- - representation : DictConfig of Representation + representation : Representation Representation model (Approx. Posterior). - transition : DictConfig of Transition + transition : Transition Transition model (Prior). - encoder : DictConfig of nn.Module + encoder : nn.Module Observation encoder. I/O: [*B, C, H, W] -> [*B, obs_embed_size]. - decoder : DictConfig of nn.Module + decoder : nn.Module Observation decoder. I/O: [*B, obs_embed_size] -> [*B, C, H, W]. init_proj : DictConfig of nn.Module @@ -54,21 +49,20 @@ class RSSM(LightningModule): def __init__( self, *, - representation: dict[str, Any], - transition: dict[str, Any], - encoder: dict[str, Any], - decoder: dict[str, Any], - init_proj: dict[str, Any], + representation: Representation, + transition: Transition, + encoder: nn.Module, + decoder: nn.Module, + init_proj: nn.Module, kl_coeff: float, use_kl_balancing: bool, ) -> None: super().__init__() - self.save_hyperparameters() - self.representation: Representation = instantiate(representation) - self.transition: Transition = instantiate(transition) - self.encoder: nn.Module = instantiate(encoder) - self.decoder: nn.Module = instantiate(decoder) - self.init_proj: nn.Module = instantiate(init_proj) + self.representation = representation + self.transition = transition + self.encoder = encoder + self.decoder = decoder + self.init_proj = init_proj self.kl_coeff = kl_coeff self.use_kl_balancing = use_kl_balancing @@ -178,18 +172,3 @@ def shared_step(self, batch: DataGroup) -> LossDict: "recon": recon_loss, "kl": kl_div, } - - @classmethod - def load_from_wandb(cls, reference: str) -> "RSSM": - """Load the model from wandb checkpoint.""" - run = wandb.Api().artifact(reference) # type: ignore[no-untyped-call] - with tempfile.TemporaryDirectory() as tmpdir: - ckpt = Path(run.download(root=tmpdir)) - model = cls.load_from_checkpoint( - checkpoint_path=ckpt / "model.ckpt", - map_location=torch.device("cpu"), - ) - if not isinstance(model, cls): - msg = f"Model is not an instance of {cls}" - raise TypeError(msg) - return model diff --git a/src/rssm/custom_types.py b/src/rssm/custom_types.py deleted file mode 100644 index cf13ef5..0000000 --- a/src/rssm/custom_types.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Custom types for RSSM.""" - -from typing import TypeAlias - -from torch import Tensor - -Slice = slice | int | tuple[slice | int, ...] - - -DataGroup: TypeAlias = tuple[Tensor, Tensor, Tensor, Tensor] -LossDict: TypeAlias = dict[str, Tensor] diff --git a/src/rssm/state.py b/src/rssm/state.py index 8f6cae5..62d1ca8 100644 --- a/src/rssm/state.py +++ b/src/rssm/state.py @@ -7,7 +7,7 @@ from distribution_extension.utils import cat_distribution, stack_distribution from torch import Tensor -from rssm.custom_types import Slice +Slice = slice | int | tuple[slice | int, ...] class State: diff --git a/src/rssm/visualize.py b/src/rssm/visualize.py deleted file mode 100644 index 444fbd9..0000000 --- a/src/rssm/visualize.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Data visualization utilities.""" - -import matplotlib.pyplot as plt -import torch -from einops import pack, unpack -from matplotlib import figure -from torch import Tensor, uint8 -from wandb import Image, Video - - -def visualize_2d_data( - data: Tensor, - indices: list[int], - x_label: str, - y_label: str, -) -> figure.Figure: - """Visualize 2D data.""" - fig, axe = plt.subplots(figsize=(8, 8), tight_layout=True) - for idx in indices: - axe.plot(data[idx, :, 0], data[idx, :, 1], alpha=0.5) - axe.set_xlabel(xlabel=x_label) - axe.set_ylabel(ylabel=y_label) - return fig - - -def pca(data: Tensor, n_components: int = 2) -> tuple[Tensor, Tensor]: - """ - Apply PCA on 2D+ Tensor. - - References - ---------- - * https://pytorch.org/docs/stable/generated/torch.pca_lowrank.html - - Returns - ------- - Tensor - PCA-transformed data. Tensor shaped [batch*, n_components]. - Tensor - Explained variance ratio. Tensor shaped [n_components]. - - """ - data, ps = pack([data], "* d") - _, s, v = torch.pca_lowrank(data, q=n_components) - [data_pca] = unpack(torch.matmul(data, v), ps, "* d") - ratio = (s**2) / (data.shape[0] - 1) / data.var(dim=0).sum() - return data_pca, ratio - - -def to_wandb_images(tensors: Tensor) -> list[Image]: - """Convert batched image tensor to wandb images.""" - tensors = tensors.detach().cpu() - if tensors.dtype == uint8: - tensors = tensors.float() / 255 - return [Image(tensor) for tensor in tensors] - - -def to_pca_wandb_image(tensor: Tensor, indices: list[int]) -> Image: - """Apply PCA on 2D+ Tensor and convert to wandb image.""" - tensor_pca, variance_ratio = pca(tensor.detach().cpu()) - fig = visualize_2d_data( - data=tensor_pca, - indices=indices, - x_label=f"PC1({variance_ratio[0]:.2f})", - y_label=f"PC2({variance_ratio[1]:.2f})", - ) - return Image(fig) - - -def to_wandb_movie(tensor: Tensor, fps: float) -> Video: - """Convert image tensor to wandb video.""" - tensor = tensor.detach().cpu().mul(255) - return Video(tensor.to(dtype=uint8), fps=fps) # type: ignore[arg-type] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..2f78f5a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Unittests.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..102718b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,67 @@ +"""Common constants and fixtures in the test.""" + +import pytest +import torch +from distribution_extension import MultiOneHot, Normal +from torch import Tensor + +from rssm.state import State + +BATCH_SIZE = 4 +SEQ_LEN = 8 +DETERMINISTIC_SIZE = 64 +STOCHASTIC_SIZE = 16 +CATEGORY_SIZE = 4 +CLASS_SIZE = 4 +ACTION_SIZE = 15 +HIDDEN_SIZE = 32 +OBS_EMBED_SIZE = 7 + + +@pytest.fixture() +def action_bd() -> Tensor: + """Create a batch of actions.""" + return torch.rand(BATCH_SIZE, ACTION_SIZE) + + +@pytest.fixture() +def observation_bd() -> Tensor: + """Create a batch of observations.""" + return torch.rand(BATCH_SIZE, 3, 64, 64) + + +@pytest.fixture() +def obs_embed_bd() -> Tensor: + """Create a batch of observation embeddings.""" + return torch.rand(BATCH_SIZE, OBS_EMBED_SIZE) + + +@pytest.fixture() +def state_bd() -> State: + """Create a batch of states(continuous).""" + deter = torch.rand(BATCH_SIZE, DETERMINISTIC_SIZE) + mean = torch.rand(BATCH_SIZE, STOCHASTIC_SIZE) + std = torch.rand(BATCH_SIZE, STOCHASTIC_SIZE) + distribution = Normal(mean, std) + return State(deter=deter, distribution=distribution) + + +@pytest.fixture() +def state_discrete_bd() -> State: + """Create a batch of states(discrete).""" + deter = torch.rand(BATCH_SIZE, DETERMINISTIC_SIZE) + logit = torch.rand(BATCH_SIZE, CATEGORY_SIZE, CLASS_SIZE) + distribution = MultiOneHot(logit) + return State(deter=deter, distribution=distribution) + + +@pytest.fixture() +def action_bld() -> Tensor: + """Create a batch of actions.""" + return torch.rand(BATCH_SIZE, SEQ_LEN, ACTION_SIZE) + + +@pytest.fixture() +def observation_bld() -> Tensor: + """Create a batch of observations.""" + return torch.rand(BATCH_SIZE, SEQ_LEN, 3, 64, 64) diff --git a/tests/test__core.py b/tests/test__core.py new file mode 100644 index 0000000..1062361 --- /dev/null +++ b/tests/test__core.py @@ -0,0 +1,210 @@ +"""Tests of `core.py`.""" + +import pytest +from torch import Tensor, nn, rand + +from rssm.core import RSSM +from rssm.networks import ( + RepresentationV1, + RepresentationV2, + TransitionV1, + TransitionV2, +) +from rssm.state import State +from tests.conftest import ( + ACTION_SIZE, + BATCH_SIZE, + CATEGORY_SIZE, + CLASS_SIZE, + DETERMINISTIC_SIZE, + HIDDEN_SIZE, + OBS_EMBED_SIZE, + SEQ_LEN, + STOCHASTIC_SIZE, +) + + +class DummyEncoder(nn.Module): + """A dummy encoder for testing.""" + + def forward(self, observation: Tensor) -> Tensor: + """Encode observation([*B, C, H, W] -> [*B, D]).""" + if observation.ndim == 4: + return rand(BATCH_SIZE, OBS_EMBED_SIZE) + return rand(BATCH_SIZE, SEQ_LEN, OBS_EMBED_SIZE) + + +class DummyDecoder(nn.Module): + """A dummy decoder for testing.""" + + def forward(self, feature: Tensor) -> Tensor: + """Decode feature([*B, D] -> [*B, C, H, W]).""" + if feature.ndim == 2: + return rand(BATCH_SIZE, 3, 64, 64) + return rand(BATCH_SIZE, SEQ_LEN, 3, 64, 64) + + +@pytest.fixture() +def continuous_rssm() -> RSSM: + """Create a continuous RSSM instance.""" + representation = RepresentationV1( + deterministic_size=DETERMINISTIC_SIZE, + stochastic_size=STOCHASTIC_SIZE, + hidden_size=HIDDEN_SIZE, + obs_embed_size=OBS_EMBED_SIZE, + activation_name="ReLU", + ) + transition = TransitionV1( + action_size=ACTION_SIZE, + deterministic_size=DETERMINISTIC_SIZE, + stochastic_size=STOCHASTIC_SIZE, + hidden_size=HIDDEN_SIZE, + activation_name="ReLU", + ) + init_proj = nn.Linear(OBS_EMBED_SIZE, DETERMINISTIC_SIZE) + return RSSM( + representation=representation, + transition=transition, + encoder=DummyEncoder(), + decoder=DummyDecoder(), + init_proj=init_proj, + kl_coeff=1.0, + use_kl_balancing=False, + ) + + +@pytest.fixture() +def discrete_rssm() -> RSSM: + """Create a discrete RSSM instance.""" + representation = RepresentationV2( + deterministic_size=DETERMINISTIC_SIZE, + category_size=CATEGORY_SIZE, + class_size=CLASS_SIZE, + hidden_size=HIDDEN_SIZE, + obs_embed_size=OBS_EMBED_SIZE, + activation_name="ReLU", + ) + transition = TransitionV2( + action_size=ACTION_SIZE, + deterministic_size=DETERMINISTIC_SIZE, + category_size=CATEGORY_SIZE, + class_size=CLASS_SIZE, + hidden_size=HIDDEN_SIZE, + activation_name="ReLU", + ) + init_proj = nn.Linear(OBS_EMBED_SIZE, DETERMINISTIC_SIZE) + return RSSM( + representation=representation, + transition=transition, + encoder=DummyEncoder(), + decoder=DummyDecoder(), + init_proj=init_proj, + kl_coeff=1.0, + use_kl_balancing=True, + ) + + +def test__initial_state( + continuous_rssm: RSSM, + discrete_rssm: RSSM, + observation_bd: Tensor, +) -> None: + """Test `initial_state` method.""" + state = continuous_rssm.initial_state(observation_bd) + assert state.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert state.stoch.shape == (BATCH_SIZE, STOCHASTIC_SIZE) + + state = discrete_rssm.initial_state(observation_bd) + assert state.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert state.stoch.shape == (BATCH_SIZE, CATEGORY_SIZE * CLASS_SIZE) + + +def test__rollout_representation( + action_bld: Tensor, + observation_bld: Tensor, + state_bd: State, + continuous_rssm: RSSM, + discrete_rssm: RSSM, +) -> None: + """Test `rollout_representation` method.""" + prior, posterior = continuous_rssm.rollout_representation( + actions=action_bld, + observations=observation_bld, + prev_state=state_bd, + ) + assert prior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE) + assert posterior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert posterior.stoch.shape == (BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE) + + prior, posterior = discrete_rssm.rollout_representation( + actions=action_bld, + observations=observation_bld, + prev_state=state_bd, + ) + feature_size = CATEGORY_SIZE * CLASS_SIZE + assert prior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, SEQ_LEN, feature_size) + assert posterior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert posterior.stoch.shape == (BATCH_SIZE, SEQ_LEN, feature_size) + + +def test__rollout_transition( + action_bld: Tensor, + state_bd: State, + continuous_rssm: RSSM, + discrete_rssm: RSSM, +) -> None: + """Test `rollout_transition` method.""" + prior = continuous_rssm.rollout_transition( + actions=action_bld, + prev_state=state_bd, + ) + assert prior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE) + + prior = discrete_rssm.rollout_transition( + actions=action_bld, + prev_state=state_bd, + ) + feature_size = CATEGORY_SIZE * CLASS_SIZE + assert prior.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, SEQ_LEN, feature_size) + + +def test__training_step( + action_bld: Tensor, + observation_bld: Tensor, + continuous_rssm: RSSM, + discrete_rssm: RSSM, +) -> None: + """Test `training_step` method.""" + batch = (action_bld, observation_bld, action_bld, observation_bld) + loss = continuous_rssm.training_step(batch) + assert "loss" in loss + assert "kl" in loss + assert "recon" in loss + + loss = discrete_rssm.training_step(batch) + assert "loss" in loss + assert "kl" in loss + assert "recon" in loss + + +def test__validation_step( + action_bld: Tensor, + observation_bld: Tensor, + continuous_rssm: RSSM, + discrete_rssm: RSSM, +) -> None: + """Test `validation_step` method.""" + batch = (action_bld, observation_bld, action_bld, observation_bld) + loss = continuous_rssm.validation_step(batch, 0) + assert "val_loss" in loss + assert "val_kl" in loss + assert "val_recon" in loss + + loss = discrete_rssm.validation_step(batch, 0) + assert "val_loss" in loss + assert "val_kl" in loss + assert "val_recon" in loss diff --git a/tests/test__networks.py b/tests/test__networks.py new file mode 100644 index 0000000..9d43228 --- /dev/null +++ b/tests/test__networks.py @@ -0,0 +1,101 @@ +"""Tests of `networks.py`.""" + + +from torch import Tensor + +from rssm import State +from rssm.networks import ( + RepresentationV1, + RepresentationV2, + TransitionV1, + TransitionV2, +) +from tests.conftest import ( + ACTION_SIZE, + BATCH_SIZE, + CATEGORY_SIZE, + CLASS_SIZE, + DETERMINISTIC_SIZE, + HIDDEN_SIZE, + OBS_EMBED_SIZE, + STOCHASTIC_SIZE, +) + + +def test__representation_v1( + obs_embed_bd: Tensor, + state_bd: State, +) -> None: + """Test the RepresentationV1 class and `fowrard()` method.""" + representation = RepresentationV1( + deterministic_size=DETERMINISTIC_SIZE, + stochastic_size=STOCHASTIC_SIZE, + hidden_size=HIDDEN_SIZE, + obs_embed_size=OBS_EMBED_SIZE, + activation_name="ReLU", + ) + posterior = representation.forward( + obs_embed=obs_embed_bd, + prior_state=state_bd, + ) + assert posterior.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert posterior.stoch.shape == (BATCH_SIZE, STOCHASTIC_SIZE) + + +def test__transition_v1(action_bd: Tensor, state_bd: State) -> None: + """Test the TransitionV1 class and `fowrard()` method.""" + transition = TransitionV1( + action_size=ACTION_SIZE, + deterministic_size=DETERMINISTIC_SIZE, + stochastic_size=STOCHASTIC_SIZE, + hidden_size=HIDDEN_SIZE, + activation_name="ReLU", + ) + prior = transition.forward( + action=action_bd, + prev_state=state_bd, + ) + assert prior.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, STOCHASTIC_SIZE) + + +def test__representation_v2( + obs_embed_bd: Tensor, + state_discrete_bd: State, +) -> None: + """Test the RepresentationV2 class and `fowrard()` method.""" + representation = RepresentationV2( + deterministic_size=DETERMINISTIC_SIZE, + category_size=CATEGORY_SIZE, + class_size=CLASS_SIZE, + hidden_size=HIDDEN_SIZE, + obs_embed_size=OBS_EMBED_SIZE, + activation_name="ReLU", + ) + posterior = representation.forward( + obs_embed=obs_embed_bd, + prior_state=state_discrete_bd, + ) + assert posterior.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert posterior.stoch.shape == (BATCH_SIZE, CATEGORY_SIZE * CLASS_SIZE) + + +def test__transition_v2( + action_bd: Tensor, + state_discrete_bd: State, +) -> None: + """Test the TransitionV2 class and `fowrard()` method.""" + transition = TransitionV2( + action_size=ACTION_SIZE, + deterministic_size=DETERMINISTIC_SIZE, + category_size=CATEGORY_SIZE, + class_size=CLASS_SIZE, + hidden_size=HIDDEN_SIZE, + activation_name="ReLU", + ) + prior = transition.forward( + action=action_bd, + prev_state=state_discrete_bd, + ) + assert prior.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert prior.stoch.shape == (BATCH_SIZE, CATEGORY_SIZE * CLASS_SIZE) diff --git a/tests/test__objective.py b/tests/test__objective.py new file mode 100644 index 0000000..a9d2fdd --- /dev/null +++ b/tests/test__objective.py @@ -0,0 +1,12 @@ +"""Tests of `objective.py`.""" + +from torch import Tensor + +from rssm.objective import likelihood + + +def test__likelihood(observation_bld: Tensor) -> None: + """Test the `likelihood()` function.""" + prediction = target = observation_bld + loss = likelihood(prediction, target, event_ndims=3) + assert loss.shape == () diff --git a/tests/test__state.py b/tests/test__state.py new file mode 100644 index 0000000..7df422e --- /dev/null +++ b/tests/test__state.py @@ -0,0 +1,125 @@ +"""Tests for `state.py`.""" + +import pytest +import torch +from distribution_extension import Normal + +from rssm.state import State, cat_states, stack_states +from tests.conftest import ( + BATCH_SIZE, + DETERMINISTIC_SIZE, + SEQ_LEN, + STOCHASTIC_SIZE, +) + + +@pytest.fixture() +def state() -> State: + """Create a State instance.""" + deter = torch.rand( + BATCH_SIZE, + SEQ_LEN, + DETERMINISTIC_SIZE, + requires_grad=True, + ) + mean = torch.rand(BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE, requires_grad=True) + std = torch.rand(BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE, requires_grad=True) + distribution = Normal(mean, std) + return State(deter=deter, distribution=distribution) + + +def test_init(state: State) -> None: + """Test the __init__ method.""" + assert state.deter.shape == (BATCH_SIZE, SEQ_LEN, DETERMINISTIC_SIZE) + assert state.stoch.shape == (BATCH_SIZE, SEQ_LEN, STOCHASTIC_SIZE) + assert state.feature.shape == ( + BATCH_SIZE, + SEQ_LEN, + DETERMINISTIC_SIZE + STOCHASTIC_SIZE, + ) + + +def test__iter__(state: State) -> None: + """Test the __iter__ method.""" + for s in state: + assert s.deter.shape == (SEQ_LEN, DETERMINISTIC_SIZE) + assert s.stoch.shape == (SEQ_LEN, STOCHASTIC_SIZE) + assert s.feature.shape == ( + SEQ_LEN, + DETERMINISTIC_SIZE + STOCHASTIC_SIZE + ) + + +def test__getitem__(state: State) -> None: + """Test the __getitem__ method.""" + for t in range(SEQ_LEN): + s = state[:, t] + assert s.deter.shape == (BATCH_SIZE, DETERMINISTIC_SIZE) + assert s.stoch.shape == (BATCH_SIZE, STOCHASTIC_SIZE) + assert s.feature.shape == ( + BATCH_SIZE, + DETERMINISTIC_SIZE + STOCHASTIC_SIZE + ) + + +def test__to(state: State) -> None: + """Test the to method.""" + state = state.to(torch.device("cpu")) + assert state.deter.device == torch.device("cpu") + assert state.stoch.device == torch.device("cpu") + + +def test__detach(state: State) -> None: + """Test the detach method.""" + assert state.deter.requires_grad is True + assert state.stoch.requires_grad is True + state = state.detach() + assert state.deter.requires_grad is False + assert state.stoch.requires_grad is False + + +def test__stack_state(state: State) -> None: + """Test the stack_states function.""" + num_states = 2 + states = [state] * num_states + state = stack_states(states, dim=1) + assert state.deter.shape == ( + BATCH_SIZE, + num_states, + SEQ_LEN, + DETERMINISTIC_SIZE, + ) + assert state.stoch.shape == ( + BATCH_SIZE, + num_states, + SEQ_LEN, + STOCHASTIC_SIZE, + ) + assert state.feature.shape == ( + BATCH_SIZE, + num_states, + SEQ_LEN, + DETERMINISTIC_SIZE + STOCHASTIC_SIZE, + ) + + +def test__cat_state(state: State) -> None: + """Test the cat_states function.""" + num_states = 2 + states = [state] * num_states + state = cat_states(states, dim=1) + assert state.deter.shape == ( + BATCH_SIZE, + num_states * SEQ_LEN, + DETERMINISTIC_SIZE, + ) + assert state.stoch.shape == ( + BATCH_SIZE, + num_states * SEQ_LEN, + STOCHASTIC_SIZE, + ) + assert state.feature.shape == ( + BATCH_SIZE, + num_states * SEQ_LEN, + DETERMINISTIC_SIZE + STOCHASTIC_SIZE, + )