diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 0000000..62e728a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,48 @@ +name: Bug Report +description: Submit a bug report +title: "[Bug Report] Bug title" +labels: ["bug"] +body: + - type: textarea + id: description + attributes: + label: Describe the bug + description: A clear and concise description of what the bug is. + validations: + required: true + + - type: textarea + id: code-example + attributes: + label: Code example + description: | + Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. + This will be automatically formatted into code, so no need for backticks. + render: shell + + - type: textarea + id: system-info + attributes: + label: System info + description: | + Describe the characteristic of your environment: + * Describe how CogmentLab was installed (pip, docker, source, ...) + * Version of `cogment_lab` (by `cogment_lab.__version__`) + * What OS/version of Linux you're using. Note that while we will accept PRs to improve Window's support, we do not officially support it. + * Python version + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any other context about the problem here. + + - type: checkboxes + id: checklist + attributes: + label: Checklist + options: + - label: > + I have checked that there is no similar [issue](https://github.com/cogment/cogment_lab/issues) in + the repo + required: true diff --git a/.github/ISSUE_TEMPLATE/proposal.yml b/.github/ISSUE_TEMPLATE/proposal.yml new file mode 100644 index 0000000..0e6fe10 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/proposal.yml @@ -0,0 +1,49 @@ +name: Proposal +description: Propose changes that are not fixing bugs +title: "[Proposal] Proposal title" +labels: ["enhancement"] +body: + - type: textarea + id: proposal + attributes: + label: Proposal + description: A clear and concise description of the proposal. + validations: + required: true + + - type: textarea + id: motivation + attributes: + label: Motivation + description: | + Please outline the motivation for the proposal. + Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". + If this is related to another GitHub issue, please link here too. + + - type: textarea + id: pitch + attributes: + label: Pitch + description: A clear and concise description of what you want to happen. + + - type: textarea + id: alternatives + attributes: + label: Alternatives + description: A clear and concise description of any alternative solutions or features you've considered, if any. + + - type: textarea + id: additional-context + attributes: + label: Additional context + description: Add any other context or screenshots about the feature request here. + + - type: checkboxes + id: checklist + attributes: + label: Checklist + options: + - label: > + I have checked that there is no similar [issue](https://github.com/cogment/cogment_lab/issues) in + the repo + required: true diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml new file mode 100644 index 0000000..9725bc4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -0,0 +1,21 @@ +name: Question +description: Ask a question +title: "[Question] Question title" +labels: ["question"] +body: + - type: markdown + attributes: + value: > + If you're a beginner and have basic questions, please ask on + [r/reinforcementlearning](https://www.reddit.com/r/reinforcementlearning/) or in the + [RL Discord](https://discord.com/invite/xhfNqQv) (if you're new please use the beginners channel). + Basic questions that are not bugs or feature requests will be closed without reply, because GitHub + issues are not an appropriate venue for these. Advanced/nontrivial questions, especially in areas where + documentation is lacking, are very much welcome. + - type: textarea + id: question + attributes: + label: Question + description: Your question + validations: + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..df1dcee --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,46 @@ +# Description + +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. + +Fixes # (issue) + +## Type of change + +Please delete options that are not relevant. + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] This change requires a documentation update + +### Screenshots + +Please attach before and after screenshots of the change if applicable. + + + +# Checklist: + +- [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with `pre-commit run --all-files` (see `CONTRIBUTING.md` instructions to set it up) +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes + + diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000..9df5640 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,62 @@ +# Configuration for probot-stale - https://github.com/probot/stale + +# Number of days of inactivity before an Issue or Pull Request becomes stale +daysUntilStale: 60 + +# Number of days of inactivity before an Issue or Pull Request with the stale label is closed. +# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. +daysUntilClose: 14 + +# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) +onlyLabels: + - more-information-needed + +# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable +exemptLabels: + - pinned + - security + - "[Status] Maybe Later" + +# Set to true to ignore issues in a project (defaults to false) +exemptProjects: true + +# Set to true to ignore issues in a milestone (defaults to false) +exemptMilestones: true + +# Set to true to ignore issues with an assignee (defaults to false) +exemptAssignees: true + +# Label to use when marking as stale +staleLabel: stale + +# Comment to post when marking as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. + +# Comment to post when removing the stale label. +# unmarkComment: > +# Your comment here. + +# Comment to post when closing a stale Issue or Pull Request. +# closeComment: > +# Your comment here. + +# Limit the number of actions per hour, from 1-30. Default is 30 +limitPerRun: 30 + +# Limit to only `issues` or `pulls` +only: issues + +# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls': +# pulls: +# daysUntilStale: 30 +# markComment: > +# This pull request has been automatically marked as stale because it has not had +# recent activity. It will be closed if no further activity occurs. Thank you +# for your contributions. + +# issues: +# exemptLabels: +# - confirmed diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml new file mode 100644 index 0000000..46c27d1 --- /dev/null +++ b/.github/workflows/build-docs.yml @@ -0,0 +1,46 @@ +name: Build main branch documentation website +on: + push: + branches: [main] +permissions: + contents: write +jobs: + docs: + name: Generate Website + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install Gymnasium + run: pip install mujoco && pip install .[box2d] + + - name: Build Envs Docs + run: python docs/scripts/gen_mds.py && python docs/scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: main + clean: false diff --git a/.github/workflows/build-publish.yml b/.github/workflows/build-publish.yml new file mode 100644 index 0000000..dd96506 --- /dev/null +++ b/.github/workflows/build-publish.yml @@ -0,0 +1,68 @@ +# This workflow will build and (if release) publish Python distributions to PyPI +# For more information see: +# - https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +# - https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ +# +# derived from https://github.com/Farama-Foundation/PettingZoo/blob/e230f4d80a5df3baf9bd905149f6d4e8ce22be31/.github/workflows/build-publish.yml +name: build-publish + +on: + push: + branches: [main] + pull_request: + branches: [main] + release: + types: [published] + +jobs: + build-wheels: + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-latest + python: 37 + platform: manylinux_x86_64 + - os: ubuntu-latest + python: 38 + platform: manylinux_x86_64 + - os: ubuntu-latest + python: 39 + platform: manylinux_x86_64 + - os: ubuntu-latest + python: 310 + platform: manylinux_x86_64 + - os: ubuntu-latest + python: 311 + platform: manylinux_x86_64 + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + - name: Install dependencies + run: python -m pip install --upgrade pip setuptools build + - name: Build sdist and wheels + run: python -m build + - name: Store wheels + uses: actions/upload-artifact@v3 + with: + path: dist + + publish: + runs-on: ubuntu-latest + needs: + - build-wheels + if: github.event_name == 'release' && github.event.action == 'published' + steps: + - name: Download dists + uses: actions/download-artifact@v3 + with: + name: artifact + path: dist + - name: Publish + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..88b3e92 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,37 @@ +name: build +on: [pull_request, push] + +permissions: + contents: read # to fetch code (actions/checkout) + +jobs: + build-all: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10'] + steps: + - uses: actions/checkout@v3 + - run: | + docker build -f bin/all-py.Dockerfile \ + --build-arg PYTHON_VERSION=${{ matrix.python-version }} \ + --tag cogment_lab-all-docker . + - name: Start background services + run: docker run -d --name cogment_lab-test cogment_lab-all-docker /usr/local/bin/cogmentlab launch base + + - name: Run tests + run: docker run cogment_lab-all-docker pytest tests/* + + + build-necessary: + runs-on: + ubuntu-latest + steps: + - uses: actions/checkout@v3 + - run: | + docker build -f bin/necessary-py.Dockerfile \ + --build-arg PYTHON_VERSION='3.10' \ + --tag cogment_lab-necessary-docker . + - name: Run tests + run: | + docker run cogment_lab-necessary-docker pytest tests diff --git a/.github/workflows/docs-manual-versioning.yml b/.github/workflows/docs-manual-versioning.yml new file mode 100644 index 0000000..2156b5c --- /dev/null +++ b/.github/workflows/docs-manual-versioning.yml @@ -0,0 +1,71 @@ +name: Manual Docs Versioning +on: + workflow_dispatch: + inputs: + version: + description: 'Documentation version to create' + required: true + commit: + description: 'Commit used to build the Documentation version' + required: false + latest: + description: 'Latest version' + type: boolean + +permissions: + contents: write +jobs: + docs: + name: Generate Website for new version + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + if: inputs.commit == '' + + - uses: actions/checkout@v3 + if: inputs.commit != '' + with: + ref: ${{ inputs.commit }} + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install CogmentLab + run: pip install .[atari,accept-rom-license,box2d] + + - name: Build Envs Docs + run: python docs/scripts/gen_mds.py && python docs/scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: ${{ inputs.version }} + clean: false + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + if: inputs.latest + with: + folder: _build + clean-exclude: | + *.*.*/ + main diff --git a/.github/workflows/docs-versioning.yml b/.github/workflows/docs-versioning.yml new file mode 100644 index 0000000..54b50e0 --- /dev/null +++ b/.github/workflows/docs-versioning.yml @@ -0,0 +1,59 @@ +name: Docs Versioning +on: + push: + tags: + - 'v?*.*.*' +permissions: + contents: write +jobs: + docs: + name: Generate Website for new version + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Get tag + id: tag + uses: dawidd6/action-get-tag@v1 + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install CogmentLab + run: pip install .[atari,accept-rom-license,box2d] + + - name: Build Envs Docs + run: python docs/scripts/gen_mds.py && python docs/scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: ${{steps.tag.outputs.tag}} + clean: false + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + clean-exclude: | + *.*.*/ + main diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..80ce02a --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,21 @@ +# https://pre-commit.com +# This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. +name: pre-commit +on: + pull_request: + push: + branches: [main] + +permissions: + contents: read # to fetch code (actions/checkout) + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + - run: python -m pip install pre-commit + - run: python -m pre_commit --version + - run: python -m pre_commit install + - run: python -m pre_commit run --all-files diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..86adbeb --- /dev/null +++ b/.gitignore @@ -0,0 +1,45 @@ +*.swp +*.pyc +*.py~ +.DS_Store +.cache +.pytest_cache/ +__pycache__/ + +# Setuptools distribution and build folders. +/dist/ +/build +/wheels +/wheelhouse + +# Virtualenv +/env +/venv + +# Python egg metadata, regenerated from source files by setuptools. +/*.egg-info + +*.sublime-project +*.sublime-workspace + +logs/ + +.ipynb_checkpoints +ghostdriver.log + +junk +MUJOCO_LOG.txt + +rllab_mujoco + +tutorial/*.html + +# IDE files +.eggs +.tox + +# PyCharm project files +.idea +vizdoom.ini + +lib/ \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..a4bd6cd --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,9 @@ +stages: + - lint + +licenses_checker: + stage: lint + image: registry.gitlab.com/ai-r/cogment/license-checker:latest + script: + - license-checker + diff --git a/.license.yaml b/.license.yaml new file mode 100644 index 0000000..b71f6b0 --- /dev/null +++ b/.license.yaml @@ -0,0 +1,6 @@ +license: + copyright_year: 2024 + author: "AI Redefined Inc. " + ignore: + - "**/.venv" + - "**/cog_settings.py" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6714a99 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,68 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-symlinks + - id: destroyed-symlinks + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-ast + - id: check-added-large-files + - id: check-merge-conflict + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable + - id: detect-private-key + - id: debug-statements + - repo: https://github.com/codespell-project/codespell + rev: v2.2.4 + hooks: + - id: codespell + args: + - --ignore-words-list=reacher,ure,referenc,wile + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + - --ignore=E203,W503,E741 + - --max-complexity=30 + - --max-line-length=456 + - --show-source + - --statistics + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: ["--py37-plus"] + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/python/black + rev: 23.3.0 + hooks: + - id: black + - repo: https://github.com/pycqa/pydocstyle + rev: 6.3.0 + hooks: + - id: pydocstyle + args: + - --source + - --explain + - --convention=google + additional_dependencies: ["tomli"] + - repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: node + pass_filenames: false + types: [python] + additional_dependencies: ["pyright"] + args: + - --project=pyproject.toml diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cd482d8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but +not limited to compiled object code, generated documentation, +and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..66398af --- /dev/null +++ b/README.md @@ -0,0 +1,97 @@ +![cog-lab](https://github.com/RedTachyon/cogment_lab/assets/19414946/165557d0-fdf0-4d0a-99f1-3fc321fa194c) + +# Human + AI = ❤️ + + +## Docs | Blog | Discord + + +[![Package version](https://img.shields.io/pypi/v/cogment-lab?color=%23007ec6&label=pypi%20package)](https://pypi.org/project/cogment-lab) +[![Downloads](https://pepy.tech/badge/cogment-lab)](https://pepy.tech/project/cogment-lab) +[![Supported Python versions](https://img.shields.io/pypi/pyversions/cogment-lab.svg)](https://pypi.org/project/cogment-lab) +[![License - Apache 2.0](https://img.shields.io/badge/license-Apache_2.0-green)](https://github.com/cogment-lab/blob/main/LICENSE) +[![Follow @AI_Redefined](https://img.shields.io/twitter/follow/nestframework.svg?style=social&label=Follow%20@AI_Redefined)](https://twitter.com/AI_Redefined) +[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/) + +# Introduction + +Cogment Lab is a toolkit for doing HILL RL -- that is human-in-the-loop learning, with an emphasis on reinforcement learning. +It is based on [cogment](https://cogment.ai), a low-level framework for exchanging messages between +environments, AI agents and humans. +It's the perfect tool for when you want to interact with your environment yourself, and maybe even trained AI agents. + +# Cogment interaction model + +While it typically isn't necessary to interact with Cogment directly to use Cogment Lab, it is useful to understand the principles on which it operates. + +Cogment exchanges messages between environments and actor. These messages contain the observations, actions, rewards, and anything +else that you might want to keep track of. + +Interactions are split into Trials, which correspond to the typical notion of an episode in RL. Each trial has a unique ID, and + +## Cogment Lab at a glance + +Cogment Lab (as well as Cogment in general) follows a microservice-based architecture. +Each environment, agent, and human interface (collectively: service) is launched as a subprocess, and exchanges messages with the orchestrator, +which in turn ensures synchronization and correct routing of messages. + +Generally speaking, you don't need to worry about any of that - Cogment Lab conveniently covers up all the rough edges, +allowing you to do your research without worries. + +Cogment Lab is inherently asynchronous - but if you're not familiar with async python, don't worry about it. +The only things you need to remember are: +- Wrap your code in `async def main()` +- Run it with `asyncio.run(main())` +- When calling certain functions use the `await` keyword, e.g. `data = await cog.get_episode_data(...)` + +If you are familiar with async programming, there's a lot of interesting things you can do with it - go crazy. + + +## Terminology + +- A `service` is anything that interacts with the Cogment orchestrator. It can be an environment or an actor, including human actors. +- An `actor` in particular is the service that interacts with an environment, and ofter wraps an `agent`. The internal structure of an actor is entirely up to the user +- An `agent` is what we typically think of as an agent in RL - something that perceives its environment and acts upon it. We do not attempt to solve the agent foundation problem in this documentation. +- An `agent` is simultaneously the part of the environment that's taking an action - multiagent environments may have several agents, so we need to assign an actor to each agent. + + +## Known rough edges + +- When running the web UI, you can open the tab only once per launched process. So if you open the UI, you can run however many trials you want, as long as you don't close it. If you do close it, you should kill the process and start a new one. + + +## Local installation + +- Requires Python 3.10 +- Install requirements in a virtual env with somthing similar to the following + + ```console + $ python -m venv .venv + $ source .venv/bin/activate + $ pip install -r requirements.txt + $ pip install -e . + ``` +- For the examples you'll need to install the additional `examples_requirements.txt`. + + +### Apple silicon installation + +To run on M1/2/3 macs, you'll need to perform those additional steps + +``` +pip uninstall grpcio grpcio-tools +export GRPC_PYTHON_LDFLAGS=" -framework CoreFoundation" +pip install grpcio==1.48.2 grpcio-tools==1.48.2 --no-binary :all: +``` + + +## Usage + +Run `cogmentlab launch base`. + +Then, run whatever scripts or notebooks. + +Terminology: +- Model: a relatively raw PyTorch (or other?) model, inheriting from `nn.Module` +- Agent: a model wrapped in some utility class to interact with np arrays +- Actor: a cogment service that may involve models and/or actors diff --git a/TODO b/TODO new file mode 100644 index 0000000..fa6f4e4 --- /dev/null +++ b/TODO @@ -0,0 +1,19 @@ +Idea: +- Support native gymnasium VectorEnv when we finish the API +- Support launching the same environment in several processes as cogment-based vectorization + + + +TODO: +- Improve robustness of the UI? (reconnecting, etc.) +- Better UI builder (define buttons, sliders, etc.) +- More QoL functions (are any missing?) +- Clean up logging (probably need to define a logger somewhere in top level of the library? maybe in each file?) +- Enable running services in the main process (async; actors work, do the same for envs and web_ui) +- Maybe rename web_ui to human? +- Add more conversions (observer, teacher for PettingZoo) +- Add support for vectorized envs, vectorized actors +- Add support for parallel envs +- Do batched evaluation with vectorized envs, homogeneous pettingzoo parallel envs +- Connect4 integration +- Instead of a fancy BaseEnv/BaseActor, use Gymnasium EZPickle (or reimplement, it's super simple) \ No newline at end of file diff --git a/bin/all-py.Dockerfile b/bin/all-py.Dockerfile new file mode 100644 index 0000000..037a76f --- /dev/null +++ b/bin/all-py.Dockerfile @@ -0,0 +1,44 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG PYTHON_VERSION +FROM python:$PYTHON_VERSION + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +RUN apt-get -y update \ + && apt-get install --no-install-recommends -y \ + unzip \ + libglu1-mesa-dev \ + libgl1-mesa-dev \ + libosmesa6-dev \ + xvfb \ + patchelf \ + ffmpeg cmake \ + && apt-get autoremove -y \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY . /usr/local/cogment_lab/ +WORKDIR /usr/local/cogment_lab/ + +RUN pip install -r requirements.txt --no-cache-dir +RUN pip install .[all] --no-cache-dir + +ENV PATH="/usr/local/bin:${PATH}" + + +RUN /usr/local/bin/cogmentlab install + +ENTRYPOINT ["/usr/local/cogment_lab/bin/docker_entrypoint"] diff --git a/bin/docker_entrypoint b/bin/docker_entrypoint new file mode 100755 index 0000000..a5fe2c9 --- /dev/null +++ b/bin/docker_entrypoint @@ -0,0 +1,26 @@ +#!/bin/bash +# This script is the entrypoint for our Docker image. + +set -ex + +# Set up display; otherwise rendering will fail +Xvfb -screen 0 1024x768x24 & +export DISPLAY=:0 + +# Wait for the file to come up +display=0 +file="/tmp/.X11-unix/X$display" +for i in $(seq 1 10); do + if [ -e "$file" ]; then + break + fi + + echo "Waiting for $file to be created (try $i/10)" + sleep "$i" +done +if ! [ -e "$file" ]; then + echo "Timing out: $file was not created" + exit 1 +fi + +exec "$@" diff --git a/bin/necessary-py.Dockerfile b/bin/necessary-py.Dockerfile new file mode 100644 index 0000000..f9f709f --- /dev/null +++ b/bin/necessary-py.Dockerfile @@ -0,0 +1,38 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG PYTHON_VERSION +FROM python:$PYTHON_VERSION + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +RUN apt-get -y update \ + && apt-get install --no-install-recommends -y \ + unzip \ + libglu1-mesa-dev \ + libgl1-mesa-dev \ + libosmesa6-dev \ + xvfb \ + patchelf \ + ffmpeg cmake \ + && apt-get autoremove -y \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY . /usr/local/cogment_lab/ +WORKDIR /usr/local/cogment_lab/ + +RUN pip install .[dev] --no-cache-dir + +ENTRYPOINT ["/usr/local/cogment_lab/bin/docker_entrypoint"] diff --git a/cogment_lab/__init__.py b/cogment_lab/__init__.py new file mode 100644 index 0000000..c0d2dc3 --- /dev/null +++ b/cogment_lab/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import logging + +from cogment_lab.process_manager import Cogment +from cogment.utils import logger + +logger.addHandler(logging.NullHandler()) + +__version__ = "0.0.1" + + +__all__ = ["Cogment"] diff --git a/cogment_lab/actors/__init__.py b/cogment_lab/actors/__init__.py new file mode 100644 index 0000000..e3c0650 --- /dev/null +++ b/cogment_lab/actors/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .simple import RandomActor, ConstantActor diff --git a/cogment_lab/actors/configs.py b/cogment_lab/actors/configs.py new file mode 100644 index 0000000..782e122 --- /dev/null +++ b/cogment_lab/actors/configs.py @@ -0,0 +1,25 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict + + +class AgentConfig(TypedDict): + agent_name: str + agent_type: str # "cogment" or "custom" + + cogment_actor: str + + agent_class: str + agent_kwargs: dict diff --git a/cogment_lab/actors/nn_actor.py b/cogment_lab/actors/nn_actor.py new file mode 100644 index 0000000..ed64c67 --- /dev/null +++ b/cogment_lab/actors/nn_actor.py @@ -0,0 +1,86 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from cogment_lab.core import CogmentActor +from coltra import Observation, DAgent, CAgent +from coltra.models import BaseModel + + +class ColtraActor(CogmentActor): + def __init__(self, model: BaseModel): + super().__init__(model) + self.model = model + self.agent = DAgent(self.model) if self.model.discrete else CAgent(self.model) + + async def act(self, observation: np.ndarray, rendered_frame=None): + obs = Observation(vector=observation) + action, _, _ = self.agent.act(obs) + return action.discrete + + +class NNActor(CogmentActor): + def __init__(self, network: nn.Module, device: str = "cpu"): + super().__init__(network=network) + self.network = network + self.device = device + self.eps = 0.0 + self.num_actions: int | None = None + self.rng = np.random.default_rng(0) + + async def act(self, observation: np.ndarray, rendered_frame=None) -> int: + if self.num_actions is None: + observation = observation.copy() + obs = torch.from_numpy(observation).float().to(self.device) + [act_probs] = self.network(obs) + self.num_actions = act_probs.shape[0] + + if self.eps > 0.0 and self.rng.random() < self.eps: + return self.rng.integers(0, self.num_actions) + + observation = observation.copy() + obs = torch.from_numpy(observation).float().to(self.device) + [act_probs] = self.network(obs) + return act_probs.detach().cpu().numpy().argmax() + + def set_eps(self, eps: float): + self.eps = eps + + +class BoltzmannActor(CogmentActor): + def __init__(self, network: nn.Module, device: str = "cpu"): + super().__init__(network=network) + self.network = network + self.device = device + self.temperature = 1.0 + self.num_actions: int | None = None + self.rng = np.random.default_rng(0) + + async def act(self, observation: np.ndarray, rendered_frame=None) -> int: + observation = observation.copy() + obs = torch.from_numpy(observation).float().to(self.device) + with torch.no_grad(): + [act_vals] = self.network(obs) + act_probs = F.softmax(act_vals / self.temperature, dim=0) + + action = torch.multinomial(act_probs, 1).item() + + return action + + def set_temperature(self, temperature: float): + self.temperature = temperature diff --git a/cogment_lab/actors/runner.py b/cogment_lab/actors/runner.py new file mode 100644 index 0000000..8bc42e4 --- /dev/null +++ b/cogment_lab/actors/runner.py @@ -0,0 +1,53 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from multiprocessing import Queue + +import cogment + +from cogment_lab.generated import cog_settings +from cogment_lab.utils.runners import setup_logging +from cogment_lab.core import BaseActor + + +async def register_actor(actor: BaseActor, actor_name: str, queue: Queue, port: int = 9002): + context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") + + context.register_actor(impl=actor.impl, impl_name=actor_name, actor_classes=["player"]) + + serve = context.serve_all_registered(cogment.ServedEndpoint(port=port)) + + queue.put(True) + + await serve + + +def actor_runner( + actor_class: type, + actor_args: tuple, + actor_kwargs: dict, + actor_name: str, + signal_queue: Queue, + port: int = 9002, + log_file: str | None = None, +): + """Given an actor, runs it""" + if log_file: + setup_logging(log_file) + actor = actor_class(*actor_args, **actor_kwargs) + + asyncio.run(register_actor(actor, actor_name, signal_queue, port)) diff --git a/cogment_lab/actors/simple.py b/cogment_lab/actors/simple.py new file mode 100644 index 0000000..dd7cbc6 --- /dev/null +++ b/cogment_lab/actors/simple.py @@ -0,0 +1,41 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import gymnasium as gym +import numpy as np + +from cogment_lab.core import CogmentActor + + +class RandomActor(CogmentActor): + def __init__(self, action_space: gym.spaces.Space): + super().__init__(action_space) + self.gym_action_space = action_space + + async def act(self, observation: Any, rendered_frame=None): + return self.gym_action_space.sample() + + +class ConstantActor(CogmentActor): + def __init__(self, action: Any): + super().__init__(action) + if isinstance(action, list): + action = np.array(action) + self.action = action + + async def act(self, observation: Any, rendered_frame=None): + return self.action diff --git a/cogment_lab/cli/__init__.py b/cogment_lab/cli/__init__.py new file mode 100644 index 0000000..08d7dad --- /dev/null +++ b/cogment_lab/cli/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cogment_lab/cli/actor.py b/cogment_lab/cli/actor.py new file mode 100644 index 0000000..e656980 --- /dev/null +++ b/cogment_lab/cli/actor.py @@ -0,0 +1,66 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging + +import cogment +import yaml + +from cogment_lab.actors.configs import AgentConfig +from cogment_lab.core import BaseActor, NativeActor +from cogment_lab.generated import cog_settings +from cogment_lab.utils import import_object +from cogment_lab.utils.yaml_utils import gym_space_constructors + + +log = logging.getLogger(__name__) + + +def get_actor(agent_config: AgentConfig) -> BaseActor: + agent_type = agent_config.get("agent_type") + if agent_type is None: + raise ValueError("agent_type is not provided in config") + + if agent_type == "cogment": + impl_path = agent_config["cogment_actor"] + impl = import_object(impl_path) + agent = NativeActor(impl=impl) + elif agent_type == "custom": + agent_class = agent_config["agent_class"] + agent_kwargs = agent_config["agent_kwargs"] + cls = import_object(agent_class) + agent = cls(**agent_kwargs) + else: + raise NotImplementedError(f"Invalid agent_type: {agent_type}") + + return agent + + +async def create_agents(agent_configs: list[AgentConfig], port: int): + context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") + + for config in agent_configs: + agent = get_actor(config) + context.register_actor(impl=agent.impl, impl_name=config["agent_name"]) + + await context.serve_all_registered(cogment.ServedEndpoint(port=port)) + + +def actor_main(config_path: str): + gym_space_constructors() + with open(config_path) as config_file: + config = yaml.load(config_file, Loader=yaml.Loader) + log.info(config) + asyncio.run(create_agents(agent_configs=config["agents"], port=config["port"])) diff --git a/cogment_lab/cli/cli.py b/cogment_lab/cli/cli.py new file mode 100644 index 0000000..a32d3a6 --- /dev/null +++ b/cogment_lab/cli/cli.py @@ -0,0 +1,107 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import argparse +import logging +import os +import subprocess +import sys + +TEAL = "\033[36m" +RESET = "\033[0m" + +custom_format = f"%(asctime)s {TEAL}[%(levelname)s] [%(name)s]{RESET} %(message)s [thread:%(thread)d]" + +formatter = logging.Formatter(custom_format, datefmt="%Y-%m-%dT%H:%M:%S%z") + +handler = logging.StreamHandler() +handler.setFormatter(formatter) + +logging.basicConfig(level=logging.INFO, handlers=[handler]) + +sys.path.insert(0, "..") + + +def install_cogment(path: str | None = None): + try: + subprocess.run( + [ + "curl", + "--silent", + "-L", + "https://raw.githubusercontent.com/cogment/cogment/main/install.sh", + "--output", + "install-cogment.sh", + ], + check=True, + ) + subprocess.run(["chmod", "+x", "install-cogment.sh"], check=True) + cmd = ["sudo", "./install-cogment.sh"] + if path: + cmd += ["--install-dir", path] + cmd += ["--version", "2.19.1"] + if os.getenv("GITHUB_ACTIONS") == "true": + cmd = cmd[1:] # Remove sudo for github actions + subprocess.run(cmd, check=True) + logging.info("Cogment installed successfully.") + except subprocess.CalledProcessError as e: + logging.error(f"Installation failed: {e}") + finally: + if os.path.exists("install-cogment.sh"): + os.remove("install-cogment.sh") + logging.info("Cleanup completed.") + + +def main(): + parser = argparse.ArgumentParser(description="Cogment Lab CLI") + subparsers = parser.add_subparsers(dest="command") + + # launch subcommand + parser_launch = subparsers.add_parser("launch") + parser_launch.add_argument("file") + + # env subcommand + parser_env = subparsers.add_parser("env") + parser_env.add_argument("file") + + parser_actor = subparsers.add_parser("actor") + parser_actor.add_argument("file") + + parser_install = subparsers.add_parser("install") + parser_install.add_argument("path", nargs="?") + + args = parser.parse_args() + + if args.command == "install": + install_cogment(args.path) + elif args.command == "launch": + from cogment_lab.cli import launch + + launch.launch_main(args.file) + elif args.command == "env": + from cogment_lab.cli import env + + env.env_main(args.file) + elif args.command == "actor": + from cogment_lab.cli import actor + + actor.actor_main(args.file) + else: + print("Invalid command. Use 'launch', 'env', 'actor' or `install`.") + + +if __name__ == "__main__": + main() diff --git a/cogment_lab/cli/env.py b/cogment_lab/cli/env.py new file mode 100644 index 0000000..a7bb253 --- /dev/null +++ b/cogment_lab/cli/env.py @@ -0,0 +1,63 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging + +import cogment +import yaml + +from cogment_lab.core import BaseEnv, NativeEnv +from cogment_lab.envs.configs import EnvConfig +from cogment_lab.envs.environment import GymEnvironment +from cogment_lab.generated import cog_settings +from cogment_lab.utils import import_object + + +# log = logging.getLogger(__name__) + + +def get_environment(config: EnvConfig) -> BaseEnv: + """Given a config, generates an impl for a cogment environment""" + env_type = config.get("env_type") + if env_type is None: + raise ValueError("env_type is not provided in config") + + if env_type == "gymnasium": + env = GymEnvironment(config) + elif env_type == "cogment": + impl_path = config["cogment_env"] + impl = import_object(impl_path) + env = NativeEnv(impl=impl) + else: + raise NotImplementedError(f"Invalid env_type: {env_type}") + + return env + + +async def create_envs(env_configs: list[EnvConfig], port: int): + context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") + + for config in env_configs: + env = get_environment(config) + context.register_environment(impl=env.impl, impl_name=config["env_name"]) + + await context.serve_all_registered(cogment.ServedEndpoint(port=port)) + + +def env_main(config_path: str): + with open(config_path) as config_file: + config = yaml.load(config_file, Loader=yaml.Loader) + logging.info(config) + asyncio.run(create_envs(env_configs=config["environments"], port=config["port"])) diff --git a/cogment_lab/cli/launch.py b/cogment_lab/cli/launch.py new file mode 100644 index 0000000..22723f1 --- /dev/null +++ b/cogment_lab/cli/launch.py @@ -0,0 +1,53 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import subprocess + + +def launch_service(service_name: str): + try: + process = subprocess.Popen(["cogment", "services", service_name]) + logging.info(f"{service_name} launched successfully. PID: {process.pid}") + return process + except Exception as e: + logging.error(f"Failed to launch {service_name}: {e}") + return None + + +def launch_main(command: str): + services = command.split() + processes = [] + + if "base" in services or "all" in services: + if len(services) > 1: + raise ValueError("Cannot combine 'base' or 'all' with other services") + + if "base" in services: + services_to_run = ["orchestrator", "trial_datastore"] + elif "all" in services: + services_to_run = ["orchestrator", "trial_datastore", "model_registry", "directory", "web_proxy"] + else: + services_to_run = services + + for service in services_to_run: + process = launch_service(service) + if process: + processes.append(process) + + # Optional: Wait for all subprocesses to complete + for process in processes: + process.wait() + + return processes diff --git a/cogment_lab/constants.py b/cogment_lab/constants.py new file mode 100644 index 0000000..b6e9d85 --- /dev/null +++ b/cogment_lab/constants.py @@ -0,0 +1,15 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DEFAULT_RENDERED_WIDTH = 1024 diff --git a/cogment_lab/core.py b/cogment_lab/core.py new file mode 100644 index 0000000..df71bc1 --- /dev/null +++ b/cogment_lab/core.py @@ -0,0 +1,269 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import copy +import logging +from typing import Awaitable, Callable, Generic, TypeVar + +import cogment +import numpy as np +from cogment.actor import ActorSession +from cogment.environment import EnvironmentSession +from cogment.session import RecvEvent + +from cogment_lab.session_helpers import ActorSessionHelper, EnvironmentSessionHelper +from cogment_lab.specs import ( + AgentSpecs +) + +Action = TypeVar("Action") +Actions = dict[str, Action] + +Observation = TypeVar("Observation") +Observations = dict[str, Observation] + +Rewards = dict[str, float] + +Dones = dict[str, bool] + + +class State: + """Class to hold state information.""" + + def __setattr__(self, name, value): + self.__dict__[name] = value + + def __getattr__(self, name): + """Get attribute from internal dictionary.""" + return self.__dict__.get(name, None) + + +class BaseEnv(abc.ABC): + """Base environment class.""" + + agent_specs: dict[str, AgentSpecs] + + def __init__(self, *args, **kwargs): + """Initialize with arguments.""" + self.args = copy.deepcopy(args) + self.kwargs = copy.deepcopy(kwargs) + + async def impl(self, environment_session: EnvironmentSession): + """Abstract method to implement environment logic.""" + raise NotImplementedError() + + def get_constructor(self): + """Get constructor for this environment.""" + cls = self.__class__ + return lambda: cls(*self.args, **self.kwargs) + + +class CogmentEnv(BaseEnv, abc.ABC, Generic[Observation, Action]): + """Base Cogment environment class.""" + + environment_session: EnvironmentSession + session_helper: EnvironmentSessionHelper + agent_specs: dict[str, AgentSpecs] + actor_name: str + + def __init__(self, *args, **kwargs): + """Initialize.""" + super().__init__(*args, **kwargs) + + async def initialize(self, state: State, environment_session: EnvironmentSession): + """Initialize state and session.""" + state.environment_session = environment_session + return state + + @abc.abstractmethod + async def reset(self, state: State) -> tuple[State, Observations]: + """Reset environment state. + + Returns: + State: Updated state. + Observations: Initial observations. + """ + raise NotImplementedError() + + @abc.abstractmethod + async def step(self, state: State, action: Actions) -> tuple[State, Observations, Rewards, Dones, Dones, dict]: + """Take a step in the environment. + + Args: + state: Current state. + action: Actions from actors. + + Returns: + State: Updated state. + Observations: New observations. + Rewards: Rewards for each actor. + Dones: Whether each actor is done. + Dones: Whether each actor is truncated. + dict: Additional info. + """ + raise NotImplementedError() + + async def end(self, state: State): + """Clean up when done.""" + pass + + async def read_actions(self, state: State, event: RecvEvent): + """Read actions from event.""" + player_action = state.session_helper.get_action(event, state.actor_name) + return player_action.value + + async def impl(self, environment_session: EnvironmentSession): + """Implement environment logic.""" + state = State() + state = await self.initialize(state, environment_session) + state, observations = await self.reset(state) + + observations = list(observations.items()) + + logging.info(f"Starting environment session") + + environment_session.start(observations) + + async for event in environment_session.all_events(): + event: RecvEvent + if event.actions: + actions = await self.read_actions(state, event) + state, observations, rewards, terminateds, truncateds, info = await self.step(state, actions) + + dones = {actor_name: terminateds[actor_name] or truncateds[actor_name] for actor_name in terminateds} + + logging.info(f"Adding rewards: {rewards}") + for actor_name in state.actors: + if actor_name not in rewards: + rewards[actor_name] = float("nan") + for actor_name, reward in rewards.items(): + environment_session.add_reward(value=reward, to=[actor_name], confidence=1.0) + + observations = list(observations.items()) + + if all(dones.values()): + logging.info(f"Logging dones=True") + environment_session.end(observations) + # elif event.type != cogment.EventType.ACTIVE: + # logging.info("Logging event.type!=ACTIVE") + # environment_session.end(observations) + else: + logging.info(f"Logging a normal observation") + environment_session.produce_observations(observations) + + await self.end(state) + + +class BaseActor(abc.ABC): + """Base actor class.""" + + def __init__(self, *args, **kwargs): + """Initialize.""" + self.args = args + self.kwargs = kwargs + + async def impl(self, actor_session: ActorSession): + """Abstract method to implement actor logic.""" + raise NotImplementedError() + + +class NativeActor(BaseActor): + """Native actor wrapping a function.""" + + def __init__(self, impl: Callable[[ActorSession], Awaitable]): + """Initialize with implementation function.""" + super().__init__(impl) + self._impl = impl + + async def impl(self, actor_session: ActorSession): + """Call implementation function.""" + await self._impl(actor_session) + + +class CogmentActor(BaseActor, abc.ABC, Generic[Observation, Action]): + """Base Cogment actor class.""" + + actor_session: ActorSession + current_event: RecvEvent + session_helper: ActorSessionHelper + + def __init__(self, *args, **kwargs): + """Initialize.""" + super().__init__(*args, **kwargs) + + async def initialize(self, actor_session: ActorSession): + """Initialize session and helpers.""" + self.actor_session = actor_session + self.actor_session.start() + + self.session_helper = ActorSessionHelper(actor_session, None) + self.action_space = self.session_helper.get_action_space() + + @abc.abstractmethod + async def act(self, observation: Observation, rendered_frame: np.ndarray | None = None) -> Action: + """Choose an action based on observation. + + Args: + observation: Current observation. + rendered_frame: Optional rendered frame. + + Returns: + Action to take. + """ + raise NotImplementedError() + + async def on_reward(self, rewards: list): + """Handle received rewards.""" + pass + + async def on_message(self, messages: list): + """Handle received messages.""" + pass + + async def end(self): + """Clean up when done.""" + pass + + async def impl(self, actor_session: ActorSession): + """Implement actor logic.""" + await self.initialize(actor_session) + async for event in actor_session.all_events(): + event: RecvEvent + self.current_event = event + if event.type != cogment.EventType.ACTIVE: + logging.info(f"Skipping event of type {event.type}") + continue + + if event.observation: + observation = self.session_helper.get_observation(event) + logging.info(f"Got observation: {observation}") + + if not observation.active: + action = None + elif not observation.alive: + action = None + else: + action = await self.act(observation.value, observation.rendered_frame) + logging.info(f"Got action: {action} with action_space: {self.action_space.gym_space}") + cog_action = self.action_space.create_serialize(action) + actor_session.do_action(cog_action) + if event.rewards: + await self.on_reward(event.rewards) + if event.messages: + await self.on_message(event.messages) + + await self.end() diff --git a/cogment_lab/envs/__init__.py b/cogment_lab/envs/__init__.py new file mode 100644 index 0000000..4f53414 --- /dev/null +++ b/cogment_lab/envs/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.envs.pettingzoo import AECEnvironment diff --git a/cogment_lab/envs/configs.py b/cogment_lab/envs/configs.py new file mode 100644 index 0000000..e087942 --- /dev/null +++ b/cogment_lab/envs/configs.py @@ -0,0 +1,65 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Module for environment configuration. + +Defines data structures for specifying environment configurations. +""" + +from __future__ import annotations + +from typing import Any, TypedDict + + +class EnvConfig(TypedDict, total=False): + """ + Configuration for a single environment. + + Attributes: + env_type: Type of the environment. + env_name: Name of the environment. + cogment_env: Cogment environment ID to use, instead of env_id/registration/make_args. + env_id: Environment ID. + registration: Environment registration details. + make_args: Arguments for make() to construct the environment. + reset_options: Reset options for the environment. + render: Whether to render the environment. + """ + + env_type: str + env_name: str + + # Either this... + cogment_env: str | None + + # ...or all of this + env_id: str | None + registration: str | None + make_args: dict[str, Any] + reset_options: dict[str, Any] + render: bool + + +class EnvRunnerConfig(TypedDict): + """ + Configuration for an environment runner. + + Attributes: + envs: List of EnvConfig, one for each environment. + port: Port number for the runner. + """ + + envs: list[EnvConfig] + port: int diff --git a/cogment_lab/envs/conversions/__init__.py b/cogment_lab/envs/conversions/__init__.py new file mode 100644 index 0000000..08d7dad --- /dev/null +++ b/cogment_lab/envs/conversions/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cogment_lab/envs/conversions/observer.py b/cogment_lab/envs/conversions/observer.py new file mode 100644 index 0000000..5f8a66f --- /dev/null +++ b/cogment_lab/envs/conversions/observer.py @@ -0,0 +1,157 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import gymnasium as gym +from pettingzoo import AECEnv, ParallelEnv +from pettingzoo.utils.agent_selector import agent_selector + + +class GymObserverAEC(AECEnv): + metadata = {"render_modes": ["rgb_array"], "name": "GymWithObserverEnv"} + + def __init__(self, gym_env_name: str, gym_make_kwargs: dict = {}, render_mode: str | None = None): + super().__init__() + logging.info( + f"Creating GymObserverAEC with gym_env_name={gym_env_name}, gym_make_kwargs={gym_make_kwargs}, render_mode={render_mode}" + ) + self.gym_env = gym.make(gym_env_name, render_mode=render_mode, **gym_make_kwargs) + self.possible_agents = ["gym", "observer"] + self.observation_spaces = { + "gym": self.gym_env.observation_space, + "observer": gym.spaces.Dict( + {"observation": self.gym_env.observation_space, "action": self.gym_env.action_space} + ), + } + self.action_spaces = { + "gym": self.gym_env.action_space, + "observer": self.gym_env.action_space, # Even though actions are ignored + } + self._agent_selector = agent_selector(self.possible_agents) + # self.reset() + + def reset(self, seed: int | None = None, options: dict | None = None): + logging.info(f"Resetting GymObserverAEC with seed={seed}, options={options}") + self.agents = self.possible_agents[:] + self.agent_selection = self._agent_selector.reset() + self._cumulative_rewards = {agent: 0 for agent in self.agents} + self.rewards = {agent: 0 for agent in self.agents} + self.terminations = {agent: False for agent in self.agents} + self.truncations = {agent: False for agent in self.agents} + self._last_observation, info = self.gym_env.reset() + self.infos = {"gym": info, "observer": info} + self._gym_action = None + + def step(self, action): + logging.info(f"Stepping GymObserverAEC with action={action}") + current_agent = self.agent_selection + + if current_agent == "gym": + # Execute main agent's action directly + observation, reward, terminated, truncated, info = self.gym_env.step(action) + self._last_observation = observation + self.rewards["gym"] = float(reward) + self.terminations["gym"] = terminated + self.truncations["gym"] = truncated + self.infos["gym"] = info + + self.rewards["observer"] = self.rewards["gym"] + self.terminations["observer"] = self.terminations["gym"] + self.truncations["observer"] = self.truncations["gym"] + self.infos["observer"] = self.infos["gym"] + + self._gym_action = action + + # Observer agent's turn; actions are ignored + elif current_agent == "observer": + pass + + self.agent_selection = self._agent_selector.next() + + def observe(self, agent): + if agent == "gym": + return self._last_observation + elif agent == "observer": + return {"observation": self._last_observation, "action": self._gym_action} + + def render(self): + return self.gym_env.render() + + def close(self): + self.gym_env.close() + + def observation_space(self, agent: str): + return self.observation_spaces[agent] + + def action_space(self, agent: str): + return self.action_spaces[agent] + +class GymObserverParallel(ParallelEnv): + metadata = {"render_modes": ["rgb_array"], "name": "GymWithObserverEnv"} + + def __init__(self, gym_env_name: str, gym_make_kwargs: dict = {}, render_mode: str | None = None): + super().__init__() + logging.info( + f"Creating GymObserverParallel with gym_env_name={gym_env_name}, gym_make_kwargs={gym_make_kwargs}, render_mode={render_mode}" + ) + self.gym_env = gym.make(gym_env_name, render_mode=render_mode, **gym_make_kwargs) + self.possible_agents = ["gym", "observer"] + self.observation_spaces = { + "gym": self.gym_env.observation_space, + "observer": self.gym_env.observation_space + } + self.action_spaces = { + "gym": self.gym_env.action_space, + "observer": self.gym_env.action_space, # Even though actions are ignored + } + # self.reset() + + def reset(self, seed: int | None = None, options: dict | None = None): + logging.info(f"Resetting GymObserverParallel with seed={seed}, options={options}") + self.agents = self.possible_agents[:] + + obs, info = self.gym_env.reset(seed=seed, options=options) + + infos = {"gym": info, "observer": info} + observations = {"gym": obs, "observer": obs} + + return observations, infos + + + def step(self, action: dict[str, Any]): + logging.info(f"Stepping GymObserverParallel with action={action}") + + obs, reward, terminated, truncated, info = self.gym_env.step(action["gym"]) + + observations = {"gym": obs, "observer": obs} + rewards = {"gym": reward, "observer": reward} + terminations = {"gym": terminated, "observer": terminated} + truncations = {"gym": truncated, "observer": truncated} + infos = {"gym": info, "observer": info} + + return observations, rewards, terminations, truncations, infos + + def render(self): + return self.gym_env.render() + + def close(self): + self.gym_env.close() + + def observation_space(self, agent: str): + return self.observation_spaces[agent] + + def action_space(self, agent: str): + return self.action_spaces[agent] diff --git a/cogment_lab/envs/conversions/teacher.py b/cogment_lab/envs/conversions/teacher.py new file mode 100644 index 0000000..abd7eb7 --- /dev/null +++ b/cogment_lab/envs/conversions/teacher.py @@ -0,0 +1,181 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym +from pettingzoo import AECEnv, ParallelEnv +from pettingzoo.utils.agent_selector import agent_selector +import warnings +import numpy as np + + +class GymTeacherAEC(AECEnv): + metadata = {"render_modes": ["rgb_array"]} + + def __init__(self, gym_env_name: str, gym_make_kwargs: dict = {}, render_mode: str | None = None): + super().__init__() + self.gym_env = gym.make(gym_env_name, render_mode=render_mode, **gym_make_kwargs) + self.possible_agents = ["gym", "teacher"] + self.observation_spaces = { + "gym": self.gym_env.observation_space, + "teacher": gym.spaces.Dict( + {"observation": self.gym_env.observation_space, "action": self.gym_env.action_space} + ), + } + teacher_action_space = gym.spaces.Dict({"active": gym.spaces.Discrete(2), "action": self.gym_env.action_space}) + self.action_spaces = {"gym": self.gym_env.action_space, "teacher": teacher_action_space} + self._agent_selector = agent_selector(self.possible_agents) + self.override = False + # self.reset() + + def reset(self, seed: int | None = None, options: dict | None = None): + self.agents = self.possible_agents[:] + self.agent_selection = self._agent_selector.reset() + self._cumulative_rewards = {agent: 0 for agent in self.agents} + self.rewards = {agent: 0 for agent in self.agents} + self.terminations = {agent: False for agent in self.agents} + self.truncations = {agent: False for agent in self.agents} + self._last_observation, info = self.gym_env.reset(seed=seed, options=options) + self.infos = {"gym": info, "teacher": info} + self._gym_action = None + + def step(self, action): + current_agent = self.agent_selection + next_agent = self._agent_selector.next() + + if current_agent == "gym": + self._cumulative_rewards["gym"] = 0 + self._cumulative_rewards["teacher"] = 0 + + self._gym_action = action + + elif current_agent == "teacher": + if action["active"] == 1: + self.override = True + real_action = action["action"] + else: + self.override = False + real_action = self._gym_action + observation, reward, terminated, truncated, info = self.gym_env.step(real_action) + self._last_observation = observation + self.rewards["gym"] = float(reward) + self.rewards["teacher"] = float(reward) + + self.terminations["gym"] = terminated + self.terminations["teacher"] = terminated + + self.truncations["gym"] = truncated + self.truncations["teacher"] = truncated + + self.infos["gym"] = info + self.infos["teacher"] = info + + self._accumulate_rewards() + self.agent_selection = next_agent + + def observe(self, agent): + if agent == "gym": + return self._last_observation + elif agent == "teacher": + return {"observation": self._last_observation, "action": self._gym_action} + + def render(self): + img = self.gym_env.render() + if self.override: + W, H, _ = img.shape + N = 5 + + # Set the borders to red + img[:N, :, :] = [255, 0, 0] # Top border + img[-N:, :, :] = [255, 0, 0] # Bottom border + img[:, :N, :] = [255, 0, 0] # Left border + img[:, -N:, :] = [255, 0, 0] # Right border + return img + + def close(self): + self.gym_env.close() + + def observation_space(self, agent: str): + return self.observation_spaces[agent] + + def action_space(self, agent: str): + return self.action_spaces[agent] + + +class GymTeacherParallel(ParallelEnv): + metadata = {"render_modes": ["rgb_array"]} + + def __init__(self, gym_env_name: str, gym_make_kwargs: dict = {}, render_mode: str | None = None): + super().__init__() + self.gym_env = gym.make(gym_env_name, render_mode=render_mode, **gym_make_kwargs) + self.possible_agents = ["gym", "teacher"] + self.observation_spaces = { + "gym": self.gym_env.observation_space, + "teacher": self.gym_env.observation_space + } + + teacher_action_space = gym.spaces.Dict({"active": gym.spaces.Discrete(2), "action": self.gym_env.action_space}) + self.action_spaces = {"gym": self.gym_env.action_space, "teacher": teacher_action_space} + self._agent_selector = agent_selector(self.possible_agents) + self.override = False + # self.reset() + + def reset(self, seed: int | None = None, options: dict | None = None): + self.agents = self.possible_agents[:] + + obs, info = self.gym_env.reset(seed=seed, options=options) + + infos = {"gym": info, "teacher": info} + observations = {"gym": obs, "teacher": obs} + + return observations, infos + + def step(self, action): + if action["teacher"]["active"] == 1: + self.override = True + real_action = action["teacher"]["action"] + else: + self.override = False + real_action = action["gym"] + + observation, reward, terminated, truncated, info = self.gym_env.step(real_action) + + observations = {"gym": observation, "teacher": observation} + rewards = {"gym": reward, "teacher": reward} + terminations = {"gym": terminated, "teacher": terminated} + truncations = {"gym": truncated, "teacher": truncated} + infos = {"gym": info, "teacher": info} + + return observations, rewards, terminations, truncations, infos + + def render(self): + img = self.gym_env.render() + if self.override: + W, H, _ = img.shape + N = 5 + + # Set the borders to red + img[:N, :, :] = [255, 0, 0] # Top border + img[-N:, :, :] = [255, 0, 0] # Bottom border + img[:, :N, :] = [255, 0, 0] # Left border + img[:, -N:, :] = [255, 0, 0] # Right border + return img + + def close(self): + self.gym_env.close() + + def observation_space(self, agent: str): + return self.observation_spaces[agent] + + def action_space(self, agent: str): + return self.action_spaces[agent] diff --git a/cogment_lab/envs/gymnasium.py b/cogment_lab/envs/gymnasium.py new file mode 100644 index 0000000..59a6eb1 --- /dev/null +++ b/cogment_lab/envs/gymnasium.py @@ -0,0 +1,265 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import logging +import os +from typing import Any, Callable + +import gymnasium as gym +from cogment.environment import EnvironmentSession + +from cogment_lab.core import CogmentEnv, State +from cogment_lab.session_helpers import EnvironmentSessionHelper +from cogment_lab.specs import AgentSpecs + +log = logging.getLogger(__name__) + +# configure pygame to use a dummy video server to be able to render headlessly +os.environ["SDL_VIDEODRIVER"] = "dummy" + + +class GymEnvironment(CogmentEnv): + """ + Gymnasium integration for Cogment. + + Exposes a Gymnasium environment as a Cogment environment. + """ + + session_helper: EnvironmentSessionHelper + actor_name = "gym" + + def __init__( + self, + env_id: str | Callable[..., gym.Env], + registration: str | None = None, + make_kwargs: dict[str, Any] | None = None, + reset_options: dict[str, Any] | None = None, + render: bool = False, + reinitialize: bool = False, + dry: bool = False, + sub_dry: bool = True, + ): + """ + Initialize the GymEnvironment. + + Args: + env_id: The Gym environment ID. + registration: Optional Gym registration string. + make_kwargs: Optional args to pass to gym.make(). + reset_options: Optional reset options to pass to env.reset(). + render: Whether to render the environment. + reinitialize: Whether to reinitialize the environment each session. + dry: Whether to abstain from initializing the environment in this process. + sub_dry: Whether to abstain from initializing the environment in the initializer of the subprocess. + """ + super().__init__( + env_id=env_id, + registration=registration, + make_kwargs=make_kwargs, + reset_options=reset_options, + render=render, + reinitialize=reinitialize, + dry=sub_dry, + sub_dry=sub_dry, + ) + + self.env_id = env_id + self.registration = registration + self.make_kwargs = make_kwargs or {} + self.reset_options = reset_options or {} + self.render = render + self.reinitialize = reinitialize + self.dry = dry + + if "render_mode" in self.make_kwargs: + raise ValueError("render_mode cannot be set in make_kwargs") + + if self.render: + self.make_kwargs["render_mode"] = "rgb_array" + + if isinstance(self.env_id, Callable): + self.env_maker = self.env_id + else: + self.env_maker = lambda **kwargs: gym.make(self.env_id, **kwargs) + + if self.registration: + importlib.import_module(self.registration) + + if not self.dry: + self.env = self.env_maker(**self.make_kwargs) + + self.agent_specs = { + "gym": AgentSpecs.create_homogeneous( + observation_space=self.env.observation_space, + action_space=self.env.action_space, + ) + } + + self.initialized = True + else: + self.env = None + self.agent_specs = {} + self.initialized = False + + self._is_closed = False + + def get_implementation_name(self): + """ + Get the name of the Gym environment. + + Returns: + The Gym environment ID. + """ + return self.env_id + + def get_agent_specs(self): + """ + Get the agent specs. + + Returns: + The agent specs dict. + """ + return self.agent_specs + + async def initialize(self, state: State, environment_session: EnvironmentSession): + """ + Initialize the environment session. + + Args: + state: The Cogment state. + environment_session: The Cogment environment session. + + Returns: + The updated state. + """ + + logging.info("Initializing environment session") + + if not self.initialized and not self.reinitialize: + self.env = self.env_maker(**self.make_kwargs) + self.agent_specs = { + "gym": AgentSpecs.create_homogeneous( + observation_space=self.env.observation_space, + action_space=self.env.action_space, + ) + } + + state.env = self.env + state.agent_specs = self.agent_specs + elif self.initialized and not self.reinitialize: + state.env = self.env + state.agent_specs = self.agent_specs + elif self.reinitialize: + state.env = self.env_maker(**self.make_kwargs) + state.agent_specs = { + "gym": AgentSpecs.create_homogeneous( + observation_space=state.env.observation_space, + action_space=state.env.action_space, + ) + } + + state.environment_session = environment_session + state.session_helper = EnvironmentSessionHelper(environment_session, state.agent_specs) + state.session_cfg = state.environment_session.config + state.actors = state.session_helper.actors + state.actor_name = state.session_helper.actors[0] + + self.initialized = True + + return state + + async def reset(self, state: State): + """ + Reset the environment. + + Args: + state: The Cogment state. + + Returns: + A tuple with the updated state and a dict of observations. + """ + + logging.info("Resetting environment") + + obs, _info = state.env.reset(seed=state.session_cfg.seed, options=state.session_cfg.reset_args) # THIS + + state.observation_space = state.session_helper.get_observation_space(self.actor_name) + frame = state.env.render() if state.session_cfg.render else None + observation = state.observation_space.create_serialize(value=obs, rendered_frame=frame, active=True, alive=True) + + return state, {"*": observation} + + async def read_actions(self, state: State, event): + """ + Read the agent action from the event. + + Args: + state: The Cogment state. + event: The event from Cogment. + + Returns: + The agent action value. + """ + player_action = state.session_helper.get_action(tick_data=event, actor_name=self.actor_name) + return player_action.value + + async def step(self, state: State, action): + """ + Step the environment. + + Args: + state: The Cogment state. + action: The agent action. + + Returns: + A tuple with the updated state, observations dict, rewards dict, + terminateds dict, truncateds dict, and info dict. + """ + + logging.info("Stepping environment") + + obs, reward, terminated, truncated, info = state.env.step(action) + logging.info(f"Step returned {obs=}, {reward=}, {terminated=}, {truncated=}, {info=}") + + observation = state.observation_space.create_serialize( + value=obs, rendered_frame=state.env.render() if state.session_cfg.render else None, active=True, alive=True + ) + + # observations = [("*", observation)] + observations = {"*": observation} + rewards = {self.actor_name: reward} + terminateds = {self.actor_name: terminated} + truncateds = {self.actor_name: truncated} + + return state, observations, rewards, terminateds, truncateds, info + + async def end(self, state: State): + """ + End the environment session. + + Args: + state: The Cogment state. + + Returns: + The updated state. + """ + + logging.info("Ending environment") + state.env.close() + if self.env is not None and not self._is_closed: + self.env.close() + self._is_closed = True diff --git a/cogment_lab/envs/pettingzoo.py b/cogment_lab/envs/pettingzoo.py new file mode 100644 index 0000000..597910d --- /dev/null +++ b/cogment_lab/envs/pettingzoo.py @@ -0,0 +1,526 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TypedDict, Any + +import numpy as np +from cogment.session import RecvEvent +from pettingzoo import AECEnv, ParallelEnv + +from cogment.environment import EnvironmentSession + +from cogment_lab.core import CogmentEnv, State +from cogment_lab.session_helpers import EnvironmentSessionHelper +from cogment_lab.specs import AgentSpecs +from cogment_lab.specs.observation_space import ObservationSpace +from cogment_lab.utils import import_object +from cogment_lab.generated.data_pb2 import ( + Observation as PbObservation, +) +import pettingzoo + + +class PZConfig(TypedDict): + """Configuration for a PettingZoo environment.""" + + env_path: str + make_args: dict + reset_options: dict + + +class AECEnvironment(CogmentEnv): + """Cogment environment wrapper for PettingZoo AEC environments.""" + + def __init__( + self, + env_path: str, + make_kwargs: dict | None = None, + reset_options: dict = None, + render: bool = False, + reinitialize: bool = False, + dry: bool = False, + sub_dry: bool = True, + ): + """ + Initialize the AECEnvironment. + + Args: + env_path: Path to the PettingZoo environment class. + make_kwargs: Arguments to pass to the environment constructor. + reset_options: Options to pass to reset(). + render: Whether to render the environment. + reinitialize: Whether to reinitialize the environment each session. + dry: Whether to abstain from initializing the environment in this process. + sub_dry: Whether to abstain from initializing the environment in the initializer of the subprocess + """ + super().__init__( + env_path=env_path, + make_kwargs=make_kwargs, + reset_options=reset_options, + render=render, + reinitialize=reinitialize, + dry=sub_dry, + sub_dry=sub_dry, + ) + self.env_path = env_path + self.make_args = make_kwargs or {} + self.reset_options = reset_options or {} + self.render = render + self.reinitialize = reinitialize + self.dry = dry + + logging.info( + f"Creating AECEnvironment with {env_path=}, {make_kwargs=}, {reset_options=}, {render=}, {reinitialize=}, {dry=}, {sub_dry=}" + ) + + self.env_maker = import_object(self.env_path) + + assert callable(self.env_maker), f"Environment class at {self.env_path} is not callable" + + if render: + self.make_args["render_mode"] = "rgb_array" + + if not self.dry: + self.env: AECEnv = self.env_maker(**self.make_args) + self.agent_specs = self.create_agent_specs(self.env) + else: + self.env = None + self.agent_specs = {} + + self.initialized = False + + async def initialize(self, state: State, environment_session: EnvironmentSession): + """ + Initialize the environment session. + + Args: + state: The Cogment state. + environment_session: The Cogment environment session. + + Returns: + The updated state. + """ + logging.info("Initializing environment session") + if not self.initialized and not self.reinitialize: + self.env = self.env_maker(**self.make_args) + self.agent_specs = self.create_agent_specs(self.env) + + state.env = self.env + state.agent_specs = self.agent_specs + elif self.initialized and not self.reinitialize: + state.env = self.env + state.agent_specs = self.agent_specs + elif self.reinitialize: + state.env = self.env_maker(**self.make_args) + state.agent_specs = self.create_agent_specs(state.env) + + self.initialized = True + + state.environment_session = environment_session + state.session_helper = EnvironmentSessionHelper(environment_session, state.agent_specs) + state.session_cfg = state.environment_session.config + state.actors = state.session_helper.actors + + state.observation_spaces = {agent: state.session_helper.get_observation_space(agent) for agent in state.actors} + + return state + + async def reset(self, state: State): + """ + Reset the environment. + + Args: + state: The Cogment state. + + Returns: + A tuple of (state, observations) + """ + logging.info("Resetting environment") + + state.env.reset(seed=state.session_cfg.seed) + + obs, reward, term, trunc, info = state.env.last() + agent = state.env.agent_selection + + state.actor_name = agent + + frame = state.env.render() if state.session_cfg.render else None + logging.info(f"Creating observation from {obs=}") + observation = state.observation_spaces[agent].create_serialize( + value=obs, rendered_frame=frame, active=True, alive=True + ) + if frame is not None: + logging.info(f"Frame shape at reset: {frame.shape}") + else: + logging.info("Frame at reset is None") + + observations = {agent: observation} + observations = self.fill_observations_(state=state, observations=observations, frame=frame) + + return state, observations + + async def step(self, state: State, action: Any): + """ + Take an environment step. + + Args: + state: The Cogment state. + action: The action to take. + + Returns: + A tuple of (state, observations, rewards, terminated, truncated, info) + """ + logging.info("Stepping environment") + + state.env.step(action) + obs, reward, terminated, truncated, info = state.env.last() + agent = state.env.agent_selection + + frame = state.env.render() if state.session_cfg.render else None + + if frame is not None: + logging.info(f"Frame shape at step: {frame.shape}") + else: + logging.info("Frame at step is None") + + observation = state.observation_spaces[agent].create_serialize( + value=obs, + rendered_frame=frame, + active=True, + alive=not (terminated or truncated), + ) + + # observations = [(agent, observation)] + observations = {agent: observation} + rewards = {agent: reward} + terminateds = state.env.terminations + truncateds = state.env.truncations + + state.actor_name = agent + + observations = self.fill_observations_(state, observations, frame=frame) + + return state, observations, rewards, terminateds, truncateds, info + + async def end(self, state: State): + """ + End the environment session. + + Args: + state: The Cogment state. + """ + logging.info("Ending environment") + state.env.close() + + @staticmethod + def fill_observations_( + state: State, observations: dict[str, PbObservation], frame: np.ndarray + ) -> dict[str, PbObservation]: + """ + Fill in any missing observations with the default observation. Mutates the observations dict. + + Args: + state: The Cogment state. + observations: The observations dict. + + Returns: + The filled observations dict. + """ + if "*" in observations: + return observations + for actor_name in state.actors: + if actor_name not in observations: + observations[actor_name] = state.observation_spaces[actor_name].create_serialize( + rendered_frame=frame, active=False + ) + + return observations + + @staticmethod + def create_agent_specs(env: AECEnv): + """ + Create the agent specs from a PettingZoo AEC environment. + + Args: + env: The PettingZoo AEC environment. + + Returns: + The agent specs dict. + """ + num_agents = len(env.possible_agents) + + # Check all observation and action spaces + + is_homogeneous = True + observation_space = env.observation_space(env.possible_agents[0]) + action_space = env.action_space(env.possible_agents[0]) + for agent in env.possible_agents: + if env.observation_space(agent) != observation_space: + is_homogeneous = False + break + if env.action_space(agent) != action_space: + is_homogeneous = False + break + + if is_homogeneous: + one_agent_specs = AgentSpecs.create_homogeneous( + observation_space=observation_space, + action_space=action_space, + ) + agent_specs = {agent: one_agent_specs for agent in env.possible_agents} + else: + agent_specs = { + agent: AgentSpecs.create_homogeneous( + observation_space=env.observation_space(agent), + action_space=env.action_space(agent), + ) + for agent in env.possible_agents + } + + return agent_specs + + +class ParallelEnvironment(CogmentEnv): + """Cogment environment wrapper for PettingZoo Parallel environments.""" + + def __init__( + self, + env_path: str, + make_kwargs: dict | None = None, + reset_options: dict = None, + render: bool = False, + reinitialize: bool = False, + dry: bool = False, + sub_dry: bool = True, + ): + """ + Initialize the AECEnvironment. + + Args: + env_path: Path to the PettingZoo environment class. + make_kwargs: Arguments to pass to the environment constructor. + reset_options: Options to pass to reset(). + render: Whether to render the environment. + reinitialize: Whether to reinitialize the environment each session. + dry: Whether to abstain from initializing the environment in this process. + sub_dry: Whether to abstain from initializing the environment in the initializer of the subprocess + """ + super().__init__( + env_path=env_path, + make_kwargs=make_kwargs, + reset_options=reset_options, + render=render, + reinitialize=reinitialize, + dry=sub_dry, + sub_dry=sub_dry, + ) + self.env_path = env_path + self.make_args = make_kwargs or {} + self.reset_options = reset_options or {} + self.render = render + self.reinitialize = reinitialize + self.dry = dry + + logging.info( + f"Creating ParallelEnvironment with {env_path=}, {make_kwargs=}, {reset_options=}, {render=}, {reinitialize=}, {dry=}, {sub_dry=}" + ) + + self.env_maker = import_object(self.env_path) + + assert callable(self.env_maker), f"Environment class at {self.env_path} is not callable" + + if render: + self.make_args["render_mode"] = "rgb_array" + + if not self.dry: + self.env: ParallelEnv = self.env_maker(**self.make_args) + self.agent_specs = self.create_agent_specs(self.env) + else: + self.env = None + self.agent_specs = {} + + self.initialized = False + + async def initialize(self, state: State, environment_session: EnvironmentSession): + """ + Initialize the environment session. + + Args: + state: The Cogment state. + environment_session: The Cogment environment session. + + Returns: + The updated state. + """ + logging.info("Initializing environment session") + if not self.initialized and not self.reinitialize: + self.env: ParallelEnv = self.env_maker(**self.make_args) + self.agent_specs = self.create_agent_specs(self.env) + + state.env = self.env + state.agent_specs = self.agent_specs + elif self.initialized and not self.reinitialize: + state.env: ParallelEnv = self.env + state.agent_specs = self.agent_specs + elif self.reinitialize: + state.env: ParallelEnv = self.env_maker(**self.make_args) + state.agent_specs = self.create_agent_specs(state.env) + + self.initialized = True + + state.environment_session = environment_session + state.session_helper = EnvironmentSessionHelper(environment_session, state.agent_specs) + state.session_cfg = state.environment_session.config + state.actors = state.session_helper.actors + + state.observation_spaces = {agent: state.session_helper.get_observation_space(agent) for agent in state.actors} + + return state + + async def reset(self, state: State): + """ + Reset the environment. + + Args: + state: The Cogment state. + + Returns: + A tuple of (state, observations) + """ + logging.info("Resetting environment") + + obs, info = state.env.reset(seed=state.session_cfg.seed) + + frame = state.env.render() if state.session_cfg.render else None + logging.info(f"Creating observation from {obs=}") + + + + if frame is not None: + logging.info(f"Frame shape at reset: {frame.shape}") + else: + logging.info("Frame at reset is None") + + observations = { + agent: + state.observation_spaces[agent].create_serialize( + value=obs[agent], rendered_frame=frame, active=True, alive=True + ) + for agent in obs + } + + # observations = {agent: observation} + # observations = self.fill_observations_(state=state, observations=observations, frame=frame) + + return state, observations + + async def step(self, state: State, action: dict[str, Any]): + """ + Take an environment step. + + Args: + state: The Cogment state. + action: The action to take. + + Returns: + A tuple of (state, observations, rewards, terminated, truncated, info) + """ + logging.info("Stepping environment") + + state.env.step(action) + obs, rewards, terminated, truncated, info = state.env.step(action) + + frame = state.env.render() if state.session_cfg.render else None + + if frame is not None: + logging.info(f"Frame shape at step: {frame.shape}") + else: + logging.info("Frame at step is None") + + + observations = { + agent: + state.observation_spaces[agent].create_serialize( + value=obs[agent], + rendered_frame=frame, + active=True, + alive=not (terminated[agent] or truncated[agent]) + ) + for agent in obs + } + + + return state, observations, rewards, terminated, truncated, info + + async def end(self, state: State): + """ + End the environment session. + + Args: + state: The Cogment state. + """ + logging.info("Ending environment") + state.env.close() + + + async def read_actions(self, state: State, event: RecvEvent): + """Read actions from event.""" + player_actions = {agent: state.session_helper.get_action(event, agent).value + for agent in state.actors} + return player_actions + + + @staticmethod + def create_agent_specs(env: ParallelEnv): + """ + Create the agent specs from a PettingZoo AEC environment. + + Args: + env: The PettingZoo AEC environment. + + Returns: + The agent specs dict. + """ + num_agents = len(env.possible_agents) + + # Check all observation and action spaces + + is_homogeneous = True + observation_space = env.observation_space(env.possible_agents[0]) + action_space = env.action_space(env.possible_agents[0]) + for agent in env.possible_agents: + if env.observation_space(agent) != observation_space: + is_homogeneous = False + break + if env.action_space(agent) != action_space: + is_homogeneous = False + break + + if is_homogeneous: + one_agent_specs = AgentSpecs.create_homogeneous( + observation_space=observation_space, + action_space=action_space, + ) + agent_specs = {agent: one_agent_specs for agent in env.possible_agents} + else: + agent_specs = { + agent: AgentSpecs.create_homogeneous( + observation_space=env.observation_space(agent), + action_space=env.action_space(agent), + ) + for agent in env.possible_agents + } + + return agent_specs diff --git a/cogment_lab/envs/runner.py b/cogment_lab/envs/runner.py new file mode 100644 index 0000000..862c06e --- /dev/null +++ b/cogment_lab/envs/runner.py @@ -0,0 +1,74 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registers an environment with Cogment and runs the Cogment server""" + +from __future__ import annotations + +import asyncio +import logging +from multiprocessing import Queue + +import cogment + +from cogment_lab.core import BaseEnv +from cogment_lab.generated import cog_settings +from cogment_lab.utils.runners import setup_logging + + +async def register_env(env: BaseEnv, env_name: str, signal_queue: Queue, port: int = 9001): + """Registers an environment with Cogment and runs the Cogment server + + Args: + env (BaseEnv): The environment to register + env_name (str): The name to register the environment under + signal_queue (Queue): A queue to signal when server has started + port (int, optional): The port for the Cogment server. Defaults to 9001. + + """ + context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") + + context.register_environment(impl=env.impl, impl_name=env_name) + logging.info(f"Registered environment {env_name} with cogment") + + serve = context.serve_all_registered(cogment.ServedEndpoint(port=port)) + signal_queue.put(True) + await serve + + +def env_runner( + env_class: type, + env_args: tuple, + env_kwargs: dict, + env_name: str, + signal_queue: Queue, + port: int = 9001, + log_file: str | None = None, +): + """Given an environment, runs it + + Args: + env_class (type): The environment class to instantiate + env_args (tuple): Positional arguments for the environment + env_kwargs (dict): Keyword arguments for the environment + env_name (str): The name to register the environment under + signal_queue (Queue): A queue to signal when server has started + port (int, optional): The port for the Cogment server. Defaults to 9001. + log_file (str | None, optional): File path to write logs to. Defaults to None. + """ + if log_file: + setup_logging(log_file) + env = env_class(*env_args, **env_kwargs) + + asyncio.run(register_env(env, env_name, signal_queue, port)) diff --git a/cogment_lab/generated/__init__.py b/cogment_lab/generated/__init__.py new file mode 100644 index 0000000..08d7dad --- /dev/null +++ b/cogment_lab/generated/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cogment_lab/generated/cog_settings.py b/cogment_lab/generated/cog_settings.py new file mode 100644 index 0000000..0d61b2d --- /dev/null +++ b/cogment_lab/generated/cog_settings.py @@ -0,0 +1,35 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cogment as _cog +from types import SimpleNamespace + +import cogment_lab.generated.data_pb2 as data_pb +import cogment_lab.generated.ndarray_pb2 as ndarray_pb +import cogment_lab.generated.spaces_pb2 as spaces_pb + + +_player_class = _cog.actor.ActorClass( + name="player", + config_type=data_pb.AgentConfig, + action_space=data_pb.PlayerAction, + observation_space=data_pb.Observation, + ) + + +actor_classes = _cog.actor.ActorClassList(_player_class) + +trial = SimpleNamespace(config_type=data_pb.TrialConfig) + +environment = SimpleNamespace(config_type=data_pb.EnvironmentConfig) diff --git a/cogment_lab/generated/data_pb2.py b/cogment_lab/generated/data_pb2.py new file mode 100644 index 0000000..64100ad --- /dev/null +++ b/cogment_lab/generated/data_pb2.py @@ -0,0 +1,61 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: data.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import cogment_lab.generated.ndarray_pb2 as ndarray__pb2 +import cogment_lab.generated.spaces_pb2 as spaces__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _ENVIRONMENTCONFIG_RESETARGSENTRY._options = None + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b'8\001' + _ENVIRONMENTSPECS._serialized_start=57 + _ENVIRONMENTSPECS._serialized_end=272 + _AGENTSPECS._serialized_start=274 + _AGENTSPECS._serialized_end=389 + _VALUE._serialized_start=391 + _VALUE._serialized_end=480 + _ENVIRONMENTCONFIG._serialized_start=483 + _ENVIRONMENTCONFIG._serialized_end=724 + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start=656 + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end=724 + _HFHUBMODEL._serialized_start=726 + _HFHUBMODEL._serialized_end=773 + _AGENTCONFIG._serialized_start=776 + _AGENTCONFIG._serialized_end=940 + _TRIALCONFIG._serialized_start=942 + _TRIALCONFIG._serialized_end=955 + _OBSERVATION._serialized_start=958 + _OBSERVATION._serialized_end=1094 + _PLAYERACTION._serialized_start=1096 + _PLAYERACTION._serialized_end=1154 +# @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/generated/ndarray_pb2.py b/cogment_lab/generated/ndarray_pb2.py new file mode 100644 index 0000000..6f63473 --- /dev/null +++ b/cogment_lab/generated/ndarray_pb2.py @@ -0,0 +1,41 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: ndarray.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xb8\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r*\x83\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _DTYPE._serialized_start=227 + _DTYPE._serialized_end=358 + _ARRAY._serialized_start=40 + _ARRAY._serialized_end=224 +# @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/generated/spaces_pb2.py b/cogment_lab/generated/spaces_pb2.py new file mode 100644 index 0000000..9af8bef --- /dev/null +++ b/cogment_lab/generated/spaces_pb2.py @@ -0,0 +1,52 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: spaces.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import cogment_lab.generated.ndarray_pb2 as ndarray__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"\x89\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x42\x06\n\x04kindb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _DISCRETE._serialized_start=51 + _DISCRETE._serialized_end=87 + _BOX._serialized_start=89 + _BOX._serialized_end=179 + _MULTIBINARY._serialized_start=181 + _MULTIBINARY._serialized_end=234 + _MULTIDISCRETE._serialized_start=236 + _MULTIDISCRETE._serialized_end=294 + _DICT._serialized_start=296 + _DICT._serialized_end=420 + _DICT_SUBSPACE._serialized_start=355 + _DICT_SUBSPACE._serialized_end=420 + _SPACE._serialized_start=423 + _SPACE._serialized_end=688 +# @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/humans/__init__.py b/cogment_lab/humans/__init__.py new file mode 100644 index 0000000..08d7dad --- /dev/null +++ b/cogment_lab/humans/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/cogment_lab/humans/actor.py b/cogment_lab/humans/actor.py new file mode 100644 index 0000000..819e848 --- /dev/null +++ b/cogment_lab/humans/actor.py @@ -0,0 +1,212 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import base64 +import io +import logging +import multiprocessing as mp +import os +from typing import Any +import json + +from jinja2 import Template + +import cogment +import numpy as np +import uvicorn +from PIL import Image +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse, HTMLResponse +from fastapi.staticfiles import StaticFiles + +from cogment_lab.core import CogmentActor +from cogment_lab.generated import cog_settings + + +def image_to_msg(img: np.ndarray | None) -> str | None: + if img is None: + return None + img = Image.fromarray(img) + img_byte_array = io.BytesIO() + img.save(img_byte_array, format="PNG") + base64_encoded_result_bytes = base64.b64encode(img_byte_array.getvalue()) + base64_encoded_result_str = base64_encoded_result_bytes.decode("ascii") + return f"data:image/png;base64,{base64_encoded_result_str}" + + +def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int: + if isinstance(action_map, list): + action_map = {action: i for i, action in enumerate(action_map)} + + if data.startswith("{"): + # This is a JSON object + action = json.loads(data) + elif data not in action_map: + action = action_map["no-op"] + else: + action = action_map[data] + + logging.info(f"Processed action {action} from {data} with action_map {action_map}") + return action + + +async def start_fastapi( + port: int, + send_queue: asyncio.Queue, + recv_queue: asyncio.Queue, + actions: list[str] | dict[str, Any] | None = None, + fps: float = 30.0, + html_override: str | None = None, + file_override: str | None = None, + jinja_parameters: dict[str, Any] | None = None, +): + app = FastAPI() + + if actions is None: + actions = ["no-op", "ArrowLeft", "ArrowRight"] + + if jinja_parameters is None: + jinja_parameters = {} + + @app.get("/") + async def get(): + logging.info("Serving index.html") + if html_override is not None: + # Render HTML from string + template = Template(html_override) + rendered_html = template.render(**jinja_parameters) + return HTMLResponse(rendered_html) + elif file_override is not None and os.path.isfile(file_override): + # Render HTML from file + with open(file_override, "r") as file: + file_content = file.read() + template = Template(file_content) + rendered_html = template.render(**jinja_parameters) + return HTMLResponse(rendered_html) + else: + # Fallback option: Serve static file + static_directory_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "static") + return FileResponse(os.path.join(static_directory_path, "index.html")) + + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + logging.info("Waiting for socket connection") + logging.info(f"Setting last_action_data") + last_action_data = "no-op" + logging.info(f"Set {last_action_data=}") + await websocket.accept() + logging.info("Client connected") + while True: + try: + logging.info("Waiting for frame") + frame: np.ndarray = await recv_queue.get() + if not isinstance(frame, np.ndarray): + logging.warning(f"Got frame of type {type(frame)}") + continue + logging.info(f"Got frame with shape {frame.shape}") + msg = image_to_msg(frame) + if msg is not None: + await websocket.send_text(msg) + + try: + action_data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0 / fps) + last_action_data = action_data + logging.info(f"Got action {action_data}, updated {last_action_data=}") + except asyncio.TimeoutError: + logging.info(f"Timed out waiting for action, using {last_action_data=}") + action_data = last_action_data + + action = msg_to_action(action_data, actions) + + await send_queue.put(action) + except WebSocketDisconnect: + logging.info("Client disconnected, waiting for new connection.") + await websocket.close() + await websocket.accept() # Accept a new WebSocket connection + except Exception as e: + logging.error("An error occurred: %s", e) + break # Break the loop in case of non-WebSocketDisconnect exceptions + + current_file_path = os.path.abspath(os.path.dirname(__file__)) + static_directory_path = os.path.join(current_file_path, "static") + app.mount("/static", StaticFiles(directory=static_directory_path), name="static") + + config = uvicorn.Config(app, host="0.0.0.0", port=port) + server = uvicorn.Server(config) + + await server.serve() + + +class HumanPlayer(CogmentActor): + def __init__(self, send_queue: asyncio.Queue, recv_queue: asyncio.Queue): + super().__init__(send_queue, recv_queue) + self.send_queue = send_queue + self.recv_queue = recv_queue + + async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int: + logging.info( + f"Getting an action with {observation=}" + f" and {rendered_frame.shape=}" + if rendered_frame is not None + else "no frame" + ) + await self.send_queue.put(rendered_frame) + action = await self.recv_queue.get() + return action + + +async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: asyncio.Queue, signal_queue: mp.Queue): + context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") + + human_player = HumanPlayer(send_queue, recv_queue) + + logging.info("Registering actor") + + context.register_actor(impl=human_player.impl, impl_name="web_ui", actor_classes=["player"]) + logging.info("Serving actor") + + serve = context.serve_all_registered(cogment.ServedEndpoint(port=port)) + + signal_queue.put(True) + + await serve + + +async def shutdown(): + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + asyncio.get_event_loop().stop() + + +def signal_handler(sig, frame): + asyncio.create_task(shutdown()) + + +async def main(app_port: int = 8000, cogment_port: int = 8999): + app_to_actor = asyncio.Queue() + actor_to_app = asyncio.Queue() + fastapi_task = asyncio.create_task(start_fastapi(port=app_port, send_queue=app_to_actor, recv_queue=actor_to_app)) + cogment_task = asyncio.create_task( + run_cogment_actor(port=cogment_port, send_queue=actor_to_app, recv_queue=app_to_actor) + ) + + await asyncio.gather(fastapi_task, cogment_task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/cogment_lab/humans/runner.py b/cogment_lab/humans/runner.py new file mode 100644 index 0000000..15f9566 --- /dev/null +++ b/cogment_lab/humans/runner.py @@ -0,0 +1,118 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import signal +from multiprocessing import Queue +from typing import Any + +from cogment_lab.humans.actor import start_fastapi, run_cogment_actor +from cogment_lab.utils.runners import setup_logging + + +# def human_actor_runner( +# app_port: int = 8000, +# cogment_port: int = 8999, +# log_file: str | None = None +# ): +# """Runs the human actor along with the FastAPI server""" +# if log_file: +# setup_logging(log_file) +# +# # Queues for communication between FastAPI and Cogment actor +# app_to_actor = asyncio.Queue() +# actor_to_app = asyncio.Queue() +# +# # Asyncio tasks for the FastAPI server and Cogment actor +# fastapi_task = start_fastapi(port=app_port, send_queue=app_to_actor, recv_queue=actor_to_app) +# cogment_task = asyncio.create_task(run_cogment_actor(port=cogment_port, send_queue=actor_to_app, recv_queue=app_to_actor)) +# +# # Run the asyncio event loop +# asyncio.run(asyncio.gather(fastapi_task, cogment_task)) + + +async def shutdown(): + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + asyncio.get_event_loop().stop() + + +def signal_handler(sig, frame): + asyncio.create_task(shutdown()) + + +async def human_actor_main( + app_port: int, + cogment_port: int, + signal_queue: Queue, + actions: list[str] | None = None, + fps: float = 30, + html_override: str | None = None, + file_override: str | None = None, + jinja_parameters: dict[str, Any] | None = None, +): + app_to_actor = asyncio.Queue() + actor_to_app = asyncio.Queue() + fastapi_task = asyncio.create_task( + start_fastapi( + port=app_port, + send_queue=app_to_actor, + recv_queue=actor_to_app, + actions=actions, + fps=fps, + html_override=html_override, + file_override=file_override, + jinja_parameters=jinja_parameters, + ) + ) + cogment_task = asyncio.create_task( + run_cogment_actor( + port=cogment_port, send_queue=actor_to_app, recv_queue=app_to_actor, signal_queue=signal_queue + ) + ) + + await asyncio.gather(fastapi_task, cogment_task) + + +def human_actor_runner( + app_port: int, + cogment_port: int, + signal_queue: Queue, + log_file: str | None = None, + actions: list[str] | None = None, + fps: float = 30, + html_override: str | None = None, + file_override: str | None = None, + jinja_parameters: dict[str, Any] | None = None, +): + if log_file: + setup_logging(log_file) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + for sig in [signal.SIGINT, signal.SIGTERM]: + loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame)) + + try: + loop.run_until_complete( + human_actor_main( + app_port, cogment_port, signal_queue, actions, fps, html_override, file_override, jinja_parameters + ) + ) + finally: + loop.close() diff --git a/cogment_lab/humans/static/index.html b/cogment_lab/humans/static/index.html new file mode 100644 index 0000000..6f86525 --- /dev/null +++ b/cogment_lab/humans/static/index.html @@ -0,0 +1,60 @@ + + + + + + + RL Env Interface + + +

RL Env Interface

+ Cogment UI
+ + + + diff --git a/cogment_lab/process_manager.py b/cogment_lab/process_manager.py new file mode 100644 index 0000000..38d4697 --- /dev/null +++ b/cogment_lab/process_manager.py @@ -0,0 +1,491 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import datetime +import logging +import os +from asyncio import Task +from collections.abc import Sequence +from multiprocessing import Process, Queue +from typing import Callable, Any, Coroutine + +import cogment + +from cogment_lab.core import BaseEnv, BaseActor +from cogment_lab.generated import cog_settings, data_pb2 +from cogment_lab.humans.runner import human_actor_runner +from cogment_lab.actors.runner import actor_runner +from cogment_lab.envs.runner import env_runner + +import multiprocessing as mp +import atexit + +from cogment_lab.utils.trial_utils import get_actor_params, format_data_multiagent, TrialData + +ORCHESTRATOR_ENDPOINT = f"grpc://localhost:9000" +ENVIRONMENT_ENDPOINT = f"grpc://localhost:9001" +RANDOM_AGENT_ENDPOINT = f"grpc://localhost:9002" +HUMAN_AGENT_ENDPOINT = f"grpc://localhost:8999" +DATASTORE_ENDPOINT = f"grpc://localhost:9003" + + +AgentName = str +ImplName = str +TrialName = str + + +class Cogment: + """Main Cogment class for managing experiments""" + + def __init__( + self, + user_id: str = "cogment_lab", + torch_mode: bool = False, + log_dir: str | None = None, + mp_method: str | None = None, + ): + """Initializes the Cogment instance + + Args: + user_id (str, optional): User ID. Defaults to "cogment_lab". + torch_mode (bool, optional): Whether to use PyTorch multiprocessing. Defaults to False. + log_dir (str, optional): Directory to store logs. Defaults to "logs". + mp_method (str | None, optional): Multiprocessing method to use. Defaults to None. + """ + self.processes: dict[ImplName, Process] = {} + self.tasks: dict[ImplName, Task] = {} + + self._register_shutdown_hook() + self.torch_mode = torch_mode + self.log_dir = log_dir + + self.envs: dict[ImplName, BaseEnv] = {} + self.actors: dict[ImplName, BaseActor] = {} + + self.context = cogment.Context(cog_settings=cog_settings, user_id=user_id) + self.controller = self.context.get_controller(endpoint=cogment.Endpoint(ORCHESTRATOR_ENDPOINT)) + self.datastore = self.context.get_datastore(endpoint=cogment.Endpoint(DATASTORE_ENDPOINT)) + + self.env_ports: dict[ImplName, int] = {} + self.actor_ports: dict[ImplName, int] = {} + + self.mp_ctx = mp.get_context(mp_method) if mp_method else mp.get_context() + + self.trial_envs: dict[TrialName, ImplName] = {} + + def _add_process( + self, target: Callable, args: tuple, name: ImplName, use_torch: bool | None = None, force: bool = False + ): + """Adds a process to the list of processes + + Args: + target (Callable): The process target function + args (tuple): Arguments for the process target + name (ImplName): Name of the process + use_torch (bool | None, optional): Whether to use PyTorch multiprocessing. Defaults to None. + force (bool, optional): Whether to force adding the process if it already exists. Defaults to False. + + Raises: + ValueError: If the process already exists and force is False + """ + if use_torch is None: + use_torch = self.torch_mode + + if name in self.processes and not force: + raise ValueError(f"Process {name} already exists") + + if use_torch: + from torch.multiprocessing import Process as TorchProcess + + p = TorchProcess(target=target, args=args) + else: + p = self.mp_ctx.Process(target=target, args=args) + p.start() + self.processes[name] = p + + def _add_task(self, target: Coroutine, name: ImplName) -> Task: + """Adds a task to the list of tasks + + Args: + target (Coroutine): The task target function + name (ImplName): Name of the task + + Returns: + Task: The task instance + """ + if name in self.tasks: + raise ValueError(f"Task {name} already exists") + + task = asyncio.create_task(target) + self.tasks[name] = task + return task + + def run_env( + self, env: BaseEnv, env_name: ImplName, port: int = 9001, log_file: str | None = None + ) -> Coroutine[bool]: + """Given an environment, runs it in a subprocess + + Args: + env (BaseEnv): The environment instance + env_name (ImplName): Name for the environment + port (int, optional): Port to run the environment on. Defaults to 9001. + log_file (str | None, optional): Log file path. Defaults to None. + + Returns: + bool: Whether the environment startup succeeded + """ + env_class = type(env) + env_args = env.args + env_kwargs = env.kwargs + + signal_queue = Queue(1) + + if self.log_dir is not None and log_file: + log_file = os.path.join(self.log_dir, log_file) + + self._add_process( + target=env_runner, + name=env_name, + args=(env_class, env_args, env_kwargs, env_name, signal_queue, port, log_file), + ) + logging.info(f"Started environment {env_name} on port {port} with log file {log_file}") + + self.envs[env_name] = env + self.env_ports[env_name] = port + + return self.is_ready(signal_queue) + + def run_actor( + self, actor: BaseActor, actor_name: ImplName, port: int = 9002, log_file: str | None = None + ) -> Coroutine[bool]: + """Given an actor, runs it + + Args: + actor (BaseActor): The actor instance + actor_name (ImplName): Name for the actor + port (int, optional): Port to run the actor on. Defaults to 9002. + log_file (str | None, optional): Log file path. Defaults to None. + + Returns: + bool: Whether the actor startup succeeded + """ + actor_class = type(actor) + actor_args = actor.args + actor_kwargs = actor.kwargs + + signal_queue = Queue(1) + + if self.log_dir is not None and log_file: + log_file = os.path.join(self.log_dir, log_file) + + self._add_process( + target=actor_runner, + name=actor_name, + args=(actor_class, actor_args, actor_kwargs, actor_name, signal_queue, port, log_file), + ) + logging.info(f"Started actor {actor_name} on port {port} with log file {log_file}") + + self.actors[actor_name] = actor + self.actor_ports[actor_name] = port + + return self.is_ready(signal_queue) + + def run_local_actor( + self, actor: BaseActor, actor_name: ImplName, port: int = 9002, log_file: str | None = None + ) -> Task: + """Given an actor, runs it locally + + Args: + actor (BaseActor): The actor instance + actor_name (ImplName): Name for the actor + port (int, optional): Port to run the actor on. Defaults to 9002. + log_file (str | None, optional): Log file path. Defaults to None. + + Returns: + bool: Whether the actor startup succeeded + """ + + if self.log_dir is not None and log_file: + log_file = os.path.join(self.log_dir, log_file) + + self.context.register_actor(impl=actor.impl, impl_name=actor_name, actor_classes=["player"]) + + serve = self._add_task(self.context.serve_all_registered(cogment.ServedEndpoint(port=port)), actor_name) + + logging.info(f"Started actor {actor_name} on port {port} with log file {log_file}") + + self.actors[actor_name] = actor + self.actor_ports[actor_name] = port + + return serve + + def run_web_ui( + self, + app_port: int = 8000, + cogment_port: int = 8999, + actions: list[str] | dict[str, Any] = [], + log_file: str | None = None, + fps: int = 30, + html_override: str | None = None, + file_override: str | None = None, + jinja_parameters: dict[str, Any] | None = None, + ) -> Coroutine[bool]: + """Runs the human actor in a separate process + + Args: + app_port (int, optional): Port for web UI. Defaults to 8000. + cogment_port (int, optional): Port for Cogment connection. Defaults to 8999. + actions (list[str], optional): Allowed actions. Defaults to []. + log_file (str | None, optional): Log file path. Defaults to None. + fps (int, optional): Frames per second for environment. Defaults to 30. + html_override (str | None, optional): HTML override file path. Defaults to None. + file_override (str | None, optional): File override file path. Defaults to None. + jinja_parameters (dict[str, Any] | None, optional): Jinja parameters for the HTML override. Defaults to None. + + Returns: + bool: Whether the web UI startup succeeded + """ + + signal_queue = Queue(1) + + if self.log_dir is not None and log_file: + log_file = os.path.join(self.log_dir, log_file) + + self._add_process( + target=human_actor_runner, + name="web_ui", + args=( + app_port, + cogment_port, + signal_queue, + log_file, + actions, + fps, + html_override, + file_override, + jinja_parameters, + ), + ) + logging.info(f"Started web UI on port {app_port} with log file {log_file}") + + self.actor_ports["web_ui"] = cogment_port + + return self.is_ready(signal_queue) + + def stop_service(self, name: ImplName, timeout: float = 1.0): + """Stops a process or a task. + + Args: + name (str): Name of the process or task to stop + timeout (float, optional): How long (in seconds) to wait for the process to stop before killing it. Defaults to 1.0. + """ + if name in self.processes: + self._stop_process(name, timeout) + elif name in self.tasks: + self._stop_task(name) + else: + raise ValueError(f"Service {name} does not exist") + + def _stop_process(self, name: ImplName, timeout: float = 1.0): + """Stops a process + + Args: + name (str): Name of the process to stop + timeout (float, optional): How long (in seconds) to wait for the process to stop before killing it. Defaults to 1.0. + """ + if name not in self.processes: + raise ValueError(f"Process {name} does not exist") + logging.info(f"Stopping process {name}") + process = self.processes[name] + if timeout == 0.0: + process.kill() + process.join() + else: + process.terminate() + process.join(timeout=timeout) + if process.is_alive(): + process.kill() + process.join() + del self.processes[name] + + def _stop_task(self, name: ImplName): + """Stops a task + + Args: + name (str): Name of the task to stop + """ + if name not in self.tasks: + raise ValueError(f"Task {name} does not exist") + logging.info(f"Stopping task {name}") + task = self.tasks[name] + task.cancel() + del self.tasks[name] + + def stop_all_services(self, timeout: float = 1.0): + """Stops all processes and tasks + + Args: + timeout (float, optional): How long (in seconds) to wait for the processes to stop before killing them. Defaults to 1.0. + """ + for name in list(self.processes.keys()): + self._stop_process(name, timeout) + for name in list(self.tasks.keys()): + self._stop_task(name) + + def _cleanup_wrapper(self): + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + asyncio.run(self.cleanup()) + loop.close() + + async def cleanup(self, timeout: float = 1.0): + """Cleans up all processes. Idempotent.""" + tasks = list(self.tasks.keys()) + processes = list(self.processes.keys()) + + for name in tasks: + try: + self._stop_task(name) + except Exception as e: + logging.warning(f"Failed to stop task {name}: {e}") + + for name in processes: + try: + self._stop_process(name, timeout) + except Exception as e: + logging.warning(f"Failed to stop process {name}: {e}") + + try: + await self.context._grpc_server.stop(None) + except asyncio.exceptions.CancelledError: + logging.info("Server already stopped") + except AttributeError: + logging.info("Server not started") + + def _register_shutdown_hook(self): + """Registers the cleanup method to run on shutdown""" + atexit.register(self._cleanup_wrapper) + + async def start_trial( + self, + env_name: ImplName, + actor_impls: dict[AgentName, ImplName] | ImplName, + session_config: dict[str, Any] = {}, + trial_name: TrialName | None = None, + ): + """Starts a new trial + + Args: + env_name (ImplName): Name of the environment implementation + actor_impls (dict[AgentName, ImplName] | ImplName): Actor implementations mapped to agent names + session_config (dict[str, Any]): kwargs for the environment session + trial_name (str | None, optional): Trial name. Defaults to None. + + Returns: + str: The trial ID + """ + if trial_name is None: + trial_name = f"{env_name}-{datetime.datetime.now().isoformat()}" + + if isinstance(actor_impls, str): + actor_impls = {"gym": actor_impls} + + env = self.envs[env_name] + actor_params = [ + get_actor_params( + name=agent_name, + implementation=actor_impl, + agent_specs=env.agent_specs[agent_name], + endpoint=f"grpc://localhost:{self.actor_ports[actor_impl]}", + ) + for agent_name, actor_impl in actor_impls.items() + ] + + env_config = data_pb2.EnvironmentConfig(**session_config) + + trial_params = cogment.TrialParameters( + cog_settings, + environment_name=env_name, + environment_endpoint=f"grpc://localhost:{self.env_ports[env_name]}", + environment_config=env_config, + actors=actor_params, + environment_implementation=env_name, + datalog_endpoint=DATASTORE_ENDPOINT, + ) + + trial_id = await self.controller.start_trial(trial_id_requested=trial_name, trial_params=trial_params) + + logging.info(f"Started trial {trial_id} with name {trial_name}") + + self.trial_envs[trial_id] = env_name + + return trial_id + + async def get_trial_data( + self, + trial_id: str, + env_name: str | None = None, + fields: Sequence[str] = ( + "observations", + "actions", + "rewards", + "done", + "next_observations", + "last_observation", + ), + ) -> dict[str, TrialData]: + """Gets trial data from the datastore, formatting it appropriately.""" + if env_name is None: + env_name = self.trial_envs[trial_id] + env = self.envs[env_name] + agent_specs = env.agent_specs + + data = await format_data_multiagent(self.datastore, trial_id, agent_specs, fields) + + return data + + async def get_trial(self, trial_id: str): + """Gets a trial by ID + + Args: + trial_id (str): The trial ID + + Returns: + Trial: The trial instance + """ + [trial] = await self.datastore.get_trials(ids=[trial_id]) + return trial + + def __del__(self): + """Cleanup on delete""" + self.stop_all_services() + + async def is_ready(self, queue: Queue): + """Waits for a readiness signal on a queue + + Args: + queue (Queue): The queue to wait on + + Returns: + Any: The object that was put on the queue + """ + while queue.empty(): + await asyncio.sleep(0.1) + return queue.get() diff --git a/cogment_lab/protos/cogment.yaml b/cogment_lab/protos/cogment.yaml new file mode 100644 index 0000000..3712420 --- /dev/null +++ b/cogment_lab/protos/cogment.yaml @@ -0,0 +1,20 @@ +import: + proto: + - ndarray.proto + - spaces.proto + - data.proto + +environment: + config_type: cogment_lab.EnvironmentConfig + +trial: + config_type: cogment_lab.TrialConfig + +# Static configuration +actor_classes: + - name: player + action: + space: cogment_lab.PlayerAction + observation: + space: cogment_lab.Observation + config_type: cogment_lab.AgentConfig diff --git a/cogment_lab/protos/data.proto b/cogment_lab/protos/data.proto new file mode 100644 index 0000000..b622e7f --- /dev/null +++ b/cogment_lab/protos/data.proto @@ -0,0 +1,79 @@ +// Copyright 2024 AI Redefined Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cogment_lab; + +import "ndarray.proto"; +import "spaces.proto"; + + +message EnvironmentSpecs { + string implementation = 1; + bool turn_based = 2; + int32 num_players = 3; + spaces.Space observation_space = 4; + spaces.Space action_space = 5; + string web_components_file = 6; +} + +message AgentSpecs { + spaces.Space observation_space = 1; + spaces.Space action_space = 2; +} + +message Value { + oneof value_type { + string string_value = 1; + int32 int_value = 2; + float float_value = 3; + } +} +message EnvironmentConfig { + string run_id = 1; + bool render = 2; + int32 render_width = 3; + uint32 seed = 4; + bool flatten = 5; + map reset_args = 6; +} + +message HFHubModel { + string repo_id = 1; + string filename = 2; +} + +message AgentConfig { + string run_id = 1; + AgentSpecs agent_specs = 2; + uint32 seed = 3; + string model_id = 4; + int32 model_iteration = 5; + int32 model_update_frequency = 6; +} + +message TrialConfig { +} + +message Observation { + nd_array.Array value = 1; + bool active = 2; + bool alive = 3; + optional bytes rendered_frame = 4; +} + +message PlayerAction { + nd_array.Array value = 1; +} \ No newline at end of file diff --git a/cogment_lab/protos/ndarray.proto b/cogment_lab/protos/ndarray.proto new file mode 100644 index 0000000..24dc3a9 --- /dev/null +++ b/cogment_lab/protos/ndarray.proto @@ -0,0 +1,38 @@ +// Copyright 2024 AI Redefined Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package cogment_lab.nd_array; + +enum DType { + DTYPE_UNKNOWN = 0; + DTYPE_FLOAT32 = 1; + DTYPE_FLOAT64 = 2; + DTYPE_INT8 = 3; + DTYPE_INT32 = 4; + DTYPE_INT64 = 5; + DTYPE_UINT8 = 6; +} + +message Array { + DType dtype = 1; + repeated uint32 shape = 2; + bytes raw_data = 3; + bytes npy_data = 4; + repeated double double_data = 5; + repeated sint32 int32_data = 6; + repeated sint64 int64_data = 7; + repeated uint32 uint32_data = 8; +} diff --git a/cogment_lab/protos/spaces.proto b/cogment_lab/protos/spaces.proto new file mode 100644 index 0000000..04d4686 --- /dev/null +++ b/cogment_lab/protos/spaces.proto @@ -0,0 +1,55 @@ +// Copyright 2024 AI Redefined Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +import "ndarray.proto"; + +package cogment_lab.spaces; + +message Discrete { + int32 n = 1; + int32 start = 2; +} + +message Box { + nd_array.Array low = 2; + nd_array.Array high = 3; +} + +message MultiBinary { + nd_array.Array n = 1; +} + +message MultiDiscrete { + nd_array.Array nvec = 1; +} + +message Dict { + message SubSpace { + string key = 1; + Space space = 2; + } + repeated SubSpace spaces = 1; +} + +message Space { + oneof kind { + Discrete discrete = 1; + Box box = 2; + Dict dict = 3; + MultiBinary multi_binary = 4; + MultiDiscrete multi_discrete = 5; + } +} diff --git a/cogment_lab/py.typed b/cogment_lab/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/cogment_lab/session_helpers.py b/cogment_lab/session_helpers.py new file mode 100644 index 0000000..26fa7cc --- /dev/null +++ b/cogment_lab/session_helpers.py @@ -0,0 +1,173 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Any + +from cogment.actor import ActorSession +from cogment.environment import EnvironmentSession +from cogment.model_registry_v2 import ModelRegistry +from cogment.session import RecvEvent, ActorInfo + +from cogment_lab.specs import AgentSpecs +from cogment_lab.specs.action_space import ActionSpace, Action +from cogment_lab.specs.observation_space import ObservationSpace, Observation + + +class ActorSessionHelper: + """ + Cogment Verse actor session helper + + Provides additional methods to the regular Cogment actor session. + """ + + def __init__(self, actor_session: ActorSession, model_registry: Optional[ModelRegistry]): + self.actor_session = actor_session + self.agent_specs = AgentSpecs.deserialize(self.actor_session.config.agent_specs) + self.action_space = self.agent_specs.get_action_space(seed=self.actor_session.config.seed) + self.observation_space = self.agent_specs.get_observation_space() + self.model_registry = model_registry + + def get_action_space(self) -> ActionSpace: + return self.action_space + + def get_observation_space(self) -> ObservationSpace: + return self.observation_space + + def get_observation(self, event: RecvEvent) -> Observation: + """ + Return the cogment verse observation for the current event. + + If the event does not contain an observation, return None. + """ + if not event.observation: + return None + + return self.observation_space.deserialize(event.observation.observation) + + def get_render(self, event: RecvEvent) -> bytes: + """ + Return the render for the current event. + + If the event does not contain a render, return None. + """ + if not event.observation: + return None + + return event.observation.render + + +class EnvironmentSessionHelper: + """ + A session helper for environments. + """ + + actor_infos: list[ActorInfo] + + def __init__(self, environment_session: EnvironmentSession, agent_specs: dict[str, AgentSpecs]): + self.actor_infos = environment_session.get_active_actors() + + assert set(agent_specs.keys()) == set( + actor_info.actor_name for actor_info in self.actor_infos + ), f"Agent specs and active actors do not match. {agent_specs.keys()} != {self.actor_infos}" + + # Mapping actor_name to actor_idx + self.actor_idxs = {actor_info.actor_name: actor_idx for (actor_idx, actor_info) in enumerate(self.actor_infos)} + # Mapping actor_idx to actor_info + self.actors = [actor_info.actor_name for actor_info in self.actor_infos] + + if isinstance(agent_specs, AgentSpecs): + agent_specs = {actor_name: agent_specs for actor_name in self.actors} + + self.agent_specs = agent_specs + self.observation_spaces = { + agent_id: specs.get_observation_space(environment_session.config.render_width) + for (agent_id, specs) in agent_specs.items() + } + + def get_observation_space(self, actor_name: str) -> ObservationSpace: + return self.observation_spaces[actor_name] + + def get_action_space(self, actor_name: str) -> ActionSpace: + return self.agent_specs[actor_name].get_action_space() + + def _get_actor_idx(self, actor_name: str) -> int: + actor_idx = self.actor_idxs.get(actor_name) + + if actor_idx is None: + raise RuntimeError(f"No actor with name [{actor_name}] found!") + + return actor_idx + + def get_action(self, tick_data: Any, actor_name: str) -> Action | None: + # For environments, tick_datas are events + event: RecvEvent = tick_data + + if not event.actions or not event.actions: + return None + + actor_idx = self._get_actor_idx(actor_name) + action_space = self.get_action_space(actor_name) + + return action_space.deserialize( + event.actions[actor_idx].action, + ) + + def get_observation(self, tick_data: Any, actor_name: str): + """ + Return the cogment verse observation of a given actor at a tick. + + If no observation, returns None. + """ + raise NotImplementedError + + def get_player_actions(self, tick_data: Any, actor_name: str) -> Action | None: + """ + Return the cogment verse player action of a given actor at a tick. + + If only a single player actor is present, no `actor_name` is required. + + If no action, returns None. + """ + event = tick_data + if not event.actions: + return None + + actions = [ + self.get_action(actor_name, tick_data) + for player_actor_name in self.actors + if player_actor_name == actor_name + ] + if len(actions) == 0: + raise RuntimeError(f"No player actors having name [{actor_name}]") + return actions[0] + + def get_player_observations(self, tick_data: Any, actor_name: str): + if actor_name is None: + observations = [self.get_observation(tick_data, actor_name) for player_actor_name in self.actors] + if len(observations) == 0: + raise RuntimeError("No player actors") + if len(observations) > 1: + raise RuntimeError("More than 1 player actor, please provide an actor name") + return observations[0] + + observations = [ + self.get_observation(tick_data, actor_name) + for player_actor_name in self.actors + if player_actor_name == actor_name + ] + if len(observations) == 0: + raise RuntimeError(f"No player actors having name [{actor_name}]") + return observations[0] diff --git a/cogment_lab/specs/__init__.py b/cogment_lab/specs/__init__.py new file mode 100644 index 0000000..4279b84 --- /dev/null +++ b/cogment_lab/specs/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .encode_rendered_frame import encode_rendered_frame +from .environment_specs import AgentSpecs diff --git a/cogment_lab/specs/action_space.py b/cogment_lab/specs/action_space.py new file mode 100644 index 0000000..180e3f2 --- /dev/null +++ b/cogment_lab/specs/action_space.py @@ -0,0 +1,135 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym + +from cogment_lab.generated.data_pb2 import ( # pylint: disable=import-error + PlayerAction, +) +from .ndarray_serialization import deserialize_ndarray, serialize_ndarray + + +# pylint: disable=attribute-defined-outside-init +class Action: + """ + Cogment Verse actor action + + Properties: + flat_value: + The action value, as a flat numpy array. + value: + The action value, as a numpy array. + """ + + def __init__(self, gym_space: gym.Space, pb_action=None, value=None): + """ + Action constructor. + Shouldn't be called directly, prefer the factory function of ActionSpace. + """ + self._gym_space = gym_space + + if pb_action is not None: + assert value is None + self._pb_action = pb_action + return + + self._value = value + + def _compute_flat_value(self): + if hasattr(self, "_value"): + value = self._value + if value is None: + return None + return gym.spaces.flatten(self._gym_space, self._value) + + if not self._pb_action.value != b"": + # This happens whenever value is None + return None + + return deserialize_ndarray(self._pb_action.value) + + @property + def flat_value(self): + if not hasattr(self, "_flat_value"): + self._flat_value = self._compute_flat_value() + return self._flat_value + + def _compute_value(self): + flat_value = self.flat_value + if flat_value is None: + return None + return gym.spaces.unflatten(self._gym_space, flat_value) + + @property + def value(self): + if not hasattr(self, "_value"): + self._value = self._compute_value() + return self._value + + +class ActionSpace: + """ + Cogment Verse action space + + Properties: + gym_space: + Wrapped Gym space for the action values (cf. https://www.gymlibrary.dev/api/spaces/) + actor_class: + Class of the actor for which this space will serialize Action probobug messages + seed: + Random seed used when generating random actions + """ + + def __init__(self, gym_space: gym.Space, seed: int = None): + self.gym_space = gym_space + + if seed: + self.gym_space.seed(int(seed)) + + def create(self, value=None): + """ + Create an Action + """ + return Action(self.gym_space, value=value) + + def sample(self, mask=None): + """ + Generate a random Action + """ + return Action(self.gym_space, value=self.gym_space.sample(mask=mask)) + + def serialize( + self, + action, + ): + """ + Serialize an Action to an Action protobuf message + """ + if action.value is None: + return PlayerAction() + + serialized_value = serialize_ndarray(action.flat_value) + return PlayerAction(value=serialized_value) + + def deserialize(self, pb_action): + """ + Deserialize an Action from an Action protobuf message + """ + return Action(self.gym_space, pb_action=pb_action) + + def create_serialize(self, value=None): + """ + Create and serialize an Action to an Action protobuf message + """ + return self.serialize(self.create(value)) diff --git a/cogment_lab/specs/encode_rendered_frame.py b/cogment_lab/specs/encode_rendered_frame.py new file mode 100644 index 0000000..4d6aad8 --- /dev/null +++ b/cogment_lab/specs/encode_rendered_frame.py @@ -0,0 +1,68 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + +MAX_RENDERED_WIDTH = 2048 + + +def encode_rendered_frame(rendered_frame: np.ndarray, max_size: int = MAX_RENDERED_WIDTH) -> bytes: + if max_size <= 0: + max_size = MAX_RENDERED_WIDTH + # gRPC max message size hack + height, width = rendered_frame.shape[:2] + if max(height, width) > max_size: + if height > width: + new_height = max_size + new_width = int(new_height / height * width) + else: + new_width = max_size + new_height = int(height / width * new_width) + rendered_frame = cv2.resize(rendered_frame, (new_width, new_height), interpolation=cv2.INTER_AREA) + + # note rgb -> bgr for cv2 + result, encoded_frame = cv2.imencode(".jpg", rendered_frame[:, :, ::-1]) + assert result + + return encoded_frame.tobytes() + + +def decode_rendered_frame(encoded_frame: bytes) -> np.ndarray: + """ + Decode the rendered frame from bytes to a NumPy array. + + Args: + encoded_frame (bytes): The encoded frame as a byte array. + + Returns: + np.ndarray: The decoded rendered frame as a NumPy array. + """ + if encoded_frame is None or len(encoded_frame) == 0: + return None + + # Convert the byte array back to a NumPy array + encoded_frame_np = np.frombuffer(encoded_frame, dtype=np.uint8) + + # Decode the image from the byte array + decoded_frame = cv2.imdecode(encoded_frame_np, cv2.IMREAD_COLOR) + + # Check if the decoding was successful + if decoded_frame is None: + raise ValueError("Failed to decode the rendered frame.") + + # Convert from BGR to RGB + decoded_frame = decoded_frame[:, :, ::-1] + + return decoded_frame diff --git a/cogment_lab/specs/environment_specs.py b/cogment_lab/specs/environment_specs.py new file mode 100644 index 0000000..04bfc35 --- /dev/null +++ b/cogment_lab/specs/environment_specs.py @@ -0,0 +1,93 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import gymnasium as gym + +from cogment_lab.generated.data_pb2 import ( + AgentSpecs as PbAgentSpecs, +) +from .action_space import ActionSpace +from .ndarray_serialization import SerializationFormat +from .observation_space import ObservationSpace +from .spaces_serialization import deserialize_space, serialize_gym_space +from ..constants import DEFAULT_RENDERED_WIDTH + + +class AgentSpecs: + """ + Representation of the specification of an agent within Cogment Lab. + """ + + def __init__(self, agent_specs_pb: PbAgentSpecs): + """ + AgentSpecs constructor. + Shouldn't be called directly, prefer the factory function such as AgentSpecs.deserialize or AgentSpecs.create_homogeneous. + """ + self._pb = agent_specs_pb + + def get_observation_space(self, render_width: int = DEFAULT_RENDERED_WIDTH) -> ObservationSpace: + """ + Build an instance of the observation space for this environment + + Parameters: + render_width: optional + maximum width for the serialized rendered frame in observation + + NOTE: In the future we'll want to support different observation space per agent role + """ + return ObservationSpace(deserialize_space(self._pb.observation_space), render_width) + + def get_action_space(self, seed: int | None = None) -> ActionSpace: + """ + Build an instance of the action space for this environment + + Parameters: + seed: optional + the seed used when generating random actions + + NOTE: In the future we'll want to support different action space per agent roles + """ + return ActionSpace(deserialize_space(self._pb.action_space), seed) + + @classmethod + def create_homogeneous( + cls, + observation_space: gym.Space, + action_space: gym.Space, + serialization_format: SerializationFormat = SerializationFormat.STRUCTURED, + ): + """ + Factory function building a homogenous EnvironmentSpecs, ie with all actors having the same action and observation spaces. + """ + return cls.deserialize( + PbAgentSpecs( + observation_space=serialize_gym_space(observation_space, serialization_format), + action_space=serialize_gym_space(action_space, serialization_format), + ) + ) + + def serialize(self): + """ + Serialize to a EnvironmentSpecs protobuf message + """ + return self._pb + + @classmethod + def deserialize(cls, agent_specs_pb: PbAgentSpecs): + """ + Factory function building an EnvironmentSpecs instance from a EnvironmentSpecs protobuf message + """ + return cls(agent_specs_pb) diff --git a/cogment_lab/specs/ndarray_serialization.py b/cogment_lab/specs/ndarray_serialization.py new file mode 100644 index 0000000..5178058 --- /dev/null +++ b/cogment_lab/specs/ndarray_serialization.py @@ -0,0 +1,145 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import io +from enum import Enum + +import numpy as np + +from cogment_lab.generated.ndarray_pb2 import ( + DTYPE_FLOAT32, + DTYPE_FLOAT64, + DTYPE_INT8, + DTYPE_INT32, + DTYPE_INT64, + DTYPE_UINT8, + DTYPE_UNKNOWN, + Array, +) + + +PB_DTYPE_FROM_DTYPE = { + "float32": DTYPE_FLOAT32, + "float64": DTYPE_FLOAT64, + "int8": DTYPE_INT8, + "int32": DTYPE_INT32, + "int64": DTYPE_INT64, + "uint8": DTYPE_UINT8, +} + +DTYPE_FROM_PB_DTYPE = { + DTYPE_FLOAT32: np.dtype("float32"), + DTYPE_FLOAT64: np.dtype("float64"), + DTYPE_INT8: np.dtype("int8"), + DTYPE_INT32: np.dtype("int32"), + DTYPE_INT64: np.dtype("int64"), + DTYPE_UINT8: np.dtype("uint8"), +} + +DOUBLE_DTYPES = frozenset(["float32", "float64"]) +UINT32_DTYPES = frozenset(["uint8"]) +INT32_DTYPES = frozenset(["int8", "int32"]) +INT64_DTYPES = frozenset(["int64"]) + + +class SerializationFormat(Enum): + RAW = 1 + NPY = 2 + STRUCTURED = 3 + + +def serialize_ndarray( + nd_array: np.ndarray, + serialization_format: SerializationFormat = SerializationFormat.RAW, +) -> Array: + str_dtype = str(nd_array.dtype) + pb_dtype = PB_DTYPE_FROM_DTYPE.get(str_dtype, DTYPE_UNKNOWN) + + # SerializationFormat.RAW + if serialization_format is SerializationFormat.RAW: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + raw_data=nd_array.tobytes(order="C"), + ) + + # SerializationFormat.NPY + if serialization_format is SerializationFormat.NPY: + buffer = io.BytesIO() + np.save(buffer, nd_array, allow_pickle=False) + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + npy_data=buffer.getvalue(), + ) + + # SerializationFormat.STRUCTURED: + if str_dtype in DOUBLE_DTYPES: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + double_data=nd_array.ravel(order="C").tolist(), + ) + if str_dtype in UINT32_DTYPES: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + uint32_data=nd_array.ravel(order="C").tolist(), + ) + if str_dtype in INT32_DTYPES: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + int32_data=nd_array.ravel(order="C").tolist(), + ) + if str_dtype in INT64_DTYPES: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + int64_data=nd_array.ravel(order="C").tolist(), + ) + + raise RuntimeError( + f"[{str_dtype}] is not a supported numpy dtype for serialization format [{serialization_format}]" + ) + + +def deserialize_ndarray(pb_array: Array) -> np.ndarray | None: + dtype = DTYPE_FROM_PB_DTYPE.get(pb_array.dtype) + str_dtype = str(dtype) + shape = tuple(pb_array.shape) + + if len(pb_array.raw_data) > 0: + return np.frombuffer(pb_array.raw_data, dtype=dtype).reshape(shape, order="C") + + if len(pb_array.npy_data) > 0: + buffer = io.BytesIO(pb_array.npy_data) + return np.load(buffer, allow_pickle=False) + + # SerializationFormat.STRUCTURED + if str_dtype in DOUBLE_DTYPES: + return np.array(pb_array.double_data, dtype=dtype).reshape(shape, order="C") + if str_dtype in UINT32_DTYPES: + return np.array(pb_array.uint32_data, dtype=dtype).reshape(shape, order="C") + if str_dtype in INT32_DTYPES: + return np.array(pb_array.int32_data, dtype=dtype).reshape(shape, order="C") + if str_dtype in INT64_DTYPES: + return np.array(pb_array.int64_data, dtype=dtype).reshape(shape, order="C") + + return None + # raise RuntimeError( + # f"[{str_dtype}] is not a supported numpy dtype for serialization format [SerializationFormat.STRUCTURED]" + # ) diff --git a/cogment_lab/specs/observation_space.py b/cogment_lab/specs/observation_space.py new file mode 100644 index 0000000..b768409 --- /dev/null +++ b/cogment_lab/specs/observation_space.py @@ -0,0 +1,227 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym +import numpy as np + +from cogment_lab.constants import DEFAULT_RENDERED_WIDTH +from cogment_lab.generated.data_pb2 import ( + Observation as PbObservation, +) +from .encode_rendered_frame import encode_rendered_frame, decode_rendered_frame +from .ndarray_serialization import deserialize_ndarray, serialize_ndarray + + +# pylint: disable=attribute-defined-outside-init +class Observation: + """ + Cogment Verse actor observation + + Properties: + flat_value: + The observation value, as a flat numpy array. + value: + The observation value, as a numpy array. + active: optional + Boolean indicating if the object is active. + alive: optional + Boolean indicating if the object is alive. + rendered_frame: optional + Environment's rendered frame as a numpy array of RGB pixels. + """ + + def __init__( + self, + gym_space: gym.Space, + pb_observation=None, + value=None, + active=None, + alive=None, + rendered_frame=None, + ): + """ + Observation constructor. + Shouldn't be called directly, prefer the factory function of ObservationSpace. + """ + + self._gym_space = gym_space + + if pb_observation is not None: + assert value is None + assert active is None + assert alive is None + assert rendered_frame is None + self._pb_observation = pb_observation + return + + self._value = value + self._active = active + self._alive = alive + self._rendered_frame = rendered_frame + + self._pb_observation = PbObservation( + active=active, + alive=alive, + ) + + def _compute_flat_value(self): + if hasattr(self, "_value"): + return gym.spaces.flatten(self._gym_space, self._value) + + if not self._pb_observation.value != b"" or self._pb_observation.value is None: + return None + + return deserialize_ndarray(self._pb_observation.value) + + @property + def flat_value(self): + if not hasattr(self, "_flat_value"): + self._flat_value = self._compute_flat_value() + return self._flat_value + + def _compute_value(self): + return gym.spaces.unflatten(self._gym_space, self.flat_value) if self.flat_value is not None else None + + @property + def value(self): + if not hasattr(self, "_value"): + self._value = self._compute_value() + return self._value + + def _deserialize_rendered_frame(self): + if not self._pb_observation.rendered_frame != b"": + return None + return decode_rendered_frame(self._pb_observation.rendered_frame) + + @property + def rendered_frame(self): + if not hasattr(self, "_rendered_frame"): + self._rendered_frame = self._deserialize_rendered_frame() + return self._rendered_frame + + @property + def active(self): + return self._pb_observation.active if self._pb_observation.active != b"" else self._active + + @property + def alive(self): + return self._pb_observation.alive if self._pb_observation.alive != b"" else self._alive + + def __repr__(self): + return f"Observation(value={self.value.shape if isinstance(self.value, np.ndarray) else self.value}, active={self.active}, alive={self.alive}, rendered_frame={self.rendered_frame.shape if self.rendered_frame is not None else 'None'})@{hex(id(self))}" + + def __str__(self): + return self.__repr__() + + +class ObservationSpace: + """ + Cogment Verse observation space + + Properties: + gym_space: + Wrapped Gym space for the observation values + render_width: + Maximum width for the serialized rendered frame in observations + """ + + def __init__(self, space: gym.Space, render_width: int = DEFAULT_RENDERED_WIDTH): + """ + ObservationSpace constructor. + Shouldn't be called directly, prefer the factory function of EnvironmentSpecs. + """ + if isinstance(space, gym.spaces.Dict) and ("action_mask" in space.spaces): + # Check the observation space defines an action_mask "component" (like PettingZoo does) + assert "observation" in space.spaces + assert len(space.spaces) == 2 + + self.gym_space = space.spaces["observation"] + self.action_mask_gym_space = space.spaces["action_mask"] + else: + # "Standard" observation space, no action_mask + self.gym_space = space + self.action_mask_gym_space = None + + # Other configuration + self.render_width = render_width + + def create( + self, + value=None, + active=None, + alive=None, + rendered_frame=None, + ) -> Observation: + """ + Create an Observation + """ + return Observation( + self.gym_space, + value=value, + active=active, + alive=alive, + rendered_frame=rendered_frame, + ) + + def serialize( + self, + observation: Observation, + ) -> PbObservation: + """ + Serialize an Observation to an Observation protobuf message + """ + + serialized_value = None + if observation.value is not None: + flat_value = gym.spaces.flatten(self.gym_space, observation.value) + serialized_value = serialize_ndarray(flat_value) + + serialized_rendered_frame = None + if observation.rendered_frame is not None: + serialized_rendered_frame = encode_rendered_frame( + rendered_frame=observation.rendered_frame, max_size=self.render_width + ) + + return PbObservation( + value=serialized_value, + active=observation.active, + alive=observation.alive, + rendered_frame=serialized_rendered_frame, + ) + + def deserialize(self, pb_observation: PbObservation) -> Observation: + """ + Deserialize an Observation from an Observation protobuf message + """ + + return Observation(self.gym_space, pb_observation=pb_observation) + + def create_serialize( + self, + value=None, + active=None, + alive=None, + rendered_frame=None, + ) -> PbObservation: + """ + Create a serialized Observation + """ + return self.serialize( + self.create( + value=value, + active=active, + alive=alive, + rendered_frame=rendered_frame, + ) + ) diff --git a/cogment_lab/specs/spaces_serialization.py b/cogment_lab/specs/spaces_serialization.py new file mode 100644 index 0000000..f09c2dc --- /dev/null +++ b/cogment_lab/specs/spaces_serialization.py @@ -0,0 +1,100 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym +import numpy as np + +from cogment_lab.generated.spaces_pb2 import ( # pylint: disable=import-error + Box, + Dict, + Discrete, + MultiBinary, + MultiDiscrete, + Space, +) + +from .ndarray_serialization import ( + SerializationFormat, + deserialize_ndarray, + serialize_ndarray, +) + + +def serialize_gym_space(space: gym.Space, serialization_format=SerializationFormat.STRUCTURED) -> Space: + if isinstance(space, (gym.spaces.Discrete, gym.spaces.Discrete)): + return Space(discrete=Discrete(n=space.n, start=space.start)) + if isinstance(space, gym.spaces.Box): + low = space.low + high = space.high + return Space( + box=Box( + low=serialize_ndarray(low, serialization_format=serialization_format), + high=serialize_ndarray(high, serialization_format=serialization_format), + ), + ) + + if isinstance(space, gym.spaces.MultiBinary): + if isinstance(space.n, np.ndarray): + size = space.n + elif isinstance(space.n, int): + size = np.array([space.n], dtype=np.dtype("int32")) + else: + size = np.array(space.n, dtype=np.dtype("int32")) + return Space( + multi_binary=MultiBinary(n=serialize_ndarray(size, serialization_format=serialization_format)), + ) + + if isinstance(space, gym.spaces.MultiDiscrete): + nvec = space.nvec + return Space( + multi_discrete=MultiDiscrete(nvec=serialize_ndarray(nvec, serialization_format=serialization_format)), + ) + + if isinstance(space, gym.spaces.Dict): + spaces = [] + for key, gym_sub_space in space.spaces.items(): + spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space))) + return Space(dict=Dict(spaces=spaces)) + raise RuntimeError(f"[{type(space)}] is not a supported space type") + + +def deserialize_space(pb_space: Space) -> gym.Space: + space_kind = pb_space.WhichOneof("kind") + if space_kind == "discrete": + discrete_space_pb = pb_space.discrete + return gym.spaces.Discrete(n=discrete_space_pb.n, start=discrete_space_pb.start) + if space_kind == "box": + box_space_pb = pb_space.box + low = deserialize_ndarray(box_space_pb.low) + high = deserialize_ndarray(box_space_pb.high) + return gym.spaces.Box(low=low, high=high, shape=low.shape, dtype=low.dtype) + if space_kind == "multi_binary": + multi_binary_space_pb = pb_space.multi_binary + size = deserialize_ndarray(multi_binary_space_pb.n) + if size.size > 1: + return gym.spaces.MultiBinary(n=size) + return gym.spaces.MultiBinary(n=size[0]) + if space_kind == "multi_discrete": + multi_discrete_space_pb = pb_space.multi_discrete + nvec = deserialize_ndarray(multi_discrete_space_pb.nvec) + return gym.spaces.MultiDiscrete(nvec=nvec) + if space_kind == "dict": + dict_space_pb = pb_space.dict + spaces = [] + for sub_space in dict_space_pb.spaces: + spaces.append((sub_space.key, deserialize_space(sub_space.space))) + + return gym.spaces.Dict(spaces=spaces) + + raise RuntimeError(f"[{space_kind}] is not a supported space kind") diff --git a/cogment_lab/utils/__init__.py b/cogment_lab/utils/__init__.py new file mode 100644 index 0000000..7b0b8d1 --- /dev/null +++ b/cogment_lab/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .import_class import import_object diff --git a/cogment_lab/utils/coltra_utils.py b/cogment_lab/utils/coltra_utils.py new file mode 100644 index 0000000..32ecd9b --- /dev/null +++ b/cogment_lab/utils/coltra_utils.py @@ -0,0 +1,63 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from coltra import Agent +from coltra.models import BaseModel + +from cogment_lab.utils.trial_utils import TrialData +from coltra.buffers import OnPolicyRecord, Observation, Action + + +def convert_trial_data_to_coltra(trial_data: TrialData, agent: Agent) -> OnPolicyRecord: + """Convert TrialData to OnPolicyRecord. + + Args: + trial_data (TrialData): TrialData instance + model (BaseModel): Model instance to evaluate values + + Returns: + OnPolicyRecord: Converted OnPolicyRecord instance + """ + obs = trial_data.observations + action = trial_data.actions + reward = trial_data.rewards + done = trial_data.done + # state = None # Assuming 'state' is not provided in TrialData + + # last_value = agent.act(Observation(vector=trial_data.last_observation), get_value=True)[2]["value"] + # value = agent.act(Observation(vector=trial_data.observations), get_value=True)[2]["value"] + + last_value, _ = agent.value(Observation(vector=trial_data.last_observation), ()) + value, _ = agent.value(Observation(vector=trial_data.observations), ()) + + last_value = last_value.detach().squeeze(-1).cpu().numpy() + value = value.detach().squeeze(-1).cpu().numpy() + + # Check if required fields are not None + if obs is None or action is None or reward is None or done is None: + raise ValueError("Missing required fields in TrialData for conversion") + + # Create an OnPolicyRecord instance with the mapped fields + on_policy_record = OnPolicyRecord( + obs=Observation(vector=obs).tensor(), + action=Action(discrete=action).tensor(), + reward=torch.tensor(reward.astype(np.float32)), + value=torch.tensor(value.astype(np.float32)), + done=torch.tensor(done.astype(np.float32)), + last_value=torch.tensor(last_value.astype(np.float32)), + ) + + return on_policy_record diff --git a/cogment_lab/utils/grpc.py b/cogment_lab/utils/grpc.py new file mode 100644 index 0000000..2a8a2d2 --- /dev/null +++ b/cogment_lab/utils/grpc.py @@ -0,0 +1,101 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from cogment_lab.core import Observation +from cogment_lab.generated import data_pb2 +from google.protobuf.json_format import ParseDict + + +def extend_actor_config( + actor_config_template: dict, + run_id: str, + agent_specs: data_pb2.AgentSpecs, + seed: int, +) -> data_pb2.AgentConfig: + """Extends an actor configuration template with additional parameters. + + Args: + actor_config_template: Template for actor configuration, possibly None. + run_id: Identifier for the run. + agent_specs: Specifications for the environment. + seed: Seed for random number generation. + + Returns: + An instance of AgentConfig with extended configuration. + """ + config = data_pb2.AgentConfig() + if actor_config_template: + ParseDict(actor_config_template, config) + config.run_id = run_id + config.agent_specs.CopyFrom(agent_specs) + config.seed = seed + return config + + +def create_value(val: str | int | float) -> data_pb2.Value: + """Creates a Value protobuf message from a Python value. + + Args: + val: The input string, integer, or float. + + Returns: + A Value protobuf message containing the input value. + """ + value_message = data_pb2.Value() + if isinstance(val, str): + value_message.string_value = val + elif isinstance(val, int): + value_message.int_value = val + elif isinstance(val, float): + value_message.float_value = val + else: + raise ValueError("Unsupported type") + return value_message + + +def get_env_config( + run_id: str | None = None, + render: bool | None = None, + render_width: int | None = None, + seed: int | None = None, + flatten: bool | None = None, + reset_args_dict: dict[str, str | int | float] | None = None, +) -> data_pb2.EnvironmentConfig: + """Generates an EnvironmentConfig protobuf message. + + Args: + run_id: Identifier for the run. + render: Whether to render the environment. + render_width: Render width if rendering. + seed: Random seed. + flatten: Whether to flatten the observation. + reset_args_dict: Dictionary of reset argument values. + + Returns: + An EnvironmentConfig protobuf message. + """ + env_config = data_pb2.EnvironmentConfig() + + env_config.run_id = run_id + env_config.render = render + env_config.render_width = render_width + env_config.seed = seed + env_config.flatten = flatten + + for key, val in reset_args_dict.items(): + env_config.reset_args[key].CopyFrom(create_value(val)) + + return env_config diff --git a/cogment_lab/utils/import_class.py b/cogment_lab/utils/import_class.py new file mode 100644 index 0000000..4cb4ddf --- /dev/null +++ b/cogment_lab/utils/import_class.py @@ -0,0 +1,29 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib import import_module + + +def import_object(class_name: str): + """Imports an object from a module based on a string + + Args: + class_name (str): The full path to the object e.g. "package.module.Class" + + Returns: + object: The imported object + """ + module_path, class_name = class_name.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) diff --git a/cogment_lab/utils/runners.py b/cogment_lab/utils/runners.py new file mode 100644 index 0000000..b42ac6e --- /dev/null +++ b/cogment_lab/utils/runners.py @@ -0,0 +1,61 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import subprocess +import sys + + +def setup_logging(log_file: str): + """ + Set up logging to file and stdout/stderr. + + Args: + log_file: Path to log file + """ + # Redirect stdout and stderr to log file + dirname = os.path.dirname(log_file) + if dirname: + os.makedirs(dirname, exist_ok=True) + with open(log_file, "a") as f: + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler(log_file)], # , logging.StreamHandler()], + ) + + +def process_cleanup(): + """ + Clean up any leftover processes related to cogment_lab. + """ + try: + pid = os.getpid() + + command = ( + f"ps aux | grep 'cogment-lab' | grep 'multiprocessing' | grep -v grep | " + f"awk '{{if ($3 != {pid}) print $2}}' | xargs -r kill -9" + ) + + subprocess.run(command, shell=True, check=True) + print("Processes terminated successfully.") + except subprocess.CalledProcessError as e: + print(f"An error occurred: {e}") diff --git a/cogment_lab/utils/trial_utils.py b/cogment_lab/utils/trial_utils.py new file mode 100644 index 0000000..b087cb8 --- /dev/null +++ b/cogment_lab/utils/trial_utils.py @@ -0,0 +1,415 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Literal, Any + +import cogment +import numpy as np +from cogment import ActorParameters +from cogment.datastore import Datastore, DatastoreSample +import gymnasium as gym + +from cogment_lab.core import CogmentEnv +from cogment_lab.generated import cog_settings +from cogment_lab.specs import AgentSpecs +from cogment_lab.utils.grpc import extend_actor_config + + +def get_actor_params( + name: str, + implementation: str, + agent_specs: AgentSpecs, + endpoint: str = "grpc://localhost:9002", + base_params: dict | None = None, + run_id: str = "run_id", + seed: int = 0, +) -> ActorParameters: + """ + Create and return actor parameters for a given actor. + + Args: + name (str): The name of the actor. + implementation (str): The implementation type of the actor. + agent_specs (AgentSpecs): The agent specifications. + endpoint (str): The endpoint URL, defaults to "grpc://localhost:9002". + base_params (dict | None): Base parameters for the actor, optional. + run_id (str): The run ID, defaults to "run_id". + seed (int): The seed for randomness, defaults to 0. + + Returns: + ActorParameters: The configured actor parameters. + """ + if base_params is None: + base_params = {} + params = ActorParameters( + cog_settings, + name=name, + class_name="player", + endpoint=endpoint, + implementation=implementation, + config=extend_actor_config(base_params, run_id, agent_specs.serialize(), seed), + ) + return params + + +@dataclass +class TrialData: + """ + Dataclass to store structured trial data for reinforcement learning. + + Attributes: + observations (np.ndarray | dict[str, np.ndarray] | None): Observations from the trial. + actions (np.ndarray | dict[str, np.ndarray] | None): Actions taken during the trial. + rewards (np.ndarray | None): Rewards received during the trial. + done (np.ndarray | None): Done flags for each step of the trial. + next_observations (np.ndarray | dict[str, np.ndarray] | None): Next observations after each step. + last_observation (np.ndarray | dict[str, np.ndarray] | None): The final observation of the trial. + """ + + observations: np.ndarray | dict[str, np.ndarray] | None = field(default=None) + actions: np.ndarray | dict[str, np.ndarray] | None = field(default=None) + rewards: np.ndarray | None = field(default=None) + done: np.ndarray | None = field(default=None) + next_observations: np.ndarray | dict[str, np.ndarray] | None = field(default=None) + last_observation: np.ndarray | dict[str, np.ndarray] | None = field(default=None) + + +def initialize_buffer(space: gym.Space | None, length: int) -> np.ndarray | dict[str, np.ndarray]: + """ + Initializes a buffer based on the given gym space and length. + + Args: + space (gym.Space | None): The gym space to base the buffer on. If None, an empty buffer is created. + length (int): The length of the buffer. + + Returns: + Union[np.ndarray, Dict[str, np.ndarray]]: The initialized buffer, either as an ndarray or a dictionary of ndarrays. + """ + if space is None: + return np.empty((length,), dtype=np.float32) + elif isinstance(space, gym.spaces.Dict): + return {key: np.empty((length,) + space[key].shape, dtype=space[key].dtype) for key in space.spaces.keys()} + elif isinstance(space, gym.spaces.Tuple): + return {i: np.empty((length,) + space[i].shape, dtype=space[i].dtype) for i in range(len(space.spaces))} + else: # Simple space + return np.empty((length,) + space.shape, dtype=space.dtype) + + +def write_to_buffer( + buffer: np.ndarray | dict[str, np.ndarray], + data: np.ndarray | dict[str, Any], + idx: int, +): + """ + Writes data to a specified index in the buffer. + + Args: + buffer (np.ndarray | dict[str, np.ndarray]): The buffer to write data to. + data (np.ndarray | dict[str, Any]): The data to write. + idx (int): The index at which to write the data. + + """ + if isinstance(buffer, dict): + for key in buffer.keys(): + buffer[key][idx] = data[key] + else: + buffer[idx] = data + + +async def format_data( + datastore: Datastore, + trial_id: str, + fields: Sequence[str] = ( + "observations", + "actions", + "done", + "next_observations", + "last_observation", + ), + agent_specs: AgentSpecs | None = None, + env: CogmentEnv | None = None, +) -> TrialData: + """ + Formats trial data from a Cogment trial into a structured format for reinforcement learning. + + Args: + datastore (Datastore): The datastore to fetch trial data from. + trial_id (str): The identifier of the trial. + fields (List[str]): The list of fields to include in the formatted data. + agent_specs (Optional[EnvironmentSpecs]): The environment specifications, optional. + env (Optional[CogmentEnv]): A Cogment environment, optional. + + Returns: + TrialData: The formatted trial data. + + Raises: + AssertionError: If both agent_specs and env are None. + """ + assert agent_specs is not None or env is not None, "Either agent_specs or env must be provided" + assert agent_specs is None or env is None, "Only one of agent_specs and env can be provided" + + if agent_specs is None: + agent_specs = env.agent_specs + + trials = await datastore.get_trials([trial_id]) + samples = [] + async for sample in datastore.all_samples(trials): + samples.append(sample) + if sample.trial_state == cogment.TrialState.ENDED: + break + # if len(samples) >= sample_count: + # break + + data = extract_data_from_samples(samples, fields, agent_specs) + + return data + + +def extract_data_from_samples( + samples: list[DatastoreSample], + fields: Sequence[str] = ( + "observations", + "actions", + "rewards", + "done", + "next_observations", + "last_observation", + ), + agent_specs: AgentSpecs | None = None, + actor_name: str = "gym", +) -> TrialData: + """ + Extracts trial data into a TrialData instance from a list of DatastoreSamples. + + Args: + samples (list[DatastoreSample]): The samples to extract data from. + fields (Sequence[str]): The fields to extract into the TrialData. + agent_specs (AgentSpecs | None): The environment specifications. + actor_name (str): The name of the actor to extract data for. + + Returns: + TrialData: The extracted trial data. + """ + sample_count = len(samples) + if sample_count == 0: + raise ValueError("No samples provided") + + cog_observation_space = agent_specs.get_observation_space() + observation_space = cog_observation_space.gym_space + + cog_action_space = agent_specs.get_action_space() + action_space = cog_action_space.gym_space + + data = TrialData() + if "observations" in fields: + data.observations = initialize_buffer(observation_space, sample_count - 1) + if "actions" in fields: + data.actions = initialize_buffer(action_space, sample_count - 1) + if "rewards" in fields: + data.rewards = initialize_buffer(None, sample_count - 1) + if "done" in fields: + data.done = initialize_buffer(None, sample_count - 1) + data.done[-1] = True + if "next_observations" in fields: + data.next_observations = initialize_buffer(observation_space, sample_count - 1) + if "last_observation" in fields: + data.last_observation = initialize_buffer(observation_space, 1) + + for i, sample in enumerate(samples[:-1]): + if "observations" in fields: + obs = cog_observation_space.deserialize(sample.actors_data[actor_name].observation).value + write_to_buffer(data.observations, obs, i) + if "actions" in fields: + action = cog_action_space.deserialize(sample.actors_data[actor_name].action).value + write_to_buffer(data.actions, action, i) + if "rewards" in fields: + write_to_buffer(data.rewards, sample.actors_data[actor_name].reward, i) + if "done" in fields and i < sample_count - 2: + write_to_buffer(data.done, False, i) + if "next_observations" in fields: + next_obs = cog_observation_space.deserialize(samples[i + 1].actors_data[actor_name].observation).value + write_to_buffer(data.next_observations, next_obs, i) + if "last_observation" in fields: + last_obs = agent_specs.get_observation_space().deserialize(samples[-1].actors_data[actor_name].observation).value + write_to_buffer(data.last_observation, last_obs, 0) + + return data + + +def extract_rewards_from_samples( + samples: list[DatastoreSample], + agent_specs: AgentSpecs | None = None, + actor_name: str = "gym", +) -> TrialData: + """ + Extracts rewards from trial samples into a TrialData instance. + + Args: + samples (list[DatastoreSample]): The samples to extract rewards from. + agent_specs (AgentSpecs | None): The environment specifications. + actor_name (str): The name of the actor to extract rewards for. + + Returns: + TrialData: The extracted rewards. + """ + sample_count = len(samples) + + rewards = initialize_buffer(None, sample_count - 1) + + for i, sample in enumerate(samples[:-1]): + write_to_buffer(rewards, sample.actors_data[actor_name].reward, i) + + return rewards + + +def concat_trial_field( + field_data: list[np.ndarray | dict[str, np.ndarray] | None], +) -> np.ndarray | dict[str, np.ndarray] | None: + """ + Concatenates a list of fields (either np.ndarray or dict of np.ndarrays) from TrialData instances. + Filters out None values before concatenation. + + Args: + field_data: List of fields to be concatenated. + + Returns: + Concatenated field as np.ndarray or dict of np.ndarrays, or None if all elements are None. + """ + # Filter out None values + valid_field_data = [data for data in field_data if data is not None] + + if not valid_field_data: + return None + + if all(isinstance(data, np.ndarray) for data in valid_field_data): + return np.concatenate(valid_field_data, axis=0) + elif all(isinstance(data, dict) for data in valid_field_data): + keys = valid_field_data[0].keys() + return {key: np.concatenate([data[key] for data in valid_field_data], axis=0) for key in keys} + else: + raise TypeError("Inconsistent field types in TrialData list.") + + +def concatenate(trial_data_list: list[TrialData]) -> TrialData: + """ + Concatenates a list of TrialData instances into a single TrialData instance. + + Args: + trial_data_list: List of TrialData instances to be concatenated. + + Returns: + A single concatenated TrialData instance. + """ + observations = concat_trial_field([trial.observations for trial in trial_data_list]) + actions = concat_trial_field([trial.actions for trial in trial_data_list]) + rewards = concat_trial_field([trial.rewards for trial in trial_data_list]) + done = concat_trial_field([trial.done for trial in trial_data_list]) + next_observations = concat_trial_field([trial.next_observations for trial in trial_data_list]) + + # Handle 'last_observation' separately + last_trial = trial_data_list[-1] + last_observation = ( + last_trial.last_observation if last_trial.last_observation is not None else last_trial.next_observations[-1] + ) + + return TrialData( + observations=observations, + actions=actions, + rewards=rewards, + done=done, + next_observations=next_observations, + last_observation=last_observation, + ) + + +async def format_data_multiagent( + datastore: Datastore, + trial_id: str, + actor_agent_specs: dict[str, AgentSpecs], + fields: Sequence[str] = ( + "observations", + "actions", + "rewards", + "done", + "next_observations", + "last_observation", + ), +) -> dict[str, TrialData]: + """ + Formats trial data from a multiagent Cogment trial into structured formats for reinforcement learning. + + Args: + datastore (Datastore): The datastore to fetch trial data from. + trial_id (str): The identifier of the trial. + actor_agent_specs (dict[str, EnvironmentSpecs]): A dictionary mapping actor IDs to their environment specifications. + fields (List[str]): The list of fields to include in the formatted data. + + Returns: + dict[str, TrialData]: A dictionary mapping actor IDs to their formatted trial data. + """ + trials = [] + while len(trials) == 0: + try: + trials = await datastore.get_trials([trial_id]) + except cogment.CogmentError: + continue + + # Initialize a dictionary to store samples for each actor + actor_samples = {actor_id: [] for actor_id in actor_agent_specs.keys()} + actor_reward_samples = {actor_id: [] for actor_id in actor_agent_specs.keys()} + + # Get all samples + all_samples = [] + async for sample in datastore.all_samples(trials): + all_samples.append(sample) + + # Sort according to tick_id -- this might not be necessary with some version of cogment + all_samples.sort(key=lambda x: x.tick_id) + + for sample in all_samples: + for actor_id in actor_agent_specs.keys(): + # Add the sample to the list for an actor if the observation for that actor is not None + if ( + sample.actors_data.get(actor_id) + and sample.actors_data[actor_id].observation is not None + and sample.actors_data[actor_id].observation.value is not None + and sample.actors_data[actor_id].observation.value.raw_data != b"" + ): + actor_samples[actor_id].append(sample) + + if ( + sample.actors_data.get(actor_id) + and sample.actors_data[actor_id].reward is not None + and sample.actors_data[actor_id].reward == sample.actors_data[actor_id].reward # Check for NaN + ): + actor_reward_samples[actor_id].append(sample) + + # Extract data for each actor + actor_data = {} + + for actor_id, samples in actor_samples.items(): + actor_data[actor_id] = extract_data_from_samples( + samples, fields, actor_agent_specs[actor_id], actor_name=actor_id + ) + + for actor_id, reward_samples in actor_reward_samples.items(): + actor_data[actor_id].rewards = extract_rewards_from_samples( + reward_samples, actor_agent_specs[actor_id], actor_name=actor_id + ) + + return actor_data diff --git a/cogment_lab/utils/yaml_utils.py b/cogment_lab/utils/yaml_utils.py new file mode 100644 index 0000000..c57ec0a --- /dev/null +++ b/cogment_lab/utils/yaml_utils.py @@ -0,0 +1,181 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym +import numpy as np +import yaml + + +def gym_space_constructors(): + """Registers YAML constructors for Gym spaces. + + This allows Gym spaces to be created automatically from YAML files + by registering constructors for each space type. + """ + + def box_constructor(loader, node): + """YAML constructor for Box spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Box space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + return gym.spaces.Box( + low=np.array(values.get("low", -np.inf)), + high=np.array(values.get("high", np.inf)), + shape=values.get("shape", None), + dtype=values.get("dtype", np.float32), + seed=values.get("seed", None), + ) + + def discrete_constructor(loader, node): + """YAML constructor for Discrete spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Discrete space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + return gym.spaces.Discrete(n=values["n"], seed=values.get("seed", None), start=values.get("start", 0)) + + def multibinary_constructor(loader, node): + """YAML constructor for MultiBinary spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A MultiBinary space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + return gym.spaces.MultiBinary(n=values["n"], seed=values.get("seed", None)) + + def multidiscrete_constructor(loader, node): + """YAML constructor for MultiDiscrete spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A MultiDiscrete space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + return gym.spaces.MultiDiscrete( + nvec=np.array(values["nvec"]), + dtype=values.get("dtype", np.int64), + seed=values.get("seed", None), + start=values.get("start", None), + ) + + def text_constructor(loader, node): + """YAML constructor for Text spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Text space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + return gym.spaces.Text( + max_length=values["max_length"], + min_length=values.get("min_length", 1), + charset=values.get("charset", "alphanumeric"), + seed=values.get("seed", None), + ) + + def dict_constructor(loader, node): + """YAML constructor for Dict spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Dict space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + spaces = values.pop("spaces", None) + seed = values.pop("seed", None) + return gym.spaces.Dict(spaces=spaces, seed=seed, **values) + + def tuple_constructor(loader, node): + """YAML constructor for Tuple spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Tuple space constructed from the YAML node. + """ + values = loader.construct_sequence(node) + spaces = values.pop("spaces", None) + seed = values.pop("seed", None) + return gym.spaces.Tuple(spaces=spaces, seed=seed) + + def sequence_constructor(loader, node): + """YAML constructor for Sequence spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Sequence space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + space = values.get("space") + seed = values.get("seed", None) + stack = values.get("stack", False) + return gym.spaces.Sequence(space=space, seed=seed, stack=stack) + + def graph_constructor(loader, node): + """YAML constructor for Graph spaces. + + Args: + loader: The YAML loader. + node: The YAML node. + + Returns: + A Graph space constructed from the YAML node. + """ + values = loader.construct_mapping(node) + node_space = values.pop("node_space") + edge_space = values.pop("edge_space", None) + seed = values.pop("seed", None) + return gym.spaces.Graph(node_space=node_space, edge_space=edge_space, seed=seed) + + # Register constructors + yaml.add_constructor("!Box", box_constructor) + yaml.add_constructor("!Discrete", discrete_constructor) + yaml.add_constructor("!MultiBinary", multibinary_constructor) + yaml.add_constructor("!MultiDiscrete", multidiscrete_constructor) + yaml.add_constructor("!Text", text_constructor) + + yaml.add_constructor("!Dict", dict_constructor) + yaml.add_constructor("!Tuple", tuple_constructor) + yaml.add_constructor("!Graph", graph_constructor) + yaml.add_constructor("!Sequence", sequence_constructor) + yaml.add_constructor("!Tuple", tuple_constructor) diff --git a/examples/demos/active-lunar/lunar-active.py b/examples/demos/active-lunar/lunar-active.py new file mode 100644 index 0000000..3c19a47 --- /dev/null +++ b/examples/demos/active-lunar/lunar-active.py @@ -0,0 +1,201 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import random + +import numpy as np +import torch +import wandb +from coltra.models import FCNetwork +from torch import optim +from tqdm import trange +from typarse import BaseParser + +from cogment_lab.actors import ConstantActor +from cogment_lab.actors.nn_actor import NNActor, BoltzmannActor +from cogment_lab.envs import AECEnvironment, GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.utils.runners import process_cleanup +from shared import ReplayBuffer, get_current_eps, dqn_loss + + +class Parser(BaseParser): + wandb_project: str = "test" + wandb_name: str = "test" + + env_name: str = "LunarLander-v2" + + batch_size: int = 128 + gamma: float = 0.99 + replay_buffer_capacity: int = 50000 + learning_rate: float = 6.3e-4 + num_episodes: int = 500 + seed: int = 0 + + human_episodes: int = 100 + + +async def main(): + args = Parser() + + process_cleanup() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + wandb.init(project=args.wandb_project, name=args.wandb_name) + + wandb.config.batch_size = args.batch_size + wandb.config.gamma = args.gamma + wandb.config.replay_buffer_capacity = args.replay_buffer_capacity + wandb.config.learning_rate = args.learning_rate + wandb.config.num_episodes = args.num_episodes + wandb.config.seed = args.seed + wandb.config.env_name = args.env_name + + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + cenv = AECEnvironment( + env_path="cogment_lab.envs.conversions.teacher.GymTeacherAEC", + make_kwargs={"gym_env_name": args.env_name, "gym_make_kwargs": {}, "render_mode": "rgb_array"}, + render=True, + reinitialize=True, + ) + + await cog.run_env(cenv, "lunar", port=9021, log_file="env-aec.log") + + obs_len = cenv.env.observation_space("gym").shape[0] + + cenv_single = GymEnvironment(env_id=args.env_name, reinitialize=True, render=True) + + await cog.run_env(cenv_single, "lunar-single", port=9022, log_file="env-gym.log") + + # Create and run the learner network + + replay_buffer = ReplayBuffer(args.replay_buffer_capacity, obs_len) + + # Run the agent + network = FCNetwork( + input_size=obs_len, output_sizes=[cenv.env.action_space("gym").n], hidden_sizes=[256, 256], activation="tanh" + ) + + actor = NNActor(network, "cpu") + optimizer = optim.Adam(network.parameters(), lr=args.learning_rate) + + cog.run_local_actor(actor, "dqn", port=9012, log_file="dqn.log") + + # Run the human teacher + + # # Lunar lander actions + actions = { + "no-op": {"active": 0, "action": 0}, + "ArrowDown": {"active": 1, "action": 0}, + "ArrowRight": {"active": 1, "action": 1}, + "ArrowUp": {"active": 1, "action": 2}, + "ArrowLeft": {"active": 1, "action": 3}, + } + + await cog.run_web_ui(actions=actions, log_file="human.log", fps=60) + + print(f"Launched web UI at http://localhost:8000") + + total_timesteps = 0 + ep_rewards = [] + + for episode in (pbar := trange(args.num_episodes)): + actor.set_eps(get_current_eps(episode)) + if episode == args.human_episodes: + cog.stop_service("lunar") + + if episode < args.human_episodes: + trial_id = await cog.start_trial( + env_name="lunar", + actor_impls={"gym": "dqn", "teacher": "web_ui"}, + session_config={"render": True, "seed": episode}, + ) + else: + trial_id = await cog.start_trial( + env_name="lunar-single", actor_impls={"gym": "dqn"}, session_config={"render": True, "seed": episode} + ) + + trial_data_task = asyncio.create_task(cog.get_trial_data(trial_id)) + + gradient_updates = 0 + + + trial_data = await trial_data_task + + + # Logging + dqn_data = trial_data["gym"] + + total_reward = dqn_data.rewards.sum() + pbar.set_description(f"mean_reward: {total_reward:.3}") + ep_rewards.append(total_reward) + + total_timesteps += len(dqn_data.rewards) + + # Add data to replay buffer + + for t in range(len(dqn_data.done)): + state = dqn_data.observations[t] + dqn_action = dqn_data.actions[t] + try: + human_data = trial_data["teacher"] + human_active = human_data.actions["active"][t] + human_action = human_data.actions["action"][t] + except (IndexError, KeyError): + human_active = 0 + human_action = 0 + + action = human_action if human_active == 1 else dqn_action + + + reward = dqn_data.rewards[t] + next_state = dqn_data.next_observations[t] + done = dqn_data.done[t] + + replay_buffer.push(state, action, reward, next_state, done) + + # Train, once per datapoint + + if len(replay_buffer) > args.batch_size: + batch = replay_buffer.sample(args.batch_size) + + loss = dqn_loss(network, batch, args.gamma) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + gradient_updates += 1 + + log_dict = { + "episode": episode, + "reward": total_reward, + "ep_length": len(dqn_data.rewards), + "total_timesteps": total_timesteps, + "gradient_updates": gradient_updates, + } + + wandb.log(log_dict) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/demos/active-lunar/lunar-base.py b/examples/demos/active-lunar/lunar-base.py new file mode 100644 index 0000000..fee5683 --- /dev/null +++ b/examples/demos/active-lunar/lunar-base.py @@ -0,0 +1,157 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import random + +import numpy as np +import torch +import wandb +from coltra.models import FCNetwork +from torch import optim +from tqdm import trange +from typarse import BaseParser + +from cogment_lab.actors import ConstantActor +from cogment_lab.actors.nn_actor import NNActor, BoltzmannActor +from cogment_lab.envs import AECEnvironment, GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.utils.runners import process_cleanup +from shared import ReplayBuffer, get_current_eps, dqn_loss + + +class Parser(BaseParser): + wandb_project: str = "test" + wandb_name: str = "test" + + env_name: str = "LunarLander-v2" + + batch_size: int = 128 + gamma: float = 0.99 + replay_buffer_capacity: int = 50000 + learning_rate: float = 6.3e-4 + num_episodes: int = 500 + seed: int = 0 + +async def main(): + args = Parser() + + process_cleanup() + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + wandb.init(project=args.wandb_project, name=args.wandb_name) + + wandb.config.batch_size = args.batch_size + wandb.config.gamma = args.gamma + wandb.config.replay_buffer_capacity = args.replay_buffer_capacity + wandb.config.learning_rate = args.learning_rate + wandb.config.num_episodes = args.num_episodes + wandb.config.seed = args.seed + wandb.config.env_name = args.env_name + + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + + cenv = GymEnvironment(env_id=args.env_name, reinitialize=True, render=True) + + obs_len = cenv.env.observation_space.shape[0] + + await cog.run_env(cenv, "lunar", port=9021, log_file="env-gym.log") + + # Create and run the learner network + + replay_buffer = ReplayBuffer(args.replay_buffer_capacity, obs_len) + + # Run the agent + network = FCNetwork( + input_size=obs_len, output_sizes=[cenv.env.action_space.n], hidden_sizes=[256, 256], activation="tanh" + ) + + actor = NNActor(network, "cpu") + optimizer = optim.Adam(network.parameters(), lr=args.learning_rate) + + cog.run_local_actor(actor, "dqn", port=9012, log_file="dqn.log") + + total_timesteps = 0 + ep_rewards = [] + + for episode in (pbar := trange(args.num_episodes)): + actor.set_eps(get_current_eps(episode)) + if episode == args.human_episodes: + cog.stop_service("lunar") + + + trial_id = await cog.start_trial( + env_name="lunar", actor_impls={"gym": "dqn"}, session_config={"render": True, "seed": episode} + ) + + trial_data_task = asyncio.create_task(cog.get_trial_data(trial_id)) + + gradient_updates = 0 + + trial_data = await trial_data_task + + # Logging + dqn_data = trial_data["gym"] + + total_reward = dqn_data.rewards.sum() + pbar.set_description(f"mean_reward: {total_reward:.3}") + ep_rewards.append(total_reward) + + total_timesteps += len(dqn_data.rewards) + + # Add data to replay buffer + + for t in range(len(dqn_data.done)): + state = dqn_data.observations[t] + action = dqn_data.actions[t] + + reward = dqn_data.rewards[t] + next_state = dqn_data.next_observations[t] + done = dqn_data.done[t] + + replay_buffer.push(state, action, reward, next_state, done) + + # Train, once per datapoint + + if len(replay_buffer) > args.batch_size: + batch = replay_buffer.sample(args.batch_size) + + loss = dqn_loss(network, batch, args.gamma) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + gradient_updates += 1 + + log_dict = { + "episode": episode, + "reward": total_reward, + "ep_length": len(dqn_data.rewards), + "total_timesteps": total_timesteps, + "gradient_updates": gradient_updates, + } + + wandb.log(log_dict) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/demos/active-lunar/shared.py b/examples/demos/active-lunar/shared.py new file mode 100644 index 0000000..0a3330e --- /dev/null +++ b/examples/demos/active-lunar/shared.py @@ -0,0 +1,144 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from torch import nn +import hashlib + +from cogment_lab import Cogment + +EPS_INIT = 0.9 +EPS_FINAL = 0.001 +EPS_DECAY = 300 + + +class ReplayBuffer: + def __init__(self, capacity: int, obs_size: int, seed: int = 0): + self.capacity = capacity + self.buffer_counter = 0 + + self.rng = np.random.default_rng(seed) + + # Pre-allocate memory + self.states = np.zeros((capacity, obs_size), dtype=np.float32) + self.actions = np.zeros(capacity, dtype=np.int32) + self.rewards = np.zeros(capacity, dtype=np.float32) + self.next_states = np.zeros((capacity, obs_size), dtype=np.float32) + self.dones = np.zeros(capacity, dtype=np.float32) + + def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: float): + index = self.buffer_counter % self.capacity + + self.states[index] = state + self.actions[index] = action + self.rewards[index] = reward + self.next_states[index] = next_state + self.dones[index] = done + + self.buffer_counter += 1 + + def sample(self, batch_size: int) -> tuple[np.ndarray, ...]: + max_buffer_size = min(self.buffer_counter, self.capacity) + batch_indices = self.rng.choice(max_buffer_size, batch_size, replace=False) + + return ( + self.states[batch_indices], + self.actions[batch_indices], + self.rewards[batch_indices], + self.next_states[batch_indices], + self.dones[batch_indices], + ) + + def __len__(self): + return min(self.buffer_counter, self.capacity) + + +def get_current_eps( + current_step: int, eps_start: float = EPS_INIT, eps_final: float = EPS_FINAL, eps_decay_duration: int = EPS_DECAY +) -> float: + """ + Calculate the epsilon value for epsilon-greedy exploration in DQN. + + Parameters: + current_step (int): The current step index. + eps_start (float): The starting value of epsilon. + eps_final (float): The final value of epsilon. + eps_decay_duration (int): The number of steps over which epsilon is decayed linearly. + + Returns: + float: The current epsilon value. + """ + current_step = min(current_step, eps_decay_duration) + + decay_rate = (eps_start - eps_final) / eps_decay_duration + + current_epsilon = eps_start - decay_rate * current_step + + current_epsilon = max(current_epsilon, eps_final) + + return current_epsilon + + +def hash_model(model: nn.Module): + model_state = model.state_dict() + model_weights = [] + + for key, value in model_state.items(): + model_weights.append(value.cpu().numpy().tobytes()) + + model_hash = hashlib.sha256(b"".join(model_weights)).hexdigest() + + return model_hash + + +def dqn_loss(model: nn.Module, batch: tuple[np.ndarray, ...], γ: float) -> torch.Tensor: + states, actions, rewards, next_states, dones = batch + + states = torch.from_numpy(states) + actions = torch.from_numpy(actions).to(torch.int64) + rewards = torch.from_numpy(rewards) + next_states = torch.from_numpy(next_states) + dones = torch.from_numpy(dones) + + current_q_values = model(states)[0].gather(1, actions.unsqueeze(1)).squeeze(1) + next_q_values = model(next_states)[0].max(1)[0] + expected_q_values = rewards + γ * next_q_values * (1 - dones) + + loss = nn.MSELoss()(current_q_values, expected_q_values.detach()) + + return loss + + +async def evaluate_model( + cog: Cogment, env_name: str, actor_impls: dict[str, str], num_episodes: int = 10 +) -> tuple[float, float]: + total_rewards = [] + episode_lengths = [] + + for ep in range(num_episodes): + trial_id = await cog.start_trial( + env_name=env_name, actor_impls=actor_impls, session_config={"render": False, "seed": 10_000 + ep} + ) + + trial_data = await cog.get_trial_data(trial_id) + dqn_data = trial_data["gym"] + + total_rewards.append(dqn_data.rewards.sum()) + episode_lengths.append(len(dqn_data.rewards)) + + mean_reward = np.mean(total_rewards) + mean_length = np.mean(episode_lengths) + + return mean_reward, mean_length diff --git a/examples/docs_example.py b/examples/docs_example.py new file mode 100644 index 0000000..b2ea9f0 --- /dev/null +++ b/examples/docs_example.py @@ -0,0 +1,87 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime + +from cogment_lab import Cogment +from cogment_lab.actors import RandomActor, ConstantActor +from cogment_lab.envs import GymEnvironment +from cogment_lab.utils.runners import process_cleanup + +LUNAR_LANDER_ACTIONS = ["no-op", "ArrowRight", "ArrowUp", "ArrowLeft"] + + +async def main(): + # Create the global process manager + process_cleanup() + + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + # Launch the environment + env = GymEnvironment( + env_id="LunarLander-v2", # ID of a Gymnasium environment + render=True, # True if we want to see the rendering at some point + ) + + await cog.run_env( + env=env, + env_name="lunar", + port=9011, # Typically, we use ports 901x for environments and 902x for actors + log_file="env.log", + ) + + # Launch a constant actor + constant_actor = ConstantActor(0) + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9021, log_file="random.log") + + # Launch a random actor + random_actor = RandomActor(env.env.action_space) + + await cog.run_actor(actor=random_actor, actor_name="random", port=9022, log_file="constant.log") + + # Launch an episode + episode_id = await cog.start_trial( + env_name="lunar", # Which environment + actor_impls={"gym": "random"}, # Which actor(s) will act + ) + + # Compute the total reward of the episode + data = await cog.get_trial_data(trial_id=episode_id) + random_reward = data["gym"].rewards.sum() + + print(f"Random agent's reward: {random_reward}") + + # Launch a human actor UI + await cog.run_web_ui(actions=LUNAR_LANDER_ACTIONS, log_file="human.log", fps=60) + + episode_id = await cog.start_trial(env_name="lunar", actor_impls={"gym": "web_ui"}, session_config={"render": True}) + + print("Go to http://localhost:8000 in your browser and see how well you do!") + + data = await cog.get_trial_data(trial_id=episode_id) + + human_reward = data["gym"].rewards.sum() + + if human_reward > random_reward: + print(f"Good job! You beat a random agent with a score of {human_reward}!") + else: + print(f"Awkward... You lost with a score of {human_reward}...") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/gymnasium/bc-training.ipynb b/examples/gymnasium/bc-training.ipynb new file mode 100644 index 0000000..732e3b4 --- /dev/null +++ b/examples/gymnasium/bc-training.ipynb @@ -0,0 +1,746 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T17:26:58.848013Z", + "start_time": "2023-12-13T17:26:56.922017Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from coltra import HomogeneousGroup\n", + "from coltra.buffers import Observation\n", + "from coltra.models import MLPModel\n", + "from coltra.policy_optimization import CrowdPPOptimizer\n", + "from tqdm import trange\n", + "\n", + "from cogment_lab.actors import ColtraActor\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import concatenate\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:26:58.911299Z", + "start_time": "2023-12-13T17:26:58.848302Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:26:59.807004Z", + "start_time": "2023-12-13T17:26:59.288309Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T18:26:59.286916\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:03.568016Z", + "start_time": "2023-12-13T17:27:01.302459Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We'll train on \n", + "\n", + "cenv = GymEnvironment(\n", + " env_id=\"MountainCar-v0\",\n", + " render=True,\n", + " make_kwargs={\"max_episode_steps\": 401},\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"mcar\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:07.411552Z", + "start_time": "2023-12-13T17:27:05.008940Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a model using coltra\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'mcar': ,\n 'coltra': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:09.742973Z", + "start_time": "2023-12-13T17:27:09.738867Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:15.628829Z", + "start_time": "2023-12-13T17:27:13.486915Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "MOUNTAIN_CAR_ACTIONS = [\"no-op\", \"ArrowLeft\", \"ArrowRight\"]\n", + "\n", + "actions = MOUNTAIN_CAR_ACTIONS\n", + "\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [00:04<00:00, 2.18it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: -400.1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Estimate random agent performance\n", + "\n", + "episodes = []\n", + "for i in trange(10):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id)\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:27.047222Z", + "start_time": "2023-12-13T17:27:22.445172Z" + } + }, + "id": "b7cde51d7dc0c3a9" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "mean_reward: -4e+02: 100%|██████████| 10/10 [00:14<00:00, 1.47s/it]\n" + ] + } + ], + "source": [ + "# Train for a bit with PPO\n", + "\n", + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 128,\n", + "})\n", + "\n", + "all_rewards = []\n", + "\n", + "for t in (pbar := trange(10)):\n", + " num_steps = 0\n", + " episodes = []\n", + " while num_steps < 1000:\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"mcar\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " num_steps += len(data.rewards)\n", + " \n", + " all_data = concatenate(episodes)\n", + " \n", + " record = convert_trial_data_to_coltra(all_data, actor.agent)\n", + " metrics = ppo.train_on_data({\"crowd\": record}, shape=(1,) + record.reward.shape)\n", + " \n", + " mean_reward = metrics[\"crowd/mean_episode_reward\"]\n", + " all_rewards.append(mean_reward)\n", + " pbar.set_description(f\"mean_reward: {mean_reward:.3}\")\n", + " \n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:27:55.773514Z", + "start_time": "2023-12-13T17:27:41.057103Z" + } + }, + "id": "fd1d49a788c06eb4" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": "[]" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(all_rewards)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:28:10.099466Z", + "start_time": "2023-12-13T17:28:09.937126Z" + } + }, + "id": "e48689f2a72e24dc" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Reinitialize the agent\n", + "\n", + "cog.stop_service(\"coltra\")\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:28:21.263799Z", + "start_time": "2023-12-13T17:28:19.093753Z" + } + }, + "id": "5c1585be28fdae6c" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "# Get some human episodes\n", + "episodes = []\n", + "for i in range(1):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"gym\": \"web_ui\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"mcar\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " \n", + "all_data = concatenate(episodes)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:08.137344Z", + "start_time": "2023-12-13T17:29:54.830169Z" + } + }, + "id": "8f1381b80d4c8799" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: -151.0\n", + "rewards: [-151.0]\n" + ] + } + ], + "source": [ + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")\n", + "print(f\"rewards: {[sum(e.rewards) for e in episodes]}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:11.105847Z", + "start_time": "2023-12-13T17:30:11.099839Z" + } + }, + "id": "73d139b8e5d005d8" + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [], + "source": [ + "cog.stop_service(\"web_ui\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:20.978109Z", + "start_time": "2023-12-13T17:30:19.950558Z" + } + }, + "id": "73c751b525abf31e" + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [ + "all_obs = Observation(vector=all_data.observations).tensor()\n", + "all_actions = torch.tensor(all_data.actions)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:28.238655Z", + "start_time": "2023-12-13T17:30:28.234181Z" + } + }, + "id": "cdcef32e17f9afc1" + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loss: 0.28: 100%|██████████| 500/500 [00:00<00:00, 1018.11it/s] \n" + ] + } + ], + "source": [ + "losses = []\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "for t in (pbar := trange(500)):\n", + " preds = model(all_obs)[0].logits\n", + " loss = F.cross_entropy(preds, all_actions)\n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " pbar.set_description(f\"loss: {loss.item():.3}\")\n", + " \n", + " losses.append(loss.item())\n", + " " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:58.439176Z", + "start_time": "2023-12-13T17:30:57.943699Z" + } + }, + "id": "9a1cb51957e9f672" + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "data": { + "text/plain": "[]" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGtklEQVR4nO3de1xUdf4/8NdcmBluM4DAcHEU8a4oJCphN1sxdNvU/e2FyrKobNesb0WXjW+bbtk3tt3v+m3b3GzdTK39ltW325ZRNl7KRCy8oSKKgIgww30GBpiBmfP7Y2BsEtRB4AzD6/l4nAd6zucc3nM0eXU+lyMRBEEAERERkReTil0AERER0aUwsBAREZHXY2AhIiIir8fAQkRERF6PgYWIiIi8HgMLEREReT0GFiIiIvJ6DCxERETk9eRiF9AfHA4HqqqqEBwcDIlEInY5REREdBkEQUBzczNiYmIglV78GYpPBJaqqirodDqxyyAiIqI+OHv2LEaOHHnRNj4RWIKDgwE4P7BarRa5GiIiIrocZrMZOp3O9XP8YnwisHR3A6nVagYWIiKiIeZyhnNw0C0RERF5PQYWIiIi8noMLEREROT1GFiIiIjI6zGwEBERkddjYCEiIiKvx8BCREREXo+BhYiIiLweAwsRERF5PQYWIiIi8noMLEREROT1GFiIiIjI6zGwXERTqw2v7T6NJ98/LHYpREREwxoDy0XYHQJezD2Bd7+vxLmmNrHLISIiGrYYWC5iRJASM0eHAQC2HzOIXA0REdHwxcByCTdN1QIAvjxuFLkSIiKi4YuB5RJumhIFAMgva0BTq03kaoiIiIYnBpZLGDUiAJOigmF3CNhxokbscoiIiIYlBpbLcNNU51OWLziOhYiISBQMLJfhpinOcSy7T9aizWYXuRoiIqLhh4HlMkyNUSM2xB/tHQ7sKakTuxwiIqJhh4HlMkgkEszvesryJbuFiIiIBh0Dy2Xqnt78VZERnXaHyNUQERENLwwsl2l2XBg0/n5obO1AwZlGscshIiIaVhhYLpNcJsW8yZEAuIgcERHRYGNg8UD3InJfHjdAEASRqyEiIho+GFg8cMOECKj8pDjb0Iai6maxyyEiIho2GFg84K+Q4brxEQCcT1mIiIhocDCweOgm1/RmjmMhIiIaLH0KLOvWrUNcXBxUKhVSUlKwf//+i7ZvamrCypUrER0dDaVSiQkTJmDbtm2u43/4wx8gkUjctkmTJvWltAE3b7IWUglwvNqMsw2tYpdDREQ0LHgcWLZu3YqsrCysXr0aBw4cQGJiItLT01FT0/OLAW02G+bPn4/y8nK8//77KC4uxoYNGxAbG+vWburUqaiurnZte/bs6dsnGmBhgQrMigsDAGznbCEiIqJBIff0hLVr12L58uXIzMwEAKxfvx6fffYZNm7ciKeeeuqC9hs3bkRDQwP27t0LPz8/AEBcXNyFhcjliIqK8rQcUdw0NQr5ZQ348rgB91w7RuxyiIiIfJ5HT1hsNhsKCgqQlpZ2/gJSKdLS0pCXl9fjOZ988glSU1OxcuVKaLVaJCQk4IUXXoDd7v4SwVOnTiEmJgbx8fFYunQpKioqeq3DarXCbDa7bYOpexzL/rIGNFpsg/q9iYiIhiOPAktdXR3sdju0Wq3bfq1WC4Oh51kzpaWleP/992G327Ft2zY888wz+Mtf/oLnn3/e1SYlJQWbNm1Cbm4uXn31VZSVleG6665Dc3PPU4dzcnKg0Whcm06n8+RjXDFdWAAmR6vhEAD9iZ67woiIiKj/DPgsIYfDgcjISPzjH/9AcnIyMjIy8PTTT2P9+vWuNgsXLsSvfvUrTJ8+Henp6di2bRuamprw7rvv9njN7OxsmEwm13b27NmB/hgX6H7K8gVfhkhERDTgPAos4eHhkMlkMBrdB5sajcZex59ER0djwoQJkMlkrn2TJ0+GwWCAzdZzd0pISAgmTJiAkpKSHo8rlUqo1Wq3bbClT3V+3m9O1aLNZr9EayIiIroSHgUWhUKB5ORk6PV61z6HwwG9Xo/U1NQez7nmmmtQUlICh+P8G45PnjyJ6OhoKBSKHs9paWnB6dOnER0d7Ul5g2pydDBGhvqjvcOBr0/Vil0OERGRT/O4SygrKwsbNmzA5s2bUVRUhBUrVsBisbhmDS1btgzZ2dmu9itWrEBDQwMefvhhnDx5Ep999hleeOEFrFy50tXm8ccfx+7du1FeXo69e/fi5z//OWQyGW677bZ++IgDQyKRnH+3EBeRIyIiGlAeT2vOyMhAbW0tVq1aBYPBgKSkJOTm5roG4lZUVEAqPZ+DdDodvvjiCzz66KOYPn06YmNj8fDDD+N3v/udq01lZSVuu+021NfXIyIiAtdeey327duHiIiIfviIA+emqVps/LYM+hNGdNodkMu4cDAREdFAkAg+8Nphs9kMjUYDk8k0qONZOu0OzPqvr9DY2oG3l1+N1LEjBu17ExERDXWe/PzmI4ErIJdJMW9y17uF+DJEIiKiAcPAcoV++DJEH3hYRURE5JUYWK7QdeMjoPKT4lxTG45XD+6Ku0RERMMFA8sV8lfIcO045+BgfRFXvSUiIhoIDCz9IG1yJABAX8TpzURERAOBgaUf/GSSM7AcrjShxtwucjVERES+h4GlH0SqVUgcqQEA7ODLEImIiPodA0s/6Z7e/BXHsRAREfU7BpZ+ktYVWPaU1KK9gy9DJCIi6k8MLP1kcnQwYjQqtHc48G1JndjlEBER+RQGln4ikUjYLURERDRAGFj60byu6c07TnDVWyIiov7EwNKPro4fgQCFDEazFUfPcdVbIiKi/sLA0o9UfjJcP9656u12LiJHRETUbxhY+tk8rnpLRETU7xhY+tmNkyIhkQDHqsyoNrWJXQ4REZFPYGDpZ+FBSlylCwHAlyESERH1FwaWAdA9vZndQkRERP2DgWUAdK96++3perTaOkWuhoiIaOhjYBkAE7RB0IX5w9bpwDenuOotERHRlWJgGQASiQTzJrFbiIiIqL8wsAyQ7m6hHSdq4XBw1VsiIqIrwcAyQGaPCUOwUo66FisOVzaJXQ4REdGQxsAyQBRyKa6f4Fz1ltObiYiIrgwDywBKm+Jc9fYrjmMhIiK6IgwsA2juhEhIJcAJQzMqG1vFLoeIiGjIYmAZQKGBCswcHQaA3UJERERXgoFlgHW/DJHdQkRERH3HwDLAupfpzy9tQIuVq94SERH1BQPLABsbEYgx4YGw2R345mSt2OUQERENSQwsA8y56q2zW2g7u4WIiIj6pE+BZd26dYiLi4NKpUJKSgr2799/0fZNTU1YuXIloqOjoVQqMWHCBGzbtu2KrjmUdHcL7SquhZ2r3hIREXnM48CydetWZGVlYfXq1Thw4AASExORnp6OmpqeZ8HYbDbMnz8f5eXleP/991FcXIwNGzYgNja2z9ccambGhUKtkqPBYsPBikaxyyEiIhpyJIIgePS//CkpKZg1axZeeeUVAIDD4YBOp8NDDz2Ep5566oL269evx5///GecOHECfn5+/XLNHzObzdBoNDCZTFCr1Z58nEHzH28fxCeHq/DbG8biqYWTxC6HiIhIdJ78/PboCYvNZkNBQQHS0tLOX0AqRVpaGvLy8no855NPPkFqaipWrlwJrVaLhIQEvPDCC7Db7X2+5lCUNoVvbyYiIuoruSeN6+rqYLfbodVq3fZrtVqcOHGix3NKS0uxY8cOLF26FNu2bUNJSQkeeOABdHR0YPXq1X26ptVqhdVqdf3ebDZ78jFEccOECMilEpyqacGZegtGjwgUuyQiIqIhY8BnCTkcDkRGRuIf//gHkpOTkZGRgaeffhrr16/v8zVzcnKg0Whcm06n68eKB4bG3w+z4pyr3n7FVW+JiIg84lFgCQ8Ph0wmg9Ho3q1hNBoRFRXV4znR0dGYMGECZDKZa9/kyZNhMBhgs9n6dM3s7GyYTCbXdvbsWU8+hmi6V71ltxAREZFnPAosCoUCycnJ0Ov1rn0OhwN6vR6pqak9nnPNNdegpKQEDofDte/kyZOIjo6GQqHo0zWVSiXUarXbNhSkdU1v3l/WAHN7h8jVEBERDR0edwllZWVhw4YN2Lx5M4qKirBixQpYLBZkZmYCAJYtW4bs7GxX+xUrVqChoQEPP/wwTp48ic8++wwvvPACVq5cednX9BVx4YEYGxGIToeA3cVc9ZaIiOhyeTToFgAyMjJQW1uLVatWwWAwICkpCbm5ua5BsxUVFZBKz+cgnU6HL774Ao8++iimT5+O2NhYPPzww/jd73532df0JWlTtDi9uxT6IiNuSYwRuxwiIqIhweN1WLzRUFiHpdv35Q345fo8BKvkKPj9fCjkfDsCERENTwO2DgtduatGhSI8SInm9k7sPV0ndjlERERDAgPLIJNJJUif6uzqyj1qELkaIiKioYGBRQQLE6IBAF8eN6LT7rhEayIiImJgEUFKfBhCAvzQYLFhf3mD2OUQERF5PQYWEfjJpJg/md1CREREl4uBRSQLpzlX8c09aoDDMeQnahEREQ0oBhaRXDMuHMFKOWqarTh4tlHscoiIiLwaA4tIlHIZftL1bqHPC9ktREREdDEMLCJamODsFvr8qAE+sH4fERHRgGFgEdENEyLh7yfDuaY2HD1nFrscIiIir8XAIiJ/hQxzJ0YAAD4/Wi1yNURERN6LgUVkCxLOzxZitxAREVHPGFhE9pNJkVDIpCits+CksUXscoiIiLwSA4vIglV+uG58OAB2CxEREfWGgcULLJzmfLcQV70lIiLqGQOLF5g/WQu5VIIThmaU1VnELoeIiMjrMLB4AU2AH1LHjgDAbiEiIqKeMLB4iYUJ7BYiIiLqDQOLl7hpqhZSCXCk0oTKxlaxyyEiIvIqDCxeIjxIiVlxYQD4lIWIiOjHGFi8yMIfLCJHRERE5zGweJEFXeNYCioaUWNuF7kaIiIi78HA4kWiNCpcNSoEggB8cYxPWYiIiLoxsHiZ7m6hbYUMLERERN0YWLxM9/Tm/LJ61LdYRa6GiIjIOzCweBldWACmxqjhEIDtx41il0NEROQVGFi8UHe30OecLURERASAgcUrdc8W2nu6Dqa2DpGrISIiEh8DixcaFxmE8ZFB6LAL0BexW4iIiIiBxUuxW4iIiOg8BhYv1d0t9PXJWlisnSJXQ0REJC4GFi81OToYo0cEwNrpwM7iGrHLISIiElWfAsu6desQFxcHlUqFlJQU7N+/v9e2mzZtgkQicdtUKpVbm7vvvvuCNgsWLOhLaT5DIpFgAbuFiIiIAPQhsGzduhVZWVlYvXo1Dhw4gMTERKSnp6OmpvenAGq1GtXV1a7tzJkzF7RZsGCBW5u3337b09J8TvcicjtP1KC9wy5yNUREROLxOLCsXbsWy5cvR2ZmJqZMmYL169cjICAAGzdu7PUciUSCqKgo16bVai9oo1Qq3dqEhoZ6WprPSRypQYxGhVabHV+frBW7HCIiItF4FFhsNhsKCgqQlpZ2/gJSKdLS0pCXl9freS0tLRg9ejR0Oh0WL16MY8eOXdBm165diIyMxMSJE7FixQrU19f3ej2r1Qqz2ey2+SKJRIL0rm6hXHYLERHRMOZRYKmrq4Pdbr/gCYlWq4XB0PMP1IkTJ2Ljxo34+OOP8dZbb8HhcGDOnDmorKx0tVmwYAG2bNkCvV6PF198Ebt378bChQtht/fcDZKTkwONRuPadDqdJx9jSOnuFtpeZISt0yFyNUREROKQD/Q3SE1NRWpqquv3c+bMweTJk/Haa69hzZo1AIBbb73VdXzatGmYPn06xo4di127dmHevHkXXDM7OxtZWVmu35vNZp8NLcmjQxEepERdixV7T9dh7sRIsUsiIiIadB49YQkPD4dMJoPR6L76qtFoRFRU1GVdw8/PD1dddRVKSkp6bRMfH4/w8PBe2yiVSqjVarfNV8mkEixIcD7R2lZYLXI1RERE4vAosCgUCiQnJ0Ov17v2ORwO6PV6t6coF2O321FYWIjo6Ohe21RWVqK+vv6ibYaTm6fFAHCOY2G3EBERDUcezxLKysrChg0bsHnzZhQVFWHFihWwWCzIzMwEACxbtgzZ2dmu9s899xy+/PJLlJaW4sCBA7jjjjtw5swZ3HfffQCcA3KfeOIJ7Nu3D+Xl5dDr9Vi8eDHGjRuH9PT0fvqYQ9vsMWGIDFbC3N6Jb05xthAREQ0/Ho9hycjIQG1tLVatWgWDwYCkpCTk5ua6BuJWVFRAKj2fgxobG7F8+XIYDAaEhoYiOTkZe/fuxZQpUwAAMpkMR44cwebNm9HU1ISYmBjcdNNNWLNmDZRKZT99zKFNJpXgp9OisWlvOf59uArzJl84LZyIiMiXSQRBEMQu4kqZzWZoNBqYTCafHc9ScKYRv3h1LwIVMhQ8Mx8qP5nYJREREV0RT35+811CQ8SMUSGIDfGHxWbHzhN8txAREQ0vDCxDhEQiwc+mOwchf3qEs4WIiGh4YWAZQn423TlbSH/CCIu1U+RqiIiIBg8DyxCSEKtG3IgAtHc48FWR8dInEBER+QgGliHE2S3kfMrCbiEiIhpOGFiGmFsSnYFld3EtTG0dIldDREQ0OBhYhpiJUcEYHxkEm92B7cfZLURERMMDA8sQ1P2U5d+Hq0SuhIiIaHAwsAxB3dObvy2pQ4PFJnI1REREA4+BZQiKjwjC1Bg1Oh0Cco8axC6HiIhowDGwDFHnZwuxW4iIiHwfA8sQ1d0ttK+0HjXN7SJXQ0RENLAYWIYoXVgAknQhcAjA54XsFiIiIt/GwDKEdT9l4WwhIiLydQwsQ9jPpsdAIgG+P9OIqqY2scshIiIaMAwsQ1iURoVZo8MAAJ9xqX4iIvJhDCxD3C2Jzm4hzhYiIiJfxsAyxC1IiIZUAhyuNOFMvUXscoiIiAYEA8sQFxGsxJyx4QD4BmciIvJdDCw+oHu2EAMLERH5KgYWH7AgIQpyqQRF1WaU1LSIXQ4REVG/Y2DxASEBClw3vrtbiINviYjI9zCw+IhbEp3vFvr34SoIgiByNURERP2LgcVHzJ+ihUIuxelaC04YmsUuh4iIqF8xsPiIYJUfbpwYAYBL9RMRke9hYPEhP5vu7Bb69Eg1u4WIiMinMLD4kHmTI+HvJ0NFQyuOVJrELoeIiKjfMLD4kACFHPMmRwLgbCEiIvItDCw+pnu20KdHquFwsFuIiIh8AwOLj7lhQgSClXJUm9pxoKJR7HKIiIj6BQOLj1H5yTB/qhYAZwsREZHvYGDxQbd0zRb6rNAAO7uFiIjIB/QpsKxbtw5xcXFQqVRISUnB/v37e227adMmSCQSt02lUrm1EQQBq1atQnR0NPz9/ZGWloZTp071pTQCcM24cGj8/VDXYkV+ab3Y5RAREV0xjwPL1q1bkZWVhdWrV+PAgQNITExEeno6ampqej1HrVajurratZ05c8bt+J/+9Ce8/PLLWL9+PfLz8xEYGIj09HS0t7d7/okICrkUCxOiAAD/5huciYjIB3gcWNauXYvly5cjMzMTU6ZMwfr16xEQEICNGzf2eo5EIkFUVJRr02q1rmOCIOCll17C73//eyxevBjTp0/Hli1bUFVVhY8++qhPH4rOLyL3+dFq2DodIldDRER0ZTwKLDabDQUFBUhLSzt/AakUaWlpyMvL6/W8lpYWjB49GjqdDosXL8axY8dcx8rKymAwGNyuqdFokJKS0us1rVYrzGaz20buro4Pg1atRFNrBz4/yqcsREQ0tHkUWOrq6mC3292ekACAVquFwWDo8ZyJEydi48aN+Pjjj/HWW2/B4XBgzpw5qKysBADXeZ5cMycnBxqNxrXpdDpPPsawIJdJcfvs0QCALXlnLtGaiIjIuw34LKHU1FQsW7YMSUlJuOGGG/DBBx8gIiICr732Wp+vmZ2dDZPJ5NrOnj3bjxX7jttm6yCXSlBwphHHqrhUPxERDV0eBZbw8HDIZDIYjUa3/UajEVFRUZd1DT8/P1x11VUoKSkBANd5nlxTqVRCrVa7bXShSLUKC7oG377JpyxERDSEeRRYFAoFkpOTodfrXfscDgf0ej1SU1Mv6xp2ux2FhYWIjo4GAIwZMwZRUVFu1zSbzcjPz7/sa1LvlqXGAQA+OnQOptYOcYshIiLqI4+7hLKysrBhwwZs3rwZRUVFWLFiBSwWCzIzMwEAy5YtQ3Z2tqv9c889hy+//BKlpaU4cOAA7rjjDpw5cwb33XcfAOcMokceeQTPP/88PvnkExQWFmLZsmWIiYnBkiVL+udTDmOz4kIxKSoY7R0OvFfArjMiIhqa5J6ekJGRgdraWqxatQoGgwFJSUnIzc11DZqtqKiAVHo+BzU2NmL58uUwGAwIDQ1FcnIy9u7diylTprjaPPnkk7BYLLj//vvR1NSEa6+9Frm5uRcsMEeek0gkuDN1NJ7+8Cje2ncG91wzBlKpROyyiIiIPCIRBGHIr91uNpuh0WhgMpk4nqUHFmsnrn5Bj2ZrJzbfMxs3TIgQuyQiIiKPfn7zXULDQKBSjl8kjwQAvJlXLm4xREREfcDAMkzcmepck0V/ogZnG1pFroaIiMgzDCzDxNiIIFw7LhyCALyVzynOREQ0tDCwDCPLup6yvPvdWbR32EWuhoiI6PIxsAwj8yZrERvij8bWDvz7cJXY5RAREV02BpZhRCaV4I6rnU9ZNueVwwcmiBER0TDBwDLMZMzSQSGX4ug5Mw6ebRK7HCIiosvCwDLMhAUqsCgxBgCweW+5uMUQERFdJgaWYeiurvcLbSusRk1zu7jFEBERXQYGlmFo2kgNrhoVgg67gHf28/1CRETk/RhYhqm758QBAP6VfwYddoe4xRAREV0CA8swtTAhGuFBShjNVnx5zCh2OURERBfFwDJMKeRS3D5bB8A5xZmIiMibMbAMY7enjIZMKsH+sgYUVZvFLoeIiKhXDCzDWJRGhQVTowAAW/L4fiEiIvJeDCzDXPf7hT48WImmVpvI1RAREfWMgWWYmz0mDJOigtHe4cC733OKMxEReScGlmFOIpEg85o4AM5uIbuD7xciIiLvw8BCWJwUi5AAP1Q2tkFfxCnORETkfRhYCCo/GW6dNQoA8M89ZSJXQ0REdCEGFgLgHHzrJ3NOcf6+vEHscoiIiNwwsBAAICbEH7+YMRIA8MrOEpGrISIicsfAQi4r5o6FVALsKq5FYaVJ7HKIiIhcGFjIZfSIQCxOigUAvLLzlMjVEBERncfAQm4emDsWEgnwxTEjig3NYpdDREQEgIGFfmS8NhgLE5zL9a/dXixyNURERE4MLHSBR9MmQNr1lKXgDGcMERGR+BhY6ALjtcH49UwdAOCFbScgCFz9loiIxMXAQj16dP4EqPykKDjTiC+Pc/VbIiISFwML9UirVuG+a+MBAC/mnkCn3SFyRURENJwxsFCvfnNDPMICFSittWAr3+RMREQiYmChXgWr/PAfPxkHAHjpq1NosXaKXBEREQ1XfQos69atQ1xcHFQqFVJSUrB///7LOu+dd96BRCLBkiVL3PbffffdkEgkbtuCBQv6Uhr1s9tTRiNuRABqm614ZQeX7CciInF4HFi2bt2KrKwsrF69GgcOHEBiYiLS09NRU1Nz0fPKy8vx+OOP47rrruvx+IIFC1BdXe3a3n77bU9LowGgkEvxzM+mAABe31OKsjqLyBUREdFw5HFgWbt2LZYvX47MzExMmTIF69evR0BAADZu3NjrOXa7HUuXLsWzzz6L+Pj4HtsolUpERUW5ttDQUE9LowHyk0mRmDsxAh12AWs+PS52OURENAx5FFhsNhsKCgqQlpZ2/gJSKdLS0pCXl9frec899xwiIyNx77339tpm165diIyMxMSJE7FixQrU19f32tZqtcJsNrttNHAkEgme+dkU+Mkk2HGiBjtOcJozERENLo8CS11dHex2O7Rardt+rVYLg8HQ4zl79uzB66+/jg0bNvR63QULFmDLli3Q6/V48cUXsXv3bixcuBB2u73H9jk5OdBoNK5Np9N58jGoD8ZGBOGea8YAAJ7793FYO3v+syEiIhoIAzpLqLm5GXfeeSc2bNiA8PDwXtvdeuutWLRoEaZNm4YlS5bg008/xXfffYddu3b12D47Oxsmk8m1nT3LKbeD4cGfjENEsBLl9a14fU+Z2OUQEdEw4lFgCQ8Ph0wmg9Ho3iVgNBoRFRV1QfvTp0+jvLwct9xyC+RyOeRyObZs2YJPPvkEcrkcp0+f7vH7xMfHIzw8HCUlPc9KUSqVUKvVbhsNvGCVH7IXTgIAvLKjBAZTu8gVERHRcOFRYFEoFEhOToZer3ftczgc0Ov1SE1NvaD9pEmTUFhYiEOHDrm2RYsW4cYbb8ShQ4d67cqprKxEfX09oqOjPfw4NNCWJMVixqgQtNrs+K9tRWKXQ0REw4THXUJZWVnYsGEDNm/ejKKiIqxYsQIWiwWZmZkAgGXLliE7OxsAoFKpkJCQ4LaFhIQgODgYCQkJUCgUaGlpwRNPPIF9+/ahvLwcer0eixcvxrhx45Cent6/n5aumFQqwbOLEiCVAP8+XIU9p+rELomIiIYBjwNLRkYG/vu//xurVq1CUlISDh06hNzcXNdA3IqKClRXV1/29WQyGY4cOYJFixZhwoQJuPfee5GcnIxvvvkGSqXS0/JoEEwbqcGdV48GAKz6+CgH4BIR0YCTCIIgiF3ElTKbzdBoNDCZTBzPMkhMbR2Y95fdqGux4rH5E/DQvPFil0REREOMJz+/+S4h6hONvx+e+dlkAMArO0tQUd8qckVEROTLGFiozxYlxmDO2BGwdjqw+pOj8IGHdURE5KUYWKjPJBIJnlucAD+ZBDuLa/HFsZ4XDyQiIrpSDCx0RcZFBuE3148FADz77+OwWDtFroiIiHwRAwtdsZU3jsPIUH9Um9rxV/0pscshIiIfxMBCV8xfIcNzi6cCAF7fU4YTBr6MkoiI+hcDC/WLn0zSIn2qFnaHgN9/eBQOBwfgEhFR/2FgoX6z6pap8PeT4fszjXj/QKXY5RARkQ9hYKF+Exvij0fSnAvIPf/pcVQ1tYlcERER+QoGFupX91w7BokjNTC3d+KRrYdgZ9cQERH1AwYW6ld+Milevu0qBCpk2F/WgHU7S8QuiYiIfAADC/W70SMCsWZJAgDgr/pTKDjTIHJFREQ01DGw0ID4+VWxWJwUA7tDwMPvHIKprUPskoiIaAhjYKEBIZFI8PySBOjC/FHZ2Ian/u8I3zVERER9xsBCAyZY5Ye/3TYDfjIJPj9qwBvflotdEhERDVEMLDSgknQhePqnkwEAL2wrwoGKRpErIiKioYiBhQbcXXPicPO0aHQ6BDz4rwNotNjELomIiIYYBhYacBKJBH/8xTSMCQ9Elakdj757iEv3ExGRRxhYaFAEq/zw96UzoJRLsau4Fq/uPi12SURENIQwsNCgmRytdq3P8pcvi7GruEbkioiIaKhgYKFB9euZOtw2WweHADz09kGU1VnELomIiIYABhYadH9YNBXJo0PR3N6J5Vu+R3M7F5UjIqKLY2ChQaeUy/DqHTMQpVahpKYFj249zEG4RER0UQwsJIrIYBVeuzMZCrkUXxUZ8dJXJ8UuiYiIvBgDC4kmUReCnJ9PAwC8vKMEuUerRa6IiIi8FQMLieoXySNx77VjAABZ7x5GUbVZ5IqIiMgbMbCQ6LIXTsK148LRarPjnk3fwWhuF7skIiLyMgwsJDq5TIp1t8/A2IhAVJvacc+m72CxdopdFhEReREGFvIKmgA/vHH3bIwIVOBYlRn/8fZB2DlziIiIujCwkNcYNSIAG+6aCaVcCv2JGqz59LjYJRERkZdgYCGvMmNUKP4nIwkAsGlvOTbuKRO3ICIi8goMLOR1fjotGtkLJwEA1nx2HJ8eqRK5IiIiElufAsu6desQFxcHlUqFlJQU7N+//7LOe+eddyCRSLBkyRK3/YIgYNWqVYiOjoa/vz/S0tJw6tSpvpRGPuL+6+OxLHU0BAF4dOsh7DlVJ3ZJREQkIo8Dy9atW5GVlYXVq1fjwIEDSExMRHp6OmpqLv7m3fLycjz++OO47rrrLjj2pz/9CS+//DLWr1+P/Px8BAYGIj09He3tnN46XEkkEqy+ZSpunhaNDruA37z5PQorTWKXRUREIvE4sKxduxbLly9HZmYmpkyZgvXr1yMgIAAbN27s9Ry73Y6lS5fi2WefRXx8vNsxQRDw0ksv4fe//z0WL16M6dOnY8uWLaiqqsJHH33k8Qci3yGTSrA2IxFzxo6AxWbH3W/s59udiYiGKY8Ci81mQ0FBAdLS0s5fQCpFWloa8vLyej3vueeeQ2RkJO69994LjpWVlcFgMLhdU6PRICUlpddrWq1WmM1mt418k1Iuw2t3JmNqjBr1FhuWbcxHDReWIyIadjwKLHV1dbDb7dBqtW77tVotDAZDj+fs2bMHr7/+OjZs2NDj8e7zPLlmTk4ONBqNa9PpdJ58DBpiglV+2JQ5G6NHBOBsQxvueuM7mNs7xC6LiIgG0YDOEmpubsadd96JDRs2IDw8vN+um52dDZPJ5NrOnj3bb9cm7xQRrMSWe2YjPEiJomoz7t/yPdo77GKXRUREg0TuSePw8HDIZDIYjUa3/UajEVFRURe0P336NMrLy3HLLbe49jkcDuc3lstRXFzsOs9oNCI6OtrtmklJST3WoVQqoVQqPSmdfMDoEYHYlDkLt/5jH/aVNuCRdw5h3dIZkEklYpdGREQDzKMnLAqFAsnJydDr9a59DocDer0eqampF7SfNGkSCgsLcejQIde2aNEi3HjjjTh06BB0Oh3GjBmDqKgot2uazWbk5+f3eE0a3hJiNfjHsmQoZFLkHjPgmY+PQhC4hD8Rka/z6AkLAGRlZeGuu+7CzJkzMXv2bLz00kuwWCzIzMwEACxbtgyxsbHIycmBSqVCQkKC2/khISEA4Lb/kUcewfPPP4/x48djzJgxeOaZZxATE3PBei1EADBnbDheujUJK//3AP43vwIRQUo8On+C2GUREdEA8jiwZGRkoLa2FqtWrYLBYEBSUhJyc3Ndg2YrKioglXo2NObJJ5+ExWLB/fffj6amJlx77bXIzc2FSqXytDwaJn46LRprFifg9x8dxV/1pxAerMSdV48WuywiIhogEsEHnqebzWZoNBqYTCao1Wqxy6FB9D/bT+Kv+lOQSIB1t8/AT6dFX/okIiLyCp78/Oa7hGhIeyRtPJamjIIgAI+8cwh7T3MJfyIiX8TAQkOaRCLBc4sTsDAhCja7A/dvKcDRc1zCn4jI1zCw0JAnk0rwPxlJuDo+DC3WTizbuB8lNc1il0VERP2IgYV8gspPhg3LZmJarAYNFhuW/jMfFfWtYpdFRET9hIGFfEawyg9b7pmNCdogGM1WLH19HwwmvneIiMgXMLCQTwkNVOCte1MQ1/XeoaX/3Ie6FqvYZRER0RViYCGfE6lW4a37UhCjUeF0rQXLXt8PUxtflkhENJQxsJBPGhkagLfuS0F4kALHq824+439sFg7xS6LiIj6iIGFfFZ8RBDevDcFGn8/HKxownK+4ZmIaMhiYCGfNjlajc33zEagQoa9p+ux8l8H0GF3iF0WERF5iIGFfF6SLgSv3z0LSrkU+hM1yP6gkG94JiIaYhhYaFi4On4E1t+ZDKkEeL+gEu99Xyl2SURE5AEGFho2bpwYicdumggAeObjo1zCn4hoCGFgoWFlxQ1jcePECFg7Hbj7je9QXmcRuyQiIroMDCw0rEilErx061WYHK1GXYsVS/+Zz9VwiYiGAAYWGnY0/s4l/MeEB+JcUxvueD0ftc1cDZeIyJsxsNCwFBGsxJv3zka0RoWSmhb8+rU8VDW1iV0WERH1goGFhq2RoQF4e/nViA3xR1mdBb9an8cxLUREXoqBhYa1uPBAvPfbVMR3dQ/9cv1eHKxoFLssIiL6EQYWGvZiQvyx9TepmBKtRl2LDbf+Yx8+O1ItdllERPQDDCxEcI5pee+3qUibHAlrpwMr//cA1u0s4Yq4RERegoGFqEugUo7X7pyJe64ZAwD48xfFePy9I7B18t1DRERiY2Ah+gGZVIJVt0zBmiUJkEkl+L8Dlbjz9Xw0Wmxil0ZENKwxsBD14M6rR+P1u2YiSClHflkD/t+re1FS0yJ2WUREwxYDC1Ev5k6MxP+tmOOa9vzzdd9ixwmj2GUREQ1LDCxEFzExKhgfP3gNZseFodnaiXs3f8/BuEREImBgIbqE8CAl3rovBUtTRkEQnINxH3r7IFptnWKXRkQ0bDCwEF0GhVyK//r5NPzXzxMgl0rw6ZFq/PLVPFQ2topdGhHRsMDAQuSBpSmj8fb9VyM8SIHj1WYseuVb7CutF7ssIiKfx8BC5KFZcWH45MFrkRCrRoPFhjv+mY8NX5fC4eC4FiKigcLAQtQHMSH+eO83c7A4KQadDgH/ta0I927+Dg1cr4WIaED0KbCsW7cOcXFxUKlUSElJwf79+3tt+8EHH2DmzJkICQlBYGAgkpKS8Oabb7q1ufvuuyGRSNy2BQsW9KU0okHjr5DhpYwkPL8kAQq5FDuLa7Hwr1+zi4iIaAB4HFi2bt2KrKwsrF69GgcOHEBiYiLS09NRU1PTY/uwsDA8/fTTyMvLw5EjR5CZmYnMzEx88cUXbu0WLFiA6upq1/b222/37RMRDSKJRII7rh6Nj1deg7ERgTCarbh9wz6s3X6SS/oTEfUjieDhghIpKSmYNWsWXnnlFQCAw+GATqfDQw89hKeeeuqyrjFjxgzcfPPNWLNmDQDnE5ampiZ89NFHnlXfxWw2Q6PRwGQyQa1W9+kaRFeq1daJVR8fw/sFlQCAcZFByPl/0zArLkzkyoiIvJMnP789esJis9lQUFCAtLS08xeQSpGWloa8vLxLni8IAvR6PYqLi3H99de7Hdu1axciIyMxceJErFixAvX1vT9Wt1qtMJvNbhuR2AIUcvz3rxLxyu1XITxIgZKaFvxqfR6yPyiEqbVD7PKIiIY0jwJLXV0d7HY7tFqt236tVguDwdDreSaTCUFBQVAoFLj55pvxt7/9DfPnz3cdX7BgAbZs2QK9Xo8XX3wRu3fvxsKFC2G323u8Xk5ODjQajWvT6XSefAyiAfWz6TH4KusG3DrL+ffy7f0VmLd2N977/ixnEhER9ZFHXUJVVVWIjY3F3r17kZqa6tr/5JNPYvfu3cjPz+/xPIfDgdLSUrS0tECv12PNmjX46KOPMHfu3B7bl5aWYuzYsfjqq68wb968C45brVZYrVbX781mM3Q6HbuEyOvkl9Yj+8NClNZaAABTotX4/c2TMWdcuMiVERGJb8C6hMLDwyGTyWA0ur8Azmg0IioqqvdvIpVi3LhxSEpKwmOPPYZf/vKXyMnJ6bV9fHw8wsPDUVJS0uNxpVIJtVrtthF5o5T4Efj84euQvXASglVyHK824/Z/5iPzjf04XsWuTCKiy+VRYFEoFEhOToZer3ftczgc0Ov1bk9cLsXhcLg9IfmxyspK1NfXIzo62pPyiLySUi7Db24Yi91P3Ii758RBLpVgZ3EtfvryN/iPtw+irM4idolERF7P42nNWVlZ2LBhAzZv3oyioiKsWLECFosFmZmZAIBly5YhOzvb1T4nJwfbt29HaWkpioqK8Je//AVvvvkm7rjjDgBAS0sLnnjiCezbtw/l5eXQ6/VYvHgxxo0bh/T09H76mETiCwtU4A+LpuLLR6/Hz6Y7w/gnh6uQtnY3fv9RIepbeg/xRETDndzTEzIyMlBbW4tVq1bBYDAgKSkJubm5roG4FRUVkErP5yCLxYIHHngAlZWV8Pf3x6RJk/DWW28hIyMDACCTyXDkyBFs3rwZTU1NiImJwU033YQ1a9ZAqVT208ck8h7xEUF45fYZWDHXhP/+ohg7i2vx1r4KfHyoCvddG4+75oxGSIBC7DKJiLyKx+uweCOuw0JD2b7Sejz37+M4Xu0c0xKokOH2lFG477p4aNUqkasjIho4nvz8ZmAh8gJ2h4DPCqvx950lOGFoBgAoZFL8IjkWv7l+LOLCA0WukIio/zGwEA1RgiBgV3Et/r6rBN+VNwIApBLg5ukxWHHDWEyJ4d9vIvIdDCxEPuC78gb8fWcJdhbXuvbdODECD9w4jsv9E5FPYGAh8iHHq8x4dfdpfHakCt0L5c6KC8UDc8dh7sQISCQScQskIuojBhYiH1ReZ8FrX5fi/woqYbM73wQ9OVqN394Qj4UJ0VDIPV6lgIhIVAwsRD7MaG7H63vK8K99Z2CxOd+3FR6kQMYsHW6bPQojQwNErpCI6PIwsBANA02tNmzJO4N/5Z+B0excdE4qAW6cGIk7rh6N6ydEQCZldxEReS8GFqJhpMPugL7IiDf3ncG3JfWu/bowf9w+ezR+PXMkRgRxEUYi8j4MLETDVGltC/6VX4H3vj8Lc3snAOd6LgsSovDrmTrMGTsCUj51ISIvwcBCNMy12ez495Eq/GvfGRyuNLn2x2hU+EXySPxixkguRkdEomNgISKXwkoT3v3+LD4+dM711AVwTo3+ZfJI/HRaNIJVfiJWSETDFQMLEV2gvcOOr4qMeL+gEl+frHWt6aLykyJtshaLEmNww8QIKOUycQslomGDgYWILspgaseHB8/hvYKzKK21uParVXIsTIjGoqQYXB0/grOMiGhAMbAQ0WURBAFHKk345HAVPj1S5ZoeDQARwUr8bHo0bp4WjRmjQjlYl4j6HQMLEXnM7hCwv6wBnxw+h22FBpjaOlzHIoKVSJ+qxcKEaMweEwY/GVfVJaIrx8BCRFfE1unA1ydr8emRKuiLatBsPT9YNyTAD2mTtViYEIVrxoVD5ccxL0TUNwwsRNRvbJ0OfHu6Dl8cNeDL40Y0WGyuY0FKOW6cFIkFU6Mwd2IEApVyESsloqGGgYWIBkSn3YHvyhvxxTEDco8aYDC3u44p5VJcNz4C8yZHYu7ECERr/EWslIiGAgYWIhpwDoeAw5VNyO0KL2fqW92OT4oKxg0TI3DjxEgkjw7luBciugADCxENKkEQcMLQjO3HjdhVXINDZ5tc67wAQLBSjmvHh2PuxAjMnRgJrVolXrFE5DUYWIhIVI0WG74+VYvdxbXYfbIW9T8Y9wIAk6PVmNv19GXGqBDI+fSFaFhiYCEir+FwCCg8Z8LO4hrsKq7F4com/PBfnWCVHNePj8ANEyNw3fhwjn0hGkYYWIjIa9W3WPHNqTrsLK7B1ydr0dja4XY8bkQAro4fgdSxI5AaPwKR7D4i8lkMLEQ0JNi7Bu7uKq7F7uIaFJ4zuY19AYD4iECkxo/A1V1bRLBSnGKJqN8xsBDRkGRu78B3ZQ3YV1qPvNJ6HKsy48f/Qo2PDELq2PMBJixQIU6xRHTFGFiIyCeYWjuQX+YML3mn63HC0HxBm0lRwa7wcnV8GEICGGCIhgoGFiLySY0WmzPAnHaGmJPGFrfjEgkwOUrtegIzKy6UAYbIizGwENGwUNdiRX5pA/JK65B3uh6nay0XtJmgDcLMuDDMHB2KWXFhGBnqD4mEb54m8gYMLEQ0LNWY27GvrAF5p+uRX1qP0roLA4xWrcTM0WFIHh2KGaNDMSVaDYWc68AQiYGBhYgIzicwBWca8X15A74rb8TRcyZ0/mgaklIuxbRYDWaMDsWMUSGYMSqUU6mJBgkDCxFRD9psdhyubML35Q04UNGEAxWNaPrROjAAEBvij+kjNZg+MgSJIzVIGKmBWuUnQsVEvo2BhYjoMgiCgNI6Cw6cacSBiiYcrGhEsbH5gqnUgHM9mMSRIZgWq0GiToOpMRqo/GSDXzSRDxnwwLJu3Tr8+c9/hsFgQGJiIv72t79h9uzZPbb94IMP8MILL6CkpAQdHR0YP348HnvsMdx5552uNoIgYPXq1diwYQOamppwzTXX4NVXX8X48eMvqx4GFiLqL83tHSg8Z8KRShOOVDbh8FkTzjW1XdBOJpVggjYYiV1PYqaP1GBiVDDfSk3kgQENLFu3bsWyZcuwfv16pKSk4KWXXsJ7772H4uJiREZGXtB+165daGxsxKRJk6BQKPDpp5/isccew2effYb09HQAwIsvvoicnBxs3rwZY8aMwTPPPIPCwkIcP34cKtWl+5IZWIhoINW3WHHknAlHznaFmEoT6lqsF7RTyKWYEq12hZhEnQbx4UGQSjkriagnAxpYUlJSMGvWLLzyyisAAIfDAZ1Oh4ceeghPPfXUZV1jxowZuPnmm7FmzRoIgoCYmBg89thjePzxxwEAJpMJWq0WmzZtwq233nrJ6zGwENFgEgQB1aZ2HKls6noS4wwy5vbOC9oGKeWYGqNGos75FCZxZAinVhN18eTnt9yTC9tsNhQUFCA7O9u1TyqVIi0tDXl5eZc8XxAE7NixA8XFxXjxxRcBAGVlZTAYDEhLS3O102g0SElJQV5eXo+BxWq1wmo9/383ZrPZk49BRHRFJBIJYkL8ERPijwUJ0QCc/76V17e6upEKzzXh6DkzWqydyC9rQH5Zg+t8jb8fpkSrMTVGjamxakyN0SA+PBBydicR9cqjwFJXVwe73Q6tVuu2X6vV4sSJE72eZzKZEBsbC6vVCplMhr///e+YP38+AMBgMLiu8eNrdh/7sZycHDz77LOelE5ENKAkEgnGhAdiTHggFifFAgA67Q6U1LbgyFkTDlc2ofCcCUXVZpjaOpyvGyitd52vlEsxqTvExKgxJVqNydFqDuwl6uJRYOmr4OBgHDp0CC0tLdDr9cjKykJ8fDzmzp3bp+tlZ2cjKyvL9Xuz2QydTtdP1RIR9Q+5TIpJUWpMilLj17Oc/0bZOh04aWzG8SozjlWZcKzKjKJqMyw2Ow6fbcLhs02u86USYGxEUFeI0bi+agI4xZqGH48CS3h4OGQyGYxGo9t+o9GIqKioXs+TSqUYN24cACApKQlFRUXIycnB3LlzXecZjUZER0e7XTMpKanH6ymVSiiVfMU8EQ09CrkUCbEaJMRqADhDjMMhoLzegmNV5q7NhONVZtRbbDhV04JTNS346FCV6xqxIf7uISZWjSi1iuNiyKd5FFgUCgWSk5Oh1+uxZMkSAM5Bt3q9Hg8++OBlX8fhcLjGoIwZMwZRUVHQ6/WugGI2m5Gfn48VK1Z4Uh4R0ZAklUoQHxGE+Igg3JIYA8A5JsZotrqewnR/rWxsw7km5/bl8fP/8xgWqHB2Jf0gyMSNCISMM5TIR3jcJZSVlYW77roLM2fOxOzZs/HSSy/BYrEgMzMTALBs2TLExsYiJycHgHO8ycyZMzF27FhYrVZs27YNb775Jl599VUAzn7fRx55BM8//zzGjx/vmtYcExPjCkVERMONRCJBlEaFKI0K8yafH+Nnau3AsWpTV5eSM8icrrWgwWLDN6fq8M2pOldbhVyK+PBAjI0MwriIIIyNDML4yCDERwRCKefYGBpaPA4sGRkZqK2txapVq2AwGJCUlITc3FzXoNmKigpIpedHulssFjzwwAOorKyEv78/Jk2ahLfeegsZGRmuNk8++SQsFgvuv/9+NDU14dprr0Vubu5lrcFCRDScaAL8MGdsOOaMDXfta++w44Sh+QdPY8w4UW2GtdOBE4ZmnDA0u11DJnUOEJ6oDcYEbTDiwgOgCwvAqLAAjAhUsGuJvBKX5ici8kF2h4BzjW0oqW1GSU0LTtdYUFLbgpPGZjT3sF5MtwCFDLrQAOjC/KELC4Au1BlkdGHOfQGKQZmrQcME3yVEREQ96h4bU2xsxklDM04am3GmoRWVDa2oNrf3+B6lHwoPUvwoyJwPNtEaFdeSIY8wsBARkcesnXZUNbWjoqEVZ7u2ioZWnG1sRUV9a48r+f6QXOpcUO+HQWZkaAAig5WIDFZCq1YhUMknNHTegK10S0REvkspl7kWv+uJqbUDZxt/FGQa2lDZ0IrKxjbY7A5UdB3rTaBCBq1ahYhgJSLVKkQGKxEepMSIIAUiur6OCFJiRKCCi+aRGwYWIiK6LJoAP2gCuteQcedwCDA2t6OivhVnG9tQ0dXNVNnUhtpmK2rM7bDY7LDY7Cits6C0znLJ7xeslGNEkMIVaEYEOcNNeJACIwK7vnb9XuPvx8HCPo6BhYiIrphUKkG0xh/RGn+k9NKmxdqJGnM7apqtMJrbnUGm2Yq6FivqW2yot1hR1+z82mEX0GztRLO1E+X1vT+x6SaTSqBWyaH294Na5Qe1v9z5tevXwSq/Hx3/QRt/PwQqZAw8Xo6BhYiIBkWQUo6grgXyLkYQBJjbO88HmRZnqKn7Uaipb7GhtsWK5vZO2B0CGls70Nja0afapBI4Q82Pgs4PA41KIUOAnwz+ChlUfjL4+8kQqJQjQOH8GqySIzSAXVkDhYGFiIi8ikQigcbfDxp/P4yNuHR7a6cdjZYOmNs7YG5zfm1u7+z6dadrn7mt8wdtzu/vsAtwCICprQOmtg4AbVdUv8pPikCFHP4KmetrQNfmr5C7Qk+gUoYAhRz+ft3HnL8//2v381VyGaTDeOViBhYiIhrSlHIZojQyRGk8X2xUEARYOx2u8GLqJdS0Wu1o63Bu7Tbn11abHW02Oyy2TlisnWhu70SnQ0B7hwPtHTbg0sN0PCKRAIEKOQKVMoQFKjE5KhjXT4hAlEYFqUQCqcTZNaeSyxAb4g+1v9ynurk4rZmIiKgfCIKAFmsnmlo70NoVZNpszmDT+oNfO8NOJyxWZ+Bp7bCjzdbZ1e58CPph+76QSpxPq6QSICbEH1q1CqEBfhgTHoSr48MQFqhAoFKOIKXc2bXlN/hPcLgOCxERkY9wOAS0d9rRYnWGnJb2TtQ0t+O78kbkldajpb0DggDYBQEOQYDFakeDxebx95FIgICucTlqfz+MCFRAAKCQSRGlUSFao8KDPxnXr++h4josREREPkIqlXSNbZEDwd17NW4vxfyxVpuzi0oQgE6Hc32cBosNDRYbDp1twpFKE1ranV1ZFlsnHAIgCHBNPa9ptqLkR9dUyKXImj9hoD7mJTGwEBER+RhXwOkyMjTA9etlqe5tBcE57sb5BKcTLdZOmNo60GCxQSaVoNVmh9HcjjabXdQxMQwsREREw5hEIoF/18ykiGCl2OX0im+pIiIiIq/HwEJERERej4GFiIiIvB4DCxEREXk9BhYiIiLyegwsRERE5PUYWIiIiMjrMbAQERGR12NgISIiIq/HwEJERERej4GFiIiIvB4DCxEREXk9BhYiIiLyej7xtmZBEAAAZrNZ5EqIiIjocnX/3O7+OX4xPhFYmpubAQA6nU7kSoiIiMhTzc3N0Gg0F20jES4n1ng5h8OBqqoqBAcHQyKR9Ou1zWYzdDodzp49C7Va3a/XpvN4nwcP7/Xg4H0eHLzPg2cg7rUgCGhubkZMTAyk0ouPUvGJJyxSqRQjR44c0O+hVqv5H8Mg4H0ePLzXg4P3eXDwPg+e/r7Xl3qy0o2DbomIiMjrMbAQERGR12NguQSlUonVq1dDqVSKXYpP430ePLzXg4P3eXDwPg8ese+1Twy6JSIiIt/GJyxERETk9RhYiIiIyOsxsBAREZHXY2AhIiIir8fAcgnr1q1DXFwcVCoVUlJSsH//frFLGlK+/vpr3HLLLYiJiYFEIsFHH33kdlwQBKxatQrR0dHw9/dHWloaTp065damoaEBS5cuhVqtRkhICO699160tLQM4qfwfjk5OZg1axaCg4MRGRmJJUuWoLi42K1Ne3s7Vq5ciREjRiAoKAi/+MUvYDQa3dpUVFTg5ptvRkBAACIjI/HEE0+gs7NzMD+KV3v11Vcxffp018JZqamp+Pzzz13HeY8Hxh//+EdIJBI88sgjrn281/3jD3/4AyQSids2adIk13Gvus8C9eqdd94RFAqFsHHjRuHYsWPC8uXLhZCQEMFoNIpd2pCxbds24emnnxY++OADAYDw4Ycfuh3/4x//KGg0GuGjjz4SDh8+LCxatEgYM2aM0NbW5mqzYMECITExUdi3b5/wzTffCOPGjRNuu+22Qf4k3i09PV144403hKNHjwqHDh0SfvrTnwqjRo0SWlpaXG1++9vfCjqdTtDr9cL3338vXH311cKcOXNcxzs7O4WEhAQhLS1NOHjwoLBt2zYhPDxcyM7OFuMjeaVPPvlE+Oyzz4STJ08KxcXFwn/+538Kfn5+wtGjRwVB4D0eCPv37xfi4uKE6dOnCw8//LBrP+91/1i9erUwdepUobq62rXV1ta6jnvTfWZguYjZs2cLK1eudP3ebrcLMTExQk5OjohVDV0/DiwOh0OIiooS/vznP7v2NTU1CUqlUnj77bcFQRCE48ePCwCE7777ztXm888/FyQSiXDu3LlBq32oqampEQAIu3fvFgTBeV/9/PyE9957z9WmqKhIACDk5eUJguAMl1KpVDAYDK42r776qqBWqwWr1Tq4H2AICQ0NFf75z3/yHg+A5uZmYfz48cL27duFG264wRVYeK/7z+rVq4XExMQej3nbfWaXUC9sNhsKCgqQlpbm2ieVSpGWloa8vDwRK/MdZWVlMBgMbvdYo9EgJSXFdY/z8vIQEhKCmTNnutqkpaVBKpUiPz9/0GseKkwmEwAgLCwMAFBQUICOjg63ez1p0iSMGjXK7V5PmzYNWq3W1SY9PR1msxnHjh0bxOqHBrvdjnfeeQcWiwWpqam8xwNg5cqVuPnmm93uKcC/z/3t1KlTiImJQXx8PJYuXYqKigoA3neffeLlhwOhrq4Odrvd7Q8BALRaLU6cOCFSVb7FYDAAQI/3uPuYwWBAZGSk23G5XI6wsDBXG3LncDjwyCOP4JprrkFCQgIA531UKBQICQlxa/vje93Tn0X3MXIqLCxEamoq2tvbERQUhA8//BBTpkzBoUOHeI/70TvvvIMDBw7gu+++u+AY/z73n5SUFGzatAkTJ05EdXU1nn32WVx33XU4evSo191nBhYiH7Ny5UocPXoUe/bsEbsUnzRx4kQcOnQIJpMJ77//Pu666y7s3r1b7LJ8ytmzZ/Hwww9j+/btUKlUYpfj0xYuXOj69fTp05GSkoLRo0fj3Xffhb+/v4iVXYhdQr0IDw+HTCa7YDS00WhEVFSUSFX5lu77eLF7HBUVhZqaGrfjnZ2daGho4J9DDx588EF8+umn2LlzJ0aOHOnaHxUVBZvNhqamJrf2P77XPf1ZdB8jJ4VCgXHjxiE5ORk5OTlITEzEX//6V97jflRQUICamhrMmDEDcrkccrkcu3fvxssvvwy5XA6tVst7PUBCQkIwYcIElJSUeN3faQaWXigUCiQnJ0Ov17v2ORwO6PV6pKamiliZ7xgzZgyioqLc7rHZbEZ+fr7rHqempqKpqQkFBQWuNjt27IDD4UBKSsqg1+ytBEHAgw8+iA8//BA7duzAmDFj3I4nJyfDz8/P7V4XFxejoqLC7V4XFha6BcTt27dDrVZjypQpg/NBhiCHwwGr1cp73I/mzZuHwsJCHDp0yLXNnDkTS5cudf2a93pgtLS04PTp04iOjva+v9P9OoTXx7zzzjuCUqkUNm3aJBw/fly4//77hZCQELfR0HRxzc3NwsGDB4WDBw8KAIS1a9cKBw8eFM6cOSMIgnNac0hIiPDxxx8LR44cERYvXtzjtOarrrpKyM/PF/bs2SOMHz+e05p/ZMWKFYJGoxF27drlNj2xtbXV1ea3v/2tMGrUKGHHjh3C999/L6Smpgqpqamu493TE2+66Sbh0KFDQm5urhAREcFpoD/w1FNPCbt37xbKysqEI0eOCE899ZQgkUiEL7/8UhAE3uOB9MNZQoLAe91fHnvsMWHXrl1CWVmZ8O233wppaWlCeHi4UFNTIwiCd91nBpZL+Nvf/iaMGjVKUCgUwuzZs4V9+/aJXdKQsnPnTgHABdtdd90lCIJzavMzzzwjaLVaQalUCvPmzROKi4vdrlFfXy/cdtttQlBQkKBWq4XMzEyhublZhE/jvXq6xwCEN954w9Wmra1NeOCBB4TQ0FAhICBA+PnPfy5UV1e7Xae8vFxYuHCh4O/vL4SHhwuPPfaY0NHRMcifxnvdc889wujRowWFQiFEREQI8+bNc4UVQeA9Hkg/Diy81/0jIyNDiI6OFhQKhRAbGytkZGQIJSUlruPedJ8lgiAI/fvMhoiIiKh/cQwLEREReT0GFiIiIvJ6DCxERETk9RhYiIiIyOsxsBAREZHXY2AhIiIir8fAQkRERF6PgYWIiIi8HgMLEREReT0GFiIiIvJ6DCxERETk9RhYiIiIyOv9f0M/f0plXRGtAAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(losses)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:30:59.243854Z", + "start_time": "2023-12-13T17:30:59.170246Z" + } + }, + "id": "78ac13e102efe3a0" + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [00:01<00:00, 5.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: -170.4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Estimate agent performance\n", + "\n", + "episodes = []\n", + "for i in trange(10):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"mcar\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:31:03.547789Z", + "start_time": "2023-12-13T17:31:01.595787Z" + } + }, + "id": "d221ebfe535f9317" + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "mean_reward: -1.81e+02: 100%|██████████| 30/30 [01:12<00:00, 2.40s/it]\n" + ] + }, + { + "data": { + "text/plain": "[]" + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_rewards = []\n", + "\n", + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gamma\": 0.99,\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 256,\n", + "})\n", + "\n", + "for t in (pbar := trange(30)):\n", + " num_steps = 0\n", + " episodes = []\n", + " while num_steps < 2000:\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"mcar\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " num_steps += len(data.rewards)\n", + " \n", + " all_data = concatenate(episodes)\n", + " \n", + " last_value = actor.agent.act(Observation(vector=all_data.last_observation), get_value=True)[2][\"value\"]\n", + " values = actor.agent.act(Observation(vector=all_data.observations), get_value=True)[2][\"value\"]\n", + " \n", + " record = convert_trial_data_to_coltra(all_data, agent=actor.agent)\n", + " metrics = ppo.train_on_data({\"crowd\": record}, shape=(1,) + record.reward.shape)\n", + " \n", + " mean_reward = metrics[\"crowd/mean_episode_reward\"]\n", + " all_rewards.append(mean_reward)\n", + " pbar.set_description(f\"mean_reward: {mean_reward:.3}\")\n", + " \n", + "plt.plot(all_rewards)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:32:23.974676Z", + "start_time": "2023-12-13T17:31:11.816425Z" + } + }, + "id": "56b220d45561a042" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "await cog.cleanup()" + ], + "metadata": { + "collapsed": false + }, + "id": "1fb0dfc4957749d2" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/gymnasium/interactive.ipynb b/examples/gymnasium/interactive.ipynb new file mode 100644 index 0000000..5a0213c --- /dev/null +++ b/examples/gymnasium/interactive.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T17:33:49.536453Z", + "start_time": "2023-12-13T17:33:47.755623Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import format_data_multiagent\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:33:49.592824Z", + "start_time": "2023-12-13T17:33:49.544145Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:33:52.081428Z", + "start_time": "2023-12-13T17:33:52.075851Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T18:33:52.073150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:33:58.345454Z", + "start_time": "2023-12-13T17:33:56.187819Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "\n", + "cenv = GymEnvironment(\n", + " env_id=\"CartPole-v1\",\n", + " render=True\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"cartpole\",\n", + " port=9001, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:34:03.722724Z", + "start_time": "2023-12-13T17:33:59.337837Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space)\n", + "constant_actor = ConstantActor(0)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'cartpole': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:34:03.723625Z", + "start_time": "2023-12-13T17:34:03.721429Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:34:07.883447Z", + "start_time": "2023-12-13T17:34:05.647704Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "MOUNTAIN_CAR_ACTIONS = [\"no-op\", \"ArrowLeft\", \"ArrowRight\"]\n", + "LUNAR_LANDER_ACTIONS = [\"no-op\", \"ArrowRight\", \"ArrowUp\", \"ArrowLeft\"]\n", + "PONG_ACTIONS = [\"no-op\", \"ArrowUp\", \"ArrowDown\"]\n", + "CARTPOLE_ACTIONS = [\"no-op\", \"ArrowRight\"]\n", + "\n", + "# Change this if you use a different environment. Only discrete actions are supported for now.\n", + "# no-op is the default action when no key is pressed\n", + "\n", + "actions = CARTPOLE_ACTIONS\n", + "\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "trial_id = await cog.start_trial(\n", + " env_name=\"cartpole\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"gym\": \"random\",\n", + " },\n", + ")\n", + "\n", + "data = await cog.get_trial_data(trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:34:30.419894Z", + "start_time": "2023-12-13T17:34:29.936857Z" + } + }, + "id": "8052ff03998b0b52" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "(19, 4)" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"gym\"].observations.shape" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:52:54.113076Z", + "start_time": "2023-12-13T17:52:54.105989Z" + } + }, + "id": "1800cbfeca577ec8" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "await cog.cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:53:01.851686Z", + "start_time": "2023-12-13T17:53:00.771Z" + } + }, + "id": "5d5770465a23d064" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/gymnasium/interactive.py b/examples/gymnasium/interactive.py new file mode 100644 index 0000000..fb43108 --- /dev/null +++ b/examples/gymnasium/interactive.py @@ -0,0 +1,72 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime + +from cogment_lab.actors import RandomActor, ConstantActor +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.utils.runners import process_cleanup +from cogment_lab.utils.trial_utils import format_data_multiagent + + +async def main(): + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + print(logpath) + + # Launch an environment in a subprocess + + cenv = GymEnvironment(env_id="CartPole-v1", render=True) + + print("Starting env") + + await cog.run_env(env=cenv, env_name="cartpole", port=9001, log_file="env.log") + + # Launch two dummy actors in subprocesses + + print("Starting actors") + + random_actor = RandomActor(cenv.env.action_space) + constant_actor = ConstantActor(0) + + await cog.run_actor(actor=random_actor, actor_name="random", port=9021, log_file="actor-random.log") + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9022, log_file="actor-constant.log") + + # Start a trial + + print("Starting trial") + + trial_id = await cog.start_trial( + env_name="cartpole", + session_config={"render": True}, + actor_impls={ + "gym": "random", + }, + ) + + print("Waiting for trial to finish") + + data = await format_data_multiagent(datastore=cog.datastore, trial_id=trial_id, actor_agent_specs=cenv.agent_specs) + + print(data) + return + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/gymnasium/local-training.ipynb b/examples/gymnasium/local-training.ipynb new file mode 100644 index 0000000..6c7c300 --- /dev/null +++ b/examples/gymnasium/local-training.ipynb @@ -0,0 +1,318 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T17:55:13.220688Z", + "start_time": "2023-12-13T17:55:13.216030Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import concatenate\n", + "\n", + "from coltra import HomogeneousGroup\n", + "from coltra.models import MLPModel\n", + "from coltra.policy_optimization import CrowdPPOptimizer\n", + "\n", + "from cogment_lab.actors import ColtraActor\n", + "\n", + "from tqdm import trange\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:53:33.296582Z", + "start_time": "2023-12-13T17:53:33.235940Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:53:33.624181Z", + "start_time": "2023-12-13T17:53:33.621354Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T18:53:33.619848\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:55:18.921705Z", + "start_time": "2023-12-13T17:55:16.764819Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We'll train on CartPole-v1\n", + "\n", + "cenv = GymEnvironment(\n", + " env_id=\"CartPole-v1\",\n", + " render=False,\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"cartpole\",\n", + " port=9001, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:55:20.063208Z", + "start_time": "2023-12-13T17:55:20.059040Z" + } + }, + "outputs": [], + "source": [ + "# Create a model using coltra\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "actor_task = cog.run_local_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "{'cartpole': ,\n 'coltra': wait_for=>}" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes | cog.tasks" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:55:31.732225Z", + "start_time": "2023-12-13T17:55:31.729176Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 128,\n", + "})" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T17:55:38.366171Z", + "start_time": "2023-12-13T17:55:38.201298Z" + } + }, + "id": "582b6bb1bf0c81df" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/10 [00:00]" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(all_rewards)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:00:12.198836Z", + "start_time": "2023-12-13T18:00:12.058221Z" + } + }, + "id": "457bbd9d81cac391" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/gymnasium/local-training.py b/examples/gymnasium/local-training.py new file mode 100644 index 0000000..2163e2e --- /dev/null +++ b/examples/gymnasium/local-training.py @@ -0,0 +1,114 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime + +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra +from cogment_lab.utils.runners import process_cleanup +from cogment_lab.utils.trial_utils import concatenate + +from coltra import HomogeneousGroup +from coltra.models import MLPModel +from coltra.policy_optimization import CrowdPPOptimizer + +from cogment_lab.actors import ColtraActor + +from tqdm import trange +import matplotlib.pyplot as plt + + +async def main(): + process_cleanup() + + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + print(logpath) + + # We'll train on CartPole-v1 + + cenv = GymEnvironment( + env_id="CartPole-v1", + render=False, + ) + + await cog.run_env(env=cenv, env_name="cartpole", port=9001, log_file="env.log") + + print("Env started") + + # Create a model using coltra + + model = MLPModel( + config={ + "hidden_sizes": [64, 64], + }, + observation_space=cenv.env.observation_space, + action_space=cenv.env.action_space, + ) + + actor = ColtraActor(model=model) + + actor_task = cog.run_local_actor(actor=actor, actor_name="coltra", port=9021, log_file="actor.log") + + print("Actor started") + + ppo = CrowdPPOptimizer( + HomogeneousGroup(actor.agent), + config={ + "gae_lambda": 0.95, + "minibatch_size": 128, + }, + ) + + all_rewards = [] + + for t in (pbar := trange(100)): + num_steps = 0 + episodes = [] + while num_steps < 1000: # Collect at least 1000 steps per training iteration + trial_id = await cog.start_trial( + env_name="cartpole", + session_config={"render": False}, + actor_impls={ + "gym": "coltra", + }, + ) + multi_data = await cog.get_trial_data(trial_id=trial_id) + data = multi_data["gym"] + episodes.append(data) + num_steps += len(data.rewards) + + all_data = concatenate(episodes) + + # Preprocess data + + record = convert_trial_data_to_coltra(all_data, actor.agent) + + # Run a PPO step + metrics = ppo.train_on_data({"crowd": record}, shape=(1, len(record.reward))) + + mean_reward = metrics["crowd/mean_episode_reward"] + all_rewards.append(mean_reward) + pbar.set_description(f"mean_reward: {mean_reward:.3}") + + plt.plot(all_rewards) + plt.show() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/gymnasium/observe.ipynb b/examples/gymnasium/observe.ipynb new file mode 100644 index 0000000..d0ff154 --- /dev/null +++ b/examples/gymnasium/observe.ipynb @@ -0,0 +1,293 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T18:23:08.799669Z", + "start_time": "2023-12-13T18:23:07.078180Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.pettingzoo import AECEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:09.292376Z", + "start_time": "2023-12-13T18:23:09.234508Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:09.470752Z", + "start_time": "2023-12-13T18:23:09.465578Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T19:23:09.453551\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:11.879060Z", + "start_time": "2023-12-13T18:23:09.667673Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "\n", + "cenv = AECEnvironment(env_path=\"cogment_lab.envs.conversions.observer.GymObserverAECAEC\",\n", + " make_kwargs={\"gym_env_name\": \"LunarLander-v2\"},\n", + " render=True)\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"lunar\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:16.261481Z", + "start_time": "2023-12-13T18:23:11.876240Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space(\"gym\"))\n", + "constant_actor = ConstantActor(0)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'lunar': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:16.267399Z", + "start_time": "2023-12-13T18:23:16.260197Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:23.468397Z", + "start_time": "2023-12-13T18:23:21.316245Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNAR_LANDER_ACTIONS = [\"no-op\", \"ArrowRight\", \"ArrowUp\", \"ArrowLeft\"]\n", + "\n", + "# Change this if you use a different environment. Only discrete actions are supported for now.\n", + "\n", + "actions = LUNAR_LANDER_ACTIONS\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "# Get data from a random + random trial\n", + "# You can change the values in `actor_impls` between `web_ui`, `random`, and `constant` to see the different behaviors\n", + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"lunar\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"gym\": \"random\",\n", + " \"observer\": \"web_ui\",\n", + " },\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:24.017807Z", + "start_time": "2023-12-13T18:23:24.005380Z" + } + }, + "id": "efef1ac3ff97fe90" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "data = await cog.get_trial_data(trial_id)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:23:35.523005Z", + "start_time": "2023-12-13T18:23:26.560614Z" + } + }, + "id": "8052ff03998b0b52" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "await cog.cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:22:59.869879Z", + "start_time": "2023-12-13T18:22:58.827007Z" + } + }, + "id": "9e64d0d548ac34ef" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/gymnasium/simple.py b/examples/gymnasium/simple.py new file mode 100644 index 0000000..1b5660f --- /dev/null +++ b/examples/gymnasium/simple.py @@ -0,0 +1,65 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime + +from cogment_lab.actors import ConstantActor +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment + + +async def main(): + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + print(logpath) + + cenv = GymEnvironment( + env_id="MountainCar-v0", + render=True, + make_kwargs={"max_episode_steps": 10}, + ) + + await cog.run_env(env=cenv, env_name="mcar", port=9011, log_file="env.log") + + # Create a model using coltra + + constant_actor = ConstantActor(1) + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9022, log_file="actor-constant.log") + + # Estimate random agent performance + + trial_id = await cog.start_trial( + env_name="mcar", + session_config={"render": False}, + actor_impls={ + "gym": "constant", + }, + ) + multi_data = await cog.get_trial_data(trial_id=trial_id) + data = multi_data["gym"] + + # mean_reward = np.mean([sum(e.rewards) for e in episodes]) + print(f"Reward shape: {data.rewards.shape}") + print(f"Rewards: {data.rewards}") + print(f"Observations: {data.observations}") + print(f"Last observation: {data.last_observation}") + print(f"Actions: {data.actions}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/gymnasium/training.ipynb b/examples/gymnasium/training.ipynb new file mode 100644 index 0000000..97e9582 --- /dev/null +++ b/examples/gymnasium/training.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T18:24:37.254269Z", + "start_time": "2023-12-13T18:24:35.182905Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import concatenate\n", + "\n", + "from coltra import HomogeneousGroup\n", + "from coltra.buffers import Observation\n", + "from coltra.models import MLPModel\n", + "from coltra.policy_optimization import CrowdPPOptimizer\n", + "\n", + "from cogment_lab.actors import ColtraActor\n", + "\n", + "from tqdm import trange\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:39.351501Z", + "start_time": "2023-12-13T18:24:39.276395Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:40.457060Z", + "start_time": "2023-12-13T18:24:40.454080Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T19:24:40.452142\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:52.181519Z", + "start_time": "2023-12-13T18:24:49.916590Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We'll train on CartPole-v1\n", + "\n", + "cenv = GymEnvironment(\n", + " env_id=\"CartPole-v1\",\n", + " render=False,\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"cartpole\",\n", + " port=9001, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:56.781677Z", + "start_time": "2023-12-13T18:24:54.489069Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a model using coltra\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'cartpole': ,\n 'coltra': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:56.786229Z", + "start_time": "2023-12-13T18:24:56.781215Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 128,\n", + "})" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:24:58.339199Z", + "start_time": "2023-12-13T18:24:58.171380Z" + } + }, + "id": "582b6bb1bf0c81df" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "mean_reward: 27.6: 100%|██████████| 10/10 [00:15<00:00, 1.54s/it]\n" + ] + } + ], + "source": [ + "all_rewards = []\n", + "\n", + "for t in (pbar := trange(10)):\n", + " num_steps = 0\n", + " episodes = []\n", + " while num_steps < 1000: # Collect at least 1000 steps per training iteration\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"cartpole\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"cartpole\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " num_steps += len(data.rewards)\n", + " \n", + " all_data = concatenate(episodes)\n", + "\n", + " # Preprocess data\n", + " record = convert_trial_data_to_coltra(all_data, actor.agent)\n", + "\n", + " # Run a PPO step\n", + " metrics = ppo.train_on_data({\"crowd\": record}, shape=(1,) + record.reward.shape)\n", + " \n", + " mean_reward = metrics[\"crowd/mean_episode_reward\"]\n", + " all_rewards.append(mean_reward)\n", + " pbar.set_description(f\"mean_reward: {mean_reward:.3}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:25:49.058929Z", + "start_time": "2023-12-13T18:25:33.652328Z" + } + }, + "id": "56b220d45561a042" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "data": { + "text/plain": "[]" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(all_rewards)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T18:26:38.543323Z", + "start_time": "2023-12-13T18:26:38.246992Z" + } + }, + "id": "457bbd9d81cac391" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/lunar-interactive-bc.ipynb b/examples/lunar-interactive-bc.ipynb new file mode 100644 index 0000000..79c63fc --- /dev/null +++ b/examples/lunar-interactive-bc.ipynb @@ -0,0 +1,661 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T19:12:43.977147Z", + "start_time": "2023-12-13T19:12:42.044264Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "\n", + "import numpy as np\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import format_data_multiagent, concatenate\n", + "\n", + "from coltra import HomogeneousGroup\n", + "from coltra.buffers import Observation, Action\n", + "from coltra.models import MLPModel\n", + "from coltra.policy_optimization import CrowdPPOptimizer\n", + "\n", + "from cogment_lab.actors import ColtraActor\n", + "\n", + "from tqdm import trange\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:44.045367Z", + "start_time": "2023-12-13T19:12:43.986529Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:44.910776Z", + "start_time": "2023-12-13T19:12:44.396144Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T20:12:44.394188\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:47.125359Z", + "start_time": "2023-12-13T19:12:44.911292Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We'll train on \n", + "\n", + "cenv = GymEnvironment(\n", + " env_id=\"LunarLander-v2\",\n", + " render=True,\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"mcar\",\n", + " port=9001, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:49.518262Z", + "start_time": "2023-12-13T19:12:47.125042Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a model using coltra\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'mcar': ,\n 'coltra': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:49.519197Z", + "start_time": "2023-12-13T19:12:49.514813Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:51.661932Z", + "start_time": "2023-12-13T19:12:49.518157Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "LUNAR_LANDER_ACTIONS = [\"no-op\", \"ArrowRight\", \"ArrowUp\", \"ArrowLeft\"]\n", + "\n", + "actions = LUNAR_LANDER_ACTIONS\n", + "\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [00:01<00:00, 9.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: -52.153119627747216\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Estimate random agent performance\n", + "\n", + "episodes = []\n", + "for i in trange(10):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await format_data_multiagent(datastore=cog.datastore, trial_id=trial_id, actor_agent_specs=cenv.agent_specs)\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:12:54.225530Z", + "start_time": "2023-12-13T19:12:53.217519Z" + } + }, + "id": "b7cde51d7dc0c3a9" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Reinitialize the agent\n", + "\n", + "cog.stop_service(\"coltra\")\n", + "\n", + "model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + ")\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:13:03.878130Z", + "start_time": "2023-12-13T19:13:01.612883Z" + } + }, + "id": "5c1585be28fdae6c" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "# Get some human episodes\n", + "episodes = []\n", + "for i in range(3):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"gym\": \"web_ui\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id)\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " \n", + "all_data = concatenate(episodes)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:35:11.374153Z", + "start_time": "2023-12-13T19:34:33.073648Z" + } + }, + "id": "8f1381b80d4c8799" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: 77.66748484175616\n", + "rewards: [-90.41456591431051, 277.5436417415343, 45.87337869804469]\n" + ] + } + ], + "source": [ + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")\n", + "print(f\"rewards: {[sum(e.rewards) for e in episodes]}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:35:19.616407Z", + "start_time": "2023-12-13T19:35:19.613200Z" + } + }, + "id": "73d139b8e5d005d8" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [], + "source": [ + "cog.stop_service(\"web_ui\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:35:31.859929Z", + "start_time": "2023-12-13T19:35:30.845308Z" + } + }, + "id": "73c751b525abf31e" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "all_obs = Observation(vector=all_data.observations).tensor()\n", + "all_actions = torch.tensor(all_data.actions)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:35:34.727391Z", + "start_time": "2023-12-13T19:35:34.724168Z" + } + }, + "id": "cdcef32e17f9afc1" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "loss: 0.186: 100%|██████████| 500/500 [00:00<00:00, 738.80it/s]\n" + ] + } + ], + "source": [ + "losses = []\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "for t in (pbar := trange(500)):\n", + " preds = model(all_obs)[0].logits\n", + " loss = F.cross_entropy(preds, all_actions)\n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " pbar.set_description(f\"loss: {loss.item():.3}\")\n", + " \n", + " losses.append(loss.item())\n", + " " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:36:04.511395Z", + "start_time": "2023-12-13T19:36:03.830625Z" + } + }, + "id": "9a1cb51957e9f672" + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [ + { + "data": { + "text/plain": "[]" + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(losses)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:36:05.177923Z", + "start_time": "2023-12-13T19:36:05.113102Z" + } + }, + "id": "78ac13e102efe3a0" + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [00:02<00:00, 4.85it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean_reward: -93.20616755490191\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Estimate agent performance\n", + "\n", + "episodes = []\n", + "for i in trange(10):\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await format_data_multiagent(datastore=cog.datastore, trial_id=trial_id, actor_agent_specs=cenv.agent_specs)\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + "mean_reward = np.mean([sum(e.rewards) for e in episodes])\n", + "print(f\"mean_reward: {mean_reward}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:36:08.699440Z", + "start_time": "2023-12-13T19:36:06.634848Z" + } + }, + "id": "d221ebfe535f9317" + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "mean_reward: 87.1: 100%|██████████| 100/100 [04:30<00:00, 2.70s/it]\n" + ] + }, + { + "data": { + "text/plain": "[]" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "all_rewards = []\n", + "\n", + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gamma\": 0.99,\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 256,\n", + "})\n", + "\n", + "for t in (pbar := trange(100)):\n", + " num_steps = 0\n", + " episodes = []\n", + " while num_steps < 2000:\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"mcar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await format_data_multiagent(datastore=cog.datastore, trial_id=trial_id, actor_agent_specs=cenv.agent_specs)\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " num_steps += len(data.rewards)\n", + " \n", + " all_data = concatenate(episodes)\n", + " \n", + " record = convert_trial_data_to_coltra(all_data, actor.agent)\n", + " metrics = ppo.train_on_data({\"crowd\": record}, shape=(1,) + record.reward.shape)\n", + " \n", + " mean_reward = metrics[\"crowd/mean_episode_reward\"]\n", + " all_rewards.append(mean_reward)\n", + " pbar.set_description(f\"mean_reward: {mean_reward:.3}\")\n", + " \n", + "plt.plot(all_rewards)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:41:16.217109Z", + "start_time": "2023-12-13T19:36:45.952578Z" + } + }, + "id": "56b220d45561a042" + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [], + "source": [ + "await cog.cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T19:41:48.682828Z", + "start_time": "2023-12-13T19:41:48.645559Z" + } + }, + "id": "e23f3f5ecd4ffc2a" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/lunarlander-teach.ipynb b/examples/lunarlander-teach.ipynb new file mode 100644 index 0000000..23bb0a9 --- /dev/null +++ b/examples/lunarlander-teach.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-15T15:16:54.108165Z", + "start_time": "2024-01-15T15:16:52.130847Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "import os\n", + "\n", + "from cogment_lab.envs.gymnasium import GymEnvironment\n", + "from cogment_lab.envs.pettingzoo import AECEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra\n", + "from cogment_lab.utils.runners import process_cleanup\n", + "from cogment_lab.utils.trial_utils import format_data_multiagent, concatenate\n", + "\n", + "from coltra import HomogeneousGroup, DAgent\n", + "from coltra.buffers import Observation\n", + "from coltra.models import MLPModel\n", + "from coltra.policy_optimization import CrowdPPOptimizer\n", + "\n", + "from cogment_lab.actors import ColtraActor, RandomActor, ConstantActor\n", + "\n", + "from tqdm import trange\n", + "import matplotlib.pyplot as plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:16:54.168837Z", + "start_time": "2024-01-15T15:16:54.116233Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:16:55.754069Z", + "start_time": "2024-01-15T15:16:55.749952Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2024-01-15T16:16:55.746880\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:17:00.641327Z", + "start_time": "2024-01-15T15:16:58.281030Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cenv = GymEnvironment(\n", + " env_id=\"LunarLander-v2\",\n", + " render=False,\n", + ")\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"lunar\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:17:27.461202Z", + "start_time": "2024-01-15T15:17:20.645759Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a model using coltra\n", + "\n", + "if os.path.exists(\"models/lunar\"):\n", + " agent = DAgent.load(\"models/lunar\")\n", + " model = agent.model\n", + "else:\n", + " model = MLPModel(\n", + " config={\n", + " \"hidden_sizes\": [64, 64],\n", + " }, \n", + " observation_space=cenv.env.observation_space, \n", + " action_space=cenv.env.action_space\n", + " )\n", + "\n", + "# Put the model in shared memory so that the actor can access it\n", + "model.share_memory()\n", + "actor = ColtraActor(model=model)\n", + "\n", + "\n", + "await cog.run_actor(\n", + " actor=actor,\n", + " actor_name=\"coltra\",\n", + " port=9021,\n", + " log_file=\"actor.log\"\n", + ")\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space)\n", + "constant_actor = ConstantActor(0)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9022, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9023,\n", + " log_file=\"actor-constant.log\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [], + "source": [ + "ppo = CrowdPPOptimizer(HomogeneousGroup(actor.agent), config={\n", + " \"gae_lambda\": 0.95,\n", + " \"minibatch_size\": 128,\n", + "})" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:17:36.574395Z", + "start_time": "2024-01-15T15:17:36.402215Z" + } + }, + "id": "582b6bb1bf0c81df" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "mean_reward: -96.0: 100%|██████████| 100/100 [02:15<00:00, 1.35s/it] \n" + ] + } + ], + "source": [ + "all_rewards = []\n", + "\n", + "for t in (pbar := trange(100)):\n", + " num_steps = 0\n", + " episodes = []\n", + " while num_steps < 1000: # Collect at least 1000 steps per training iteration\n", + " trial_id = await cog.start_trial(\n", + " env_name=\"lunar\",\n", + " session_config={\"render\": False},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " },\n", + " )\n", + " multi_data = await cog.get_trial_data(trial_id=trial_id, env_name=\"lunar\")\n", + " data = multi_data[\"gym\"]\n", + " episodes.append(data)\n", + " num_steps += len(data.rewards)\n", + " \n", + " all_data = concatenate(episodes)\n", + "\n", + " # Preprocess data\n", + " record = convert_trial_data_to_coltra(all_data, actor.agent)\n", + "\n", + " # Run a PPO step\n", + " metrics = ppo.train_on_data({\"crowd\": record}, shape=(1,) + record.reward.shape)\n", + " \n", + " mean_reward = metrics[\"crowd/mean_episode_reward\"]\n", + " all_rewards.append(mean_reward)\n", + " pbar.set_description(f\"mean_reward: {mean_reward:.3}\")\n", + " \n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:20:01.409375Z", + "start_time": "2024-01-15T15:17:45.954719Z" + } + }, + "id": "56b220d45561a042" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "[]" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(all_rewards)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:22:36.059537Z", + "start_time": "2024-01-15T15:22:35.985945Z" + } + }, + "id": "457bbd9d81cac391" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "cog.stop_service(\"lunar\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:22:39.592481Z", + "start_time": "2024-01-15T15:22:39.571297Z" + } + }, + "id": "4099d057af5affcf" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cenv = AECEnvironment(env_path=\"cogment_lab.envs.conversions.teacher.GymTeacherAEC\",\n", + " make_kwargs={\"gym_env_name\": \"LunarLander-v2\", \n", + " \"gym_make_kwargs\": {}, \n", + " \"render_mode\": \"rgb_array\"},\n", + " render=True)\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"lunar-teach\",\n", + " port=9011, \n", + " log_file=\"env.log\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:22:53.079736Z", + "start_time": "2024-01-15T15:22:50.933845Z" + } + }, + "id": "8dc36a4b38990059" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actions = {\n", + " \"no-op\": {\"active\": 0, \"action\": 0},\n", + " \"ArrowDown\": {\"active\": 1, \"action\": 0},\n", + " \"ArrowRight\": {\"active\": 1, \"action\": 1},\n", + " \"ArrowUp\": {\"active\": 1, \"action\": 2},\n", + " \"ArrowLeft\": {\"active\": 1, \"action\": 3},\n", + "}\n", + "\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:22:57.152462Z", + "start_time": "2024-01-15T15:22:54.906515Z" + } + }, + "id": "d4a8788421ffd8f4" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"lunar-teach\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"gym\": \"coltra\",\n", + " \"teacher\": \"web_ui\",\n", + " },\n", + ")\n", + "\n", + "data = await cog.get_trial_data(trial_id=trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:23:24.879012Z", + "start_time": "2024-01-15T15:22:59.156914Z" + } + }, + "id": "94ef1fc6ba14903f" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "-438.36224" + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"gym\"].rewards.sum()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:23:28.514466Z", + "start_time": "2024-01-15T15:23:28.507954Z" + } + }, + "id": "cad1754096b0b41b" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "await cog.cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T15:23:33.250140Z", + "start_time": "2024-01-15T15:23:32.158203Z" + } + }, + "id": "d8c9467e9c724d7b" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/models/lunar/agent.pt b/examples/models/lunar/agent.pt new file mode 100644 index 0000000..3627621 Binary files /dev/null and b/examples/models/lunar/agent.pt differ diff --git a/examples/models/lunar/model.pt b/examples/models/lunar/model.pt new file mode 100644 index 0000000..4877901 Binary files /dev/null and b/examples/models/lunar/model.pt differ diff --git a/examples/pettingzoo/pz-atari-interactive.ipynb b/examples/pettingzoo/pz-atari-interactive.ipynb new file mode 100644 index 0000000..03f1732 --- /dev/null +++ b/examples/pettingzoo/pz-atari-interactive.ipynb @@ -0,0 +1,285 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-09T16:47:13.176123Z", + "start_time": "2024-01-09T16:47:11.657179Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.pettingzoo import AECEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:13.230710Z", + "start_time": "2024-01-09T16:47:13.184354Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:13.750134Z", + "start_time": "2024-01-09T16:47:13.746042Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2024-01-09T17:47:13.230277\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:16.131854Z", + "start_time": "2024-01-09T16:47:13.771295Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "# \n", + "# cenv = AECEnvironment(env_path=\"pettingzoo.butterfly.cooperative_pong_v5.env\",\n", + "# render=True)\n", + "\n", + "\n", + "cenv = AECEnvironment(env_path=\"pettingzoo.atari.pong_v3.env\",\n", + " render=True,\n", + " make_kwargs={\"max_cycles\": 1000})\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"pong\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:21.823661Z", + "start_time": "2024-01-09T16:47:17.528096Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space(\"first_0\"))\n", + "constant_actor = ConstantActor(1)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "{'pong': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:21.829448Z", + "start_time": "2024-01-09T16:47:21.822231Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:47:23.978317Z", + "start_time": "2024-01-09T16:47:21.827487Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# PONG_ACTIONS = [\"no-op\", \"ArrowUp\", \"ArrowDown\"]\n", + "\n", + "PONG_ACTIONS = [\"no-op\", \"F\", \"ArrowRight\", \"ArrowLeft\", \"ArrowUp\", \"ArrowDown\"]\n", + "\n", + "actions = PONG_ACTIONS\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# Get data from a random + random trial\n", + "# You can change the values in `actor_impls` between `web_ui`, `random`, and `constant` to see the different behaviors\n", + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"pong\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"first_0\": \"random\",\n", + " \"second_0\": \"web_ui\",\n", + " },\n", + ")\n", + "\n", + "data = await cog.get_trial_data(trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:48:10.877960Z", + "start_time": "2024-01-09T16:47:25.907512Z" + } + }, + "id": "8052ff03998b0b52" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "data": { + "text/plain": "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,\n 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5,\n 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,\n 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 0, 5, 5, 5, 5, 5, 0, 0, 4, 4, 4, 4,\n 4, 4, 0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0,\n 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0, 4, 4,\n 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,\n 5, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4,\n 4, 4, 4, 4, 4, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0,\n 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 4, 4, 4,\n 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 0, 0, 0,\n 4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 0, 0, 4, 4, 4, 4, 5, 5, 5, 5, 5, 0,\n 0, 4, 4, 4, 0, 5, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 0, 0, 4, 4, 4,\n 0, 0, 0, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 0, 0,\n 0, 0, 0, 0, 5, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0, 0, 0, 0, 0, 0, 0])" + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"second_0\"].actions" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-09T16:48:20.388506Z", + "start_time": "2024-01-09T16:48:20.385703Z" + } + }, + "id": "1800cbfeca577ec8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pettingzoo/pz-interactive.ipynb b/examples/pettingzoo/pz-interactive.ipynb new file mode 100644 index 0000000..b812132 --- /dev/null +++ b/examples/pettingzoo/pz-interactive.ipynb @@ -0,0 +1,286 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T21:10:04.931015Z", + "start_time": "2023-12-13T21:10:02.952918Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.pettingzoo import AECEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:10:04.993162Z", + "start_time": "2023-12-13T21:10:04.931405Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:10:04.998015Z", + "start_time": "2023-12-13T21:10:04.992371Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T22:10:04.992766\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:12:23.631838Z", + "start_time": "2023-12-13T21:12:20.897201Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "\n", + "cenv = AECEnvironment(env_path=\"pettingzoo.butterfly.cooperative_pong_v5.env\",\n", + " render=True)\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"pong\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:12:28.040971Z", + "start_time": "2023-12-13T21:12:23.630748Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space(\"paddle_0\"))\n", + "constant_actor = ConstantActor(1)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "{'pong': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:12:41.346850Z", + "start_time": "2023-12-13T21:12:41.339815Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:12:47.732796Z", + "start_time": "2023-12-13T21:12:45.491361Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PONG_ACTIONS = [\"no-op\", \"ArrowUp\", \"ArrowDown\"]\n", + "\n", + "actions = PONG_ACTIONS\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [], + "source": [ + "# Get data from a random + random trial\n", + "# You can change the values in `actor_impls` between `web_ui`, `random`, and `constant` to see the different behaviors\n", + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"pong\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"paddle_0\": \"constant\",\n", + " \"paddle_1\": \"web_ui\",\n", + " },\n", + ")\n", + "\n", + "data = await cog.get_trial_data(trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:13:35.139518Z", + "start_time": "2023-12-13T21:13:25.521318Z" + } + }, + "id": "8052ff03998b0b52" + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [ + { + "data": { + "text/plain": "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 0, 1, 1, 1, 0, 2, 2,\n 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n 0, 0, 0])" + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"paddle_1\"].actions" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:17:41.560517Z", + "start_time": "2023-12-13T21:17:41.555599Z" + } + }, + "id": "1800cbfeca577ec8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pettingzoo/pz-kaz.ipynb b/examples/pettingzoo/pz-kaz.ipynb new file mode 100644 index 0000000..bea18da --- /dev/null +++ b/examples/pettingzoo/pz-kaz.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2023-12-13T21:21:42.488660Z", + "start_time": "2023-12-13T21:21:40.354720Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.pettingzoo import AECEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:21:46.747801Z", + "start_time": "2023-12-13T21:21:46.684053Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:21:48.556919Z", + "start_time": "2023-12-13T21:21:48.549677Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2023-12-13T22:21:48.548102\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ariel/PycharmProjects/cogment_lab/venv/lib/python3.10/site-packages/cogment/context.py:213: UserWarning: No logging handler defined (e.g. logging.basicConfig)\n", + " warnings.warn(\"No logging handler defined (e.g. logging.basicConfig)\")\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:21:54.881072Z", + "start_time": "2023-12-13T21:21:51.714545Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "\n", + "cenv = AECEnvironment(env_path=\"pettingzoo.butterfly.knights_archers_zombies_v10.env\",\n", + " render=True)\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"kaz\",\n", + " port=9001, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:22:06.145915Z", + "start_time": "2023-12-13T21:22:01.760364Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space(\"knight_0\"))\n", + "constant_actor = ConstantActor(5)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "{'kaz': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:22:23.177880Z", + "start_time": "2023-12-13T21:22:23.174452Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:22:39.092581Z", + "start_time": "2023-12-13T21:22:36.945998Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PONG_ACTIONS = [\"ArrowUp\", \"ArrowDown\", \"ArrowLeft\", \"ArrowRight\", \"f\", \"no-op\"]\n", + "\n", + "actions = PONG_ACTIONS\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60*4)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# Get data from a random + random trial\n", + "# You can change the values in `actor_impls` between `web_ui`, `random`, and `constant` to see the different behaviors\n", + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"kaz\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"knight_0\": \"random\",\n", + " \"knight_1\": \"random\",\n", + " \"archer_0\": \"web_ui\",\n", + " \"archer_1\": \"random\",\n", + " },\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:22:48.142771Z", + "start_time": "2023-12-13T21:22:48.129780Z" + } + }, + "id": "e41e4ba2ff066f5c" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "\n", + "data = await cog.get_trial_data(trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-13T21:23:46.501926Z", + "start_time": "2023-12-13T21:22:53.704091Z" + } + }, + "id": "8052ff03998b0b52" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": "array([4, 0, 2, 4, 3, 2, 3, 2, 5, 2, 0, 5, 0, 0, 1, 2, 4, 3, 2, 5, 4, 5,\n 3, 1, 5, 0, 2, 4, 5, 0, 4, 1, 2, 5, 5, 0, 1, 5, 0, 2, 5, 4, 3, 4,\n 2, 2, 2, 1, 2, 2, 4, 0, 2, 3, 2, 4, 2, 1, 3, 5, 1, 4, 0, 4, 5, 1,\n 5, 4, 3, 4, 1, 5, 0, 4, 4, 2, 0, 5, 2, 5, 2, 2, 4, 3, 4, 3, 2, 0,\n 3, 0, 4, 3, 3, 3, 1, 0, 1, 0, 0, 1, 3, 3, 2, 1, 1, 3, 2, 3, 1, 5,\n 5, 3, 4, 5, 5, 3, 0, 0, 5, 5, 4, 2, 3, 2, 2, 5, 0, 5, 0, 0, 3, 3,\n 3, 5, 3, 4, 3, 2, 0, 0, 0, 4, 1, 1, 1, 0, 1, 0, 2, 5, 0, 4, 5, 0,\n 3, 0, 4, 1, 1, 2, 2, 0, 5, 0, 0, 0, 1, 5, 4, 0, 0, 5, 2, 0, 1, 5,\n 3, 2, 3, 5, 4, 0])" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"knight_0\"].actions" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-12-08T16:45:25.290174Z", + "start_time": "2023-12-08T16:45:25.287374Z" + } + }, + "id": "1800cbfeca577ec8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pettingzoo/pz-parallel-interactive.ipynb b/examples/pettingzoo/pz-parallel-interactive.ipynb new file mode 100644 index 0000000..ca96f2a --- /dev/null +++ b/examples/pettingzoo/pz-parallel-interactive.ipynb @@ -0,0 +1,278 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-01-15T20:53:56.772459Z", + "start_time": "2024-01-15T20:53:56.368460Z" + } + }, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "\n", + "from cogment_lab.actors import RandomActor, ConstantActor\n", + "from cogment_lab.envs.pettingzoo import ParallelEnvironment\n", + "from cogment_lab.process_manager import Cogment\n", + "from cogment_lab.utils.runners import process_cleanup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processes terminated successfully.\n" + ] + } + ], + "source": [ + "# Cleans up potentially hanging background processes from previous runs\n", + "process_cleanup()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:53:56.840328Z", + "start_time": "2024-01-15T20:53:56.773095Z" + } + }, + "id": "d431ab6f9d8d29cb" + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2658232039e652c3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:53:56.844588Z", + "start_time": "2024-01-15T20:53:56.841026Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logs/logs-2024-01-15T21:53:56.839732\n" + ] + } + ], + "source": [ + "logpath = f\"logs/logs-{datetime.datetime.now().isoformat()}\"\n", + "\n", + "cog = Cogment(log_dir=logpath)\n", + "\n", + "print(logpath)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a074d1b3-b399-4e34-a68b-e86adb20caee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:53:57.824755Z", + "start_time": "2024-01-15T20:53:57.007830Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch an environment in a subprocess\n", + "\n", + "cenv = ParallelEnvironment(env_path=\"pettingzoo.butterfly.cooperative_pong_v5.parallel_env\",\n", + " render=True)\n", + "\n", + "await cog.run_env(env=cenv, \n", + " env_name=\"pong\",\n", + " port=9011, \n", + " log_file=\"env.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3374d134b845beb2", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:53:58.771737Z", + "start_time": "2024-01-15T20:53:57.822248Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Launch two dummy actors in subprocesses\n", + "\n", + "random_actor = RandomActor(cenv.env.action_space(\"paddle_0\"))\n", + "constant_actor = ConstantActor(1)\n", + "\n", + "await cog.run_actor(actor=random_actor, \n", + " actor_name=\"random\", \n", + " port=9021, \n", + " log_file=\"actor-random.log\")\n", + "\n", + "await cog.run_actor(actor=constant_actor,\n", + " actor_name=\"constant\",\n", + " port=9022,\n", + " log_file=\"actor-constant.log\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "data": { + "text/plain": "{'pong': ,\n 'random': ,\n 'constant': }" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Check what's running\n", + "\n", + "cog.processes" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:53:58.773901Z", + "start_time": "2024-01-15T20:53:58.771546Z" + } + }, + "id": "896164c911313b40" + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "835c4d6ecb2afb23", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:48:09.979970Z", + "start_time": "2024-01-15T20:48:09.351179Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PONG_ACTIONS = [\"no-op\", \"ArrowUp\", \"ArrowDown\"]\n", + "\n", + "actions = PONG_ACTIONS\n", + "await cog.run_web_ui(actions=actions, log_file=\"human.log\", fps=60)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# Get data from a random + random trial\n", + "# You can change the values in `actor_impls` between `web_ui`, `random`, and `constant` to see the different behaviors\n", + "\n", + "trial_id = await cog.start_trial(\n", + " env_name=\"pong\",\n", + " session_config={\"render\": True},\n", + " actor_impls={\n", + " \"paddle_0\": \"random\",\n", + " \"paddle_1\": \"constant\",\n", + " },\n", + ")\n", + "\n", + "data = await cog.get_trial_data(trial_id)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:54:09.441520Z", + "start_time": "2024-01-15T20:54:07.117269Z" + } + }, + "id": "b4854b1295d75e9d" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "data": { + "text/plain": "array([2, 2, 2, 2, 1, 0, 2, 1, 0, 1, 0, 2, 1, 1, 0, 1, 1, 2, 1, 1, 1, 0,\n 1, 2])" + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"paddle_0\"].actions" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-01-15T20:54:09.652354Z", + "start_time": "2024-01-15T20:54:09.649361Z" + } + }, + "id": "1800cbfeca577ec8" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/pettingzoo/pz-simple.py b/examples/pettingzoo/pz-simple.py new file mode 100644 index 0000000..551d872 --- /dev/null +++ b/examples/pettingzoo/pz-simple.py @@ -0,0 +1,83 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import asyncio + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from coltra import HomogeneousGroup +from coltra.buffers import Observation +from coltra.models import MLPModel +from coltra.policy_optimization import CrowdPPOptimizer +from tqdm import trange + +from cogment_lab.actors import ColtraActor, ConstantActor +from cogment_lab.envs import AECEnvironment +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.utils.coltra_utils import convert_trial_data_to_coltra +from cogment_lab.utils.runners import process_cleanup +from cogment_lab.utils.trial_utils import concatenate + + +async def main(): + logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" + + cog = Cogment(log_dir=logpath) + + print(logpath) + + cenv = AECEnvironment( + env_path="pettingzoo.butterfly.cooperative_pong_v5.env", render=False, make_kwargs={"max_cycles": 20} + ) + + await cog.run_env(env=cenv, env_name="pong", port=9011, log_file="env.log") + + # Create a model using coltra + + constant_actor = ConstantActor(1) + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9022, log_file="actor-constant.log") + + # Estimate random agent performance + + trial_id = await cog.start_trial( + env_name="pong", + session_config={"render": True}, + actor_impls={ + "paddle_0": "constant", + "paddle_1": "constant", + }, + ) + + data = await cog.get_trial_data(trial_id=trial_id) + + # mean_reward = np.mean([sum(e.rewards) for e in episodes]) + print(f"Reward shape: {data['paddle_0'].rewards.shape}") + print(f"Rewards: {data['paddle_0'].rewards}") + print(f"Action shape: {data['paddle_0'].actions.shape}") + print(f"Actions: {data['paddle_0'].actions}") + + # Other agent + print(f"Reward shape: {data['paddle_1'].rewards.shape}") + print(f"Rewards: {data['paddle_1'].rewards}") + print(f"Action shape: {data['paddle_1'].actions.shape}") + print(f"Actions: {data['paddle_1'].actions}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..99fab85 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,192 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "cogment_lab" +dynamic = ["version"] +readme = "README.md" +license = "" + +dependencies = [ + "cogment[generate]>=2.0.0,<3.0.0", + "grpcio>=1.44", + "PyYAML~=6.0.1", + "starlette>=0.21.0", + "uvicorn>=0.17.6", + "Gymnasium[classic_control]~=0.29", + "PettingZoo~=1.23.1", + "numpy", + "opencv-python>=4.8", + "fastapi>=0.105", + "pillow>=9.0" +] + +[project.scripts] +cogmentlab = "cogment_lab.cli.cli:main" + +[tool.hatch.version] +path = "cogment_lab/__init__.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/cogment_lab", +] + +# Package ###################################################################### + + +[project.optional-dependencies] +# Update dependencies in `all` if any are added or removed +all = [ + "pytest >=7.1.3", + "pytest-asyncio", + "coltra>=0.2.1", + "torch", + "Grid2Op==1.9.6", + "lightsim2grid", + "numba==0.56.4", + "matplotlib", + "wandb>=0.13.9", +] +dev = [ + "pytest >=7.1.3", + "pytest-asyncio", + "ruff>=0.1.7", + "jupyter>=1.0.0", + "jupyterlab>=3.5.3", +] +coltra = [ + "coltra>=0.2.1", + "torch", + "wandb>=0.13.9", +] +grid2op = [ + "Grid2Op==1.9.6", + "lightsim2grid", + "numba==0.56.4", +] + +[project.urls] +Homepage = "https://cogment.ai/lab" +Repository = "https://github.com/cogment/cogment_lab" +Documentation = "https://cogment.ai/lab" +"Bug Report" = "https://github.com/cogment/cogment_lab/issues" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +include = ["cogment_lab", "cogment_lab.*"] + +[tool.setuptools.package-data] +cogment_lab = [ + "py.typed", +] + +# Linters and Test tools ####################################################### + +[tool.black] +safe = true + +[tool.isort] +atomic = true +profile = "black" +src_paths = ["cogment_lab", "tests", "docs/scripts"] +extra_standard_library = ["typing_extensions"] +indent = 4 +lines_after_imports = 2 +multi_line_output = 3 + +[tool.pyright] +include = ["cogment_lab/**", "tests/**"] +exclude = ["**/node_modules", "**/__pycache__"] +strict = [] + +typeCheckingMode = "basic" +pythonVersion = "3.7" +pythonPlatform = "All" +typeshedPath = "typeshed" +enableTypeIgnoreComments = true + +# This is required as the CI pre-commit does not download the module (i.e. numpy, pygame, box2d) +# Therefore, we have to ignore missing imports +reportMissingImports = "none" +# Some modules are missing type stubs, which is an issue when running pyright locally +reportMissingTypeStubs = false +# For warning and error, will raise an error when +reportInvalidTypeVarUse = "none" + +# reportUnknownMemberType = "warning" # -> raises 6035 warnings +# reportUnknownParameterType = "warning" # -> raises 1327 warnings +# reportUnknownVariableType = "warning" # -> raises 2585 warnings +# reportUnknownArgumentType = "warning" # -> raises 2104 warnings +reportGeneralTypeIssues = "none" # -> commented out raises 489 errors +reportUntypedFunctionDecorator = "none" # -> pytest.mark.parameterize issues + +reportPrivateUsage = "warning" +reportUnboundVariable = "warning" + +[tool.pytest.ini_options] +filterwarnings = [] + + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "scripts", + "cogment_lab/generated" +] + +# Same as Black. +line-length = 120 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c220c0e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +cogment[generate]>=2.10.1,<3.0.0 +PyYAML~=6.0.1 +starlette>=0.21.0 +uvicorn[standard]==0.17.6 +fastapi>=0.103.2 +pillow>=9.0 + +# environments +Gymnasium~=0.29 +PettingZoo~=1.23.1 + +# actors +torch +jax +jaxlib +numpy + +# For testing +black~=22.3.0 +pylint~=2.14 +pytest>=7.0,<8.0 +pytest-benchmark~=4.0.0 +pytest-timeout~=2.1.0 + +Grid2Op==1.9.6 +lightsim2grid +numba==0.56.4 + +# RL +coltra>=0.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d50072a --- /dev/null +++ b/setup.py @@ -0,0 +1,56 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setups the project.""" + +import pathlib + +from setuptools import setup + + +CWD = pathlib.Path(__file__).absolute().parent + + +def get_version(): + """Gets the cogment_lab version.""" + path = CWD / "cogment_lab" / "__init__.py" + content = path.read_text() + + for line in content.splitlines(): + if line.startswith("__version__"): + return line.strip().split()[-1].strip().strip('"') + raise RuntimeError("bad version data in __init__.py") + + +def get_description(): + """Gets the description from the readme.""" + with open("README.md") as fh: + long_description = "" + header_count = 0 + for line in fh: + if line.startswith("##"): + header_count += 1 + if header_count < 2: + long_description += line + else: + break + return long_description + + +setup( + name="cogment_lab", + version=get_version(), + long_description=get_description(), + entry_points={"console_scripts": ["cogment_lab=cogment_lab.cli.cli:main"]}, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..7392fcb --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing for CogmentLab.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..46545dc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,29 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import time + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def cogment_lab_subprocess(): + # pass + process = subprocess.Popen(["cogmentlab", "launch", "base"]) + + yield process + + process.terminate() + process.wait() diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..a3e6480 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,15 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of the core functionalities.""" diff --git a/tests/test_gymnasium.py b/tests/test_gymnasium.py new file mode 100644 index 0000000..aa6ed94 --- /dev/null +++ b/tests/test_gymnasium.py @@ -0,0 +1,54 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest + +from cogment_lab.actors import ConstantActor +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment + + +@pytest.mark.asyncio +async def test_cartpole(): + """Test the cartpole environment.""" + + cog = Cogment(log_dir="logs") + + cenv = GymEnvironment(env_id="CartPole-v1", render=True) + + await cog.run_env(env=cenv, env_name="cartpole", port=9011, log_file="env.log") + + constant_actor = ConstantActor(0) + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9021, log_file="actor-constant.log") + + trial_id = await cog.start_trial( + env_name="cartpole", + session_config={"render": True}, + actor_impls={ + "gym": "constant", + }, + ) + + data = await cog.get_trial_data(trial_id=trial_id, env_name="cartpole") + + assert isinstance(data, dict) + assert isinstance(data["gym"].observations, np.ndarray) + assert isinstance(data["gym"].actions, np.ndarray) + assert isinstance(data["gym"].rewards, np.ndarray) + + await cog.cleanup() diff --git a/tests/test_pettingzoo.py b/tests/test_pettingzoo.py new file mode 100644 index 0000000..91d9a3e --- /dev/null +++ b/tests/test_pettingzoo.py @@ -0,0 +1,56 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest + +from cogment_lab.actors import ConstantActor +from cogment_lab.envs.pettingzoo import AECEnvironment +from cogment_lab.process_manager import Cogment + + +@pytest.mark.asyncio +async def test_pong(): + """Test the cartpole environment.""" + + cog = Cogment(log_dir="logs") + + cenv = AECEnvironment(env_path="pettingzoo.butterfly.cooperative_pong_v5.env", render=False) + + await cog.run_env(env=cenv, env_name="pong", port=9012, log_file="env.log") + + constant_actor = ConstantActor(1) + + await cog.run_actor(actor=constant_actor, actor_name="constant", port=9022, log_file="actor-constant.log") + + trial_id = await cog.start_trial( + env_name="pong", + session_config={"render": False}, + actor_impls={ + "paddle_0": "constant", + "paddle_1": "constant", + }, + ) + + data = await cog.get_trial_data(trial_id=trial_id, env_name="pong") + + for agent_name in ["paddle_0", "paddle_1"]: + assert isinstance(data, dict) + assert isinstance(data[agent_name].observations, np.ndarray) + assert isinstance(data[agent_name].actions, np.ndarray) + assert isinstance(data[agent_name].rewards, np.ndarray) + + await cog.cleanup()