Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] examplesを作成 #9

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
name: CI

on:
push:
branches:
- main
pull_request:
branches:
- main


jobs:
ruff:
ci:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
- name: Set up rye
uses: eifinger/setup-rye@v3
- name: Install dependencies
run: |
rye config --set-bool behavior.use-uv=true
rye sync --no-lock
- name: Run lint
run: |
rye lint
- name: Run tests
run: |
rye run cov
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
![python](https://img.shields.io/badge/python-3.10-blue)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![CI](https://github.com/nomutin/RSSM/actions/workflows/ci.yaml/badge.svg)](https://github.com/nomutin/RSSM/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/nomutin/RSSM/graph/badge.svg?token=YMR2H87R5C)](https://codecov.io/gh/nomutin/RSSM)

[RSSMs](https://danijar.com/project/dreamer/) for imitation learning.

Expand Down
1 change: 1 addition & 0 deletions example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Training/Evaluation examples."""
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
32 changes: 16 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
[project]
name = "rssm"
version = "0.1.0"
version = "0.1.2"
description = "Reccurent State-Space Model"
dependencies = [
"lightning>=2.2.1",
"wandb>=0.16.4",
"torchrl>=0.3.1",
"hydra-core>=1.3.2",
"distribution-extension @ git+https://github.com/nomutin/distribution-extension.git",
]
readme = "README.md"
Expand All @@ -21,16 +19,7 @@ managed = true
dev-dependencies = [
"mypy>=1.9.0",
"ruff>=0.4.2",
"kornia>=0.7.2",
"imageio>=2.34.1",
"moviepy>=1.0.3",
"gdown>=5.1.0",
"matplotlib>=3.8.4",
"rich>=13.7.1",
"einops>=0.8.0",
"torchvision>=0.18.0",
"jsonargparse[signatures]>=4.27.7",
"cnn @ git+https://github.com/nomutin/CNN",
"pytest-cov>=5.0.0",
]

[tool.hatch.metadata]
Expand All @@ -39,12 +28,21 @@ allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/rssm"]

[tool.pytest.ini_options]
filterwarnings = [
"ignore::UserWarning",
"ignore::DeprecationWarning",
]

[tool.rye.scripts]
cov = "pytest -ra --cov=src --cov-report=term --cov-report=xml"

[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true

[tool.ruff]
line-length = 79
line-length = 80
target-version = "py310"

[tool.ruff.lint]
Expand All @@ -69,8 +67,10 @@ known-first-party = ["rssm"]
[tool.ruff.lint.per-file-ignores]
"src/rssm/core.py" = ["PLR0913"]
"src/rssm/networks.py" = ["PLR0913"]
"src/rssm/dataset.py" = ["PLR0913"]
"src/rssm/callback.py" = ["SLF001"]
"example/dataset.py" = ["PLR0913"]
"example/callback.py" = ["SLF001"]
"tests/*.py" = ["S101"]
"tests/test__core.py" = ["PLR6301", "PLR2004"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"
152 changes: 14 additions & 138 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13,98 +13,45 @@ aiohttp==3.9.5
# via fsspec
aiosignal==1.3.1
# via aiohttp
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
appdirs==1.4.4
# via wandb
async-timeout==4.0.3
# via aiohttp
attrs==23.2.0
# via aiohttp
beautifulsoup4==4.12.3
# via gdown
certifi==2024.2.2
# via requests
# via sentry-sdk
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via wandb
cloudpickle==3.0.0
# via tensordict
# via torchrl
cnn @ git+https://github.com/nomutin/CNN@589f47eb1f269de4532d25686b6fe2ea880711d3
contourpy==1.2.1
# via matplotlib
cycler==0.12.1
# via matplotlib
decorator==4.4.2
# via moviepy
coverage==7.5.4
# via pytest-cov
distribution-extension @ git+https://github.com/nomutin/distribution-extension.git@0b5c0cdf5bd19f6f21e373b24d5139deffe93c98
# via rssm
docker-pycreds==0.4.0
# via wandb
docstring-parser==0.16
# via jsonargparse
einops==0.8.0
# via cnn
# via distribution-extension
exceptiongroup==1.2.1
# via pytest
filelock==3.14.0
# via gdown
# via huggingface-hub
# via torch
# via triton
fonttools==4.51.0
# via matplotlib
frozenlist==1.4.1
# via aiohttp
# via aiosignal
fsspec==2024.3.1
# via huggingface-hub
# via lightning
# via pytorch-lightning
# via torch
gdown==5.1.0
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via wandb
huggingface-hub==0.23.4
# via timm
hydra-core==1.3.2
# via rssm
idna==3.7
# via requests
# via yarl
imageio==2.34.1
# via moviepy
imageio-ffmpeg==0.4.9
# via moviepy
importlib-resources==6.4.0
# via typeshed-client
iniconfig==2.0.0
# via pytest
jinja2==3.1.3
# via torch
jsonargparse==4.28.0
kiwisolver==1.4.5
# via matplotlib
kornia==0.7.2
kornia-rs==0.1.3
# via kornia
lightning==2.2.3
# via rssm
lightning-utilities==0.11.2
# via lightning
# via pytorch-lightning
# via torchmetrics
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.4
mdurl==0.1.2
# via markdown-it-py
moviepy==1.0.3
mpmath==1.3.0
# via sympy
multidict==6.0.5
Expand All @@ -116,16 +63,11 @@ mypy-extensions==1.0.0
networkx==3.3
# via torch
numpy==1.26.4
# via contourpy
# via imageio
# via lightning
# via matplotlib
# via moviepy
# via pytorch-lightning
# via tensordict
# via torchmetrics
# via torchrl
# via torchvision
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
Expand Down Expand Up @@ -154,121 +96,55 @@ nvidia-nvjitlink-cu12==12.4.127
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
omegaconf==2.3.0
# via hydra-core
packaging==24.0
# via huggingface-hub
# via hydra-core
# via kornia
# via lightning
# via lightning-utilities
# via matplotlib
# via pytest
# via pytorch-lightning
# via torchmetrics
# via torchrl
pillow==10.3.0
# via imageio
# via matplotlib
# via torchvision
proglog==0.1.10
# via moviepy
protobuf==4.25.3
# via wandb
psutil==5.9.8
# via wandb
pygments==2.18.0
# via rich
pyparsing==3.1.2
# via matplotlib
pysocks==1.7.1
# via requests
python-dateutil==2.9.0.post0
# via matplotlib
pluggy==1.5.0
# via pytest
pytest==8.2.2
# via pytest-cov
pytest-cov==5.0.0
pytorch-lightning==2.2.3
# via lightning
pyyaml==6.0.1
# via huggingface-hub
# via jsonargparse
# via lightning
# via omegaconf
# via pytorch-lightning
# via timm
# via wandb
requests==2.31.0
# via gdown
# via huggingface-hub
# via moviepy
# via wandb
rich==13.7.1
ruff==0.4.2
safetensors==0.4.3
# via timm
sentry-sdk==2.0.1
# via wandb
setproctitle==1.3.3
# via wandb
setuptools==69.5.1
# via imageio-ffmpeg
# via lightning-utilities
# via wandb
six==1.16.0
# via docker-pycreds
# via python-dateutil
smmap==5.0.1
# via gitdb
soupsieve==2.5
# via beautifulsoup4
sympy==1.12
# via torch
tensordict==0.4.0
# via torchrl
timm==1.0.7
# via cnn
tomli==2.0.1
# via coverage
# via mypy
# via pytest
torch==2.3.0
# via cnn
# via kornia
# via lightning
# via pytorch-lightning
# via tensordict
# via timm
# via torchgeometry
# via torchmetrics
# via torchrl
# via torchvision
torchgeometry==0.1.2
# via cnn
torchmetrics==1.3.2
# via lightning
# via pytorch-lightning
torchrl==0.4.0
# via rssm
torchvision==0.18.0
# via timm
tqdm==4.66.2
# via gdown
# via huggingface-hub
# via lightning
# via moviepy
# via proglog
# via pytorch-lightning
triton==2.3.0
# via torch
typeshed-client==2.5.1
# via jsonargparse
typing-extensions==4.11.0
# via huggingface-hub
# via lightning
# via lightning-utilities
# via mypy
# via pytorch-lightning
# via torch
# via typeshed-client
urllib3==2.2.1
# via requests
# via sentry-sdk
wandb==0.16.6
# via rssm
yarl==1.9.4
# via aiohttp
Loading
Loading