From 59bf9920d3b8e2dcc342d0f900e940c491d02504 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 13 Dec 2023 06:43:21 +0100 Subject: [PATCH] Several further improvements in notebooks --- docs/conf.py | 2 +- notebooks/_config.yml | 2 +- .../{Images => _static/images}/offline_RL.jpg | Bin .../images}/offline_RL_1.jpg | Bin .../images}/offline_RL_2.jpg | Bin .../images}/offline_RL_3.jpg | Bin .../images}/policy_constraint_vs_support.png | 0 .../images}/q_value_iterative.png | 0 .../{Images => _static/images}/stiching.png | 0 notebooks/nb_131_Minari_Overview.ipynb | 102 --- .../nb_150_Imitation_learning_theory.ipynb | 113 --- notebooks/nb_190_Final_Remarks.ipynb | 78 -- notebooks/nb_20_IntroductionToControl.ipynb | 4 +- notebooks/nb_30_ControlAndPlanning.ipynb | 5 +- .../nb_40_RecentDevelopmentsInControl.ipynb | 4 +- notebooks/nb_50_IntroRL.ipynb | 304 ++++++-- notebooks/nb_70_TrainingRLAgents.ipynb | 259 ++++-- notebooks/nb_75_EnvironmentEngineering.ipynb | 387 ++++++--- notebooks/nb_80_OfflineRL.ipynb | 273 ------- ...ineRL.ipynb => nb_90_IntroOfflineRL.ipynb} | 311 ++++++++ ...ipynb => nb_91_RLOpenSourceDatasets.ipynb} | 255 +++++- notebooks/nb_92_Minari_Overview.ipynb | 349 +++++++++ ...pynb => nb_93_CollectDataWithMinari.ipynb} | 274 ++++++- .../nb_94_Imitation_learning_theory.ipynb | 357 +++++++++ ...=> nb_95_imitation_learning-example.ipynb} | 297 +++++-- ..._I.ipynb => nb_96_Offline_RL_part_I.ipynb} | 254 +++++- ...97_Offpolicy_distributional_shift_1.ipynb} | 312 ++++++-- ...98_Offpolicy_distributional_shift_2.ipynb} | 311 ++++++-- ....ipynb => nb_990_Offline_RL_part_II.ipynb} | 246 +++++- ...nb_991_offline_rl_algorithms_theory.ipynb} | 266 ++++++- ...b => nb_992_Offline_rl_algorithms_I.ipynb} | 309 ++++++-- ... => nb_992_Offline_rl_algorithms_II.ipynb} | 333 +++++++- poetry.lock | 735 +++++++++++++++++- pyproject.toml | 14 +- 34 files changed, 5038 insertions(+), 1118 deletions(-) rename notebooks/{Images => _static/images}/offline_RL.jpg (100%) rename notebooks/{Images => _static/images}/offline_RL_1.jpg (100%) rename notebooks/{Images => _static/images}/offline_RL_2.jpg (100%) rename notebooks/{Images => _static/images}/offline_RL_3.jpg (100%) rename notebooks/{Images => _static/images}/policy_constraint_vs_support.png (100%) rename notebooks/{Images => _static/images}/q_value_iterative.png (100%) rename notebooks/{Images => _static/images}/stiching.png (100%) delete mode 100644 notebooks/nb_131_Minari_Overview.ipynb delete mode 100644 notebooks/nb_150_Imitation_learning_theory.ipynb delete mode 100644 notebooks/nb_190_Final_Remarks.ipynb delete mode 100644 notebooks/nb_80_OfflineRL.ipynb rename notebooks/{nb_120_IntroOfflineRL.ipynb => nb_90_IntroOfflineRL.ipynb} (54%) rename notebooks/{nb_130_RLOpenSourceDatasets.ipynb => nb_91_RLOpenSourceDatasets.ipynb} (52%) create mode 100644 notebooks/nb_92_Minari_Overview.ipynb rename notebooks/{nb_140_CollectDataWithMinari.ipynb => nb_93_CollectDataWithMinari.ipynb} (62%) create mode 100644 notebooks/nb_94_Imitation_learning_theory.ipynb rename notebooks/{nb_151_imitation_learning-example.ipynb => nb_95_imitation_learning-example.ipynb} (71%) rename notebooks/{nb_160_Offline_RL_part_I.ipynb => nb_96_Offline_RL_part_I.ipynb} (51%) rename notebooks/{nb_161_Offpolicy_distributional_shift_1.ipynb => nb_97_Offpolicy_distributional_shift_1.ipynb} (67%) rename notebooks/{nb_162_Offpolicy_distributional_shift_2.ipynb => nb_98_Offpolicy_distributional_shift_2.ipynb} (62%) rename notebooks/{nb_170_Offline_RL_part_II.ipynb => nb_990_Offline_RL_part_II.ipynb} (66%) rename notebooks/{nb_180_offline_rl_algorithms_theory.ipynb => nb_991_offline_rl_algorithms_theory.ipynb} (52%) rename notebooks/{nb_181_Offline_rl_algorithms_I.ipynb => nb_992_Offline_rl_algorithms_I.ipynb} (61%) rename notebooks/{nb_182_Offline_rl_algorithms_II.ipynb => nb_992_Offline_rl_algorithms_II.ipynb} (54%) diff --git a/docs/conf.py b/docs/conf.py index 1405fc0b..1911a72e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -141,7 +141,7 @@ def lineno_from_object_name(source_file, object_name): # built documents. # # The full version, including alpha/beta/rc tags. -version = pkg_resources.get_distribution(project).version +version = "0.1.0" release = version # The short X.Y version. major_v, minor_v = version.split(".")[:2] diff --git a/notebooks/_config.yml b/notebooks/_config.yml index acc4942d..448a002c 100644 --- a/notebooks/_config.yml +++ b/notebooks/_config.yml @@ -1,7 +1,7 @@ # Book settings # Learn more at https://jupyterbook.org/customize/config.html -title: thesan_output +title: Reinforcement Learning and Control author: appliedAI TransferLab logo: _static/images/transferlab-logo.svg diff --git a/notebooks/Images/offline_RL.jpg b/notebooks/_static/images/offline_RL.jpg similarity index 100% rename from notebooks/Images/offline_RL.jpg rename to notebooks/_static/images/offline_RL.jpg diff --git a/notebooks/Images/offline_RL_1.jpg b/notebooks/_static/images/offline_RL_1.jpg similarity index 100% rename from notebooks/Images/offline_RL_1.jpg rename to notebooks/_static/images/offline_RL_1.jpg diff --git a/notebooks/Images/offline_RL_2.jpg b/notebooks/_static/images/offline_RL_2.jpg similarity index 100% rename from notebooks/Images/offline_RL_2.jpg rename to notebooks/_static/images/offline_RL_2.jpg diff --git a/notebooks/Images/offline_RL_3.jpg b/notebooks/_static/images/offline_RL_3.jpg similarity index 100% rename from notebooks/Images/offline_RL_3.jpg rename to notebooks/_static/images/offline_RL_3.jpg diff --git a/notebooks/Images/policy_constraint_vs_support.png b/notebooks/_static/images/policy_constraint_vs_support.png similarity index 100% rename from notebooks/Images/policy_constraint_vs_support.png rename to notebooks/_static/images/policy_constraint_vs_support.png diff --git a/notebooks/Images/q_value_iterative.png b/notebooks/_static/images/q_value_iterative.png similarity index 100% rename from notebooks/Images/q_value_iterative.png rename to notebooks/_static/images/q_value_iterative.png diff --git a/notebooks/Images/stiching.png b/notebooks/_static/images/stiching.png similarity index 100% rename from notebooks/Images/stiching.png rename to notebooks/_static/images/stiching.png diff --git a/notebooks/nb_131_Minari_Overview.ipynb b/notebooks/nb_131_Minari_Overview.ipynb deleted file mode 100644 index abea1684..00000000 --- a/notebooks/nb_131_Minari_Overview.ipynb +++ /dev/null @@ -1,102 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "ae619194", - "metadata": {}, - "source": [ - "## Minari overview\n", - "\n", - "Unfortunately not supported in windows (maybe mainly because of Mujoco dependency)\n", - "\n", - "1 - **[Intro Minari](https://minari.farama.org/main/content/basic_usage/)**\n", - "\n", - "It is use to collect data given a gymnasium environment:\n", - "\n", - " from minari import DataCollectorV0\n", - " import gymnasium as gym\n", - "\n", - " env = gym.make('LunarLander-v2')\n", - " env = DataCollectorV0(env, record_infos=True, max_buffer_steps=100000)\n", - "\n", - "\n", - "Has different functionalities like to add useful metadata to the datasets, fuse data coming from different behavioral policies, custom preprocessing of the collected (observation, action, reward) data, add new data to existing datasets, restore the environments associated with your data, etc. We will use some of these functionalities in our exercises.\n", - "\n", - "However, is not a very robust library so in the possible it is better to have your own functions to restore your environment, load metadata, etc. And rely on Minari as little as possible. We will see that in the notebook exercises.\n", - "\n", - "\n", - "Of course one of its main use is to Download the very useful datasets available online as well as to have access to the policies used to generate them: [datasets](https://minari.farama.org/main/datasets/pen/human/)\n", - "\n", - "\n", - "\n", - "\n", - "Many interesting datasets uses the Mujoco C/C++ library.\n", - "\n", - "2 - **Mujoco:**\n", - "\n", - "MuJoCo, short for Multi-Joint dynamics with Contact, is a versatile physics engine designed to support research and development across various fields, including robotics, bio-mechanics, graphics, animation, machine learning, and more. Originally created by Robotic LLC, it was later acquired by DeepMind and made freely accessible to the public in October 2021. Furthermore, it was open-sourced in May 2022. [GitHub](https://github.com/google-deepmind/mujoco)" - ] - }, - { - "cell_type": "markdown", - "id": "f6408068", - "metadata": {}, - "source": [ - "## RL unplugged (Deepmind)\n", - "\n", - "[website](https://www.deepmind.com/blog/rl-unplugged-benchmarks-for-offline-reinforcement-learning) and [blog](https://www.deepmind.com/blog/rl-unplugged-benchmarks-for-offline-reinforcement-learning)" - ] - }, - { - "cell_type": "markdown", - "id": "252dec6b", - "metadata": {}, - "source": [ - "## Open X-Embodiment\n", - "\n", - "[website](https://robotics-transformer-x.github.io/)" - ] - }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "collapsed": false - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/nb_150_Imitation_learning_theory.ipynb b/notebooks/nb_150_Imitation_learning_theory.ipynb deleted file mode 100644 index d627d908..00000000 --- a/notebooks/nb_150_Imitation_learning_theory.ipynb +++ /dev/null @@ -1,113 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Imitation Learning\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Imitation learning is a supervise learning approach focuses on learning policies or behaviors by observing and imitating expert demonstrations**. Instead of learning from trial and error, imitation learning leverages existing expert knowledge to train agents.\n", - "\n", - "This makes these algorithms appealing as **you don't need to create a reward function for your task** like in situations where the manual approach becomes essential because creating a reward function directly is not feasible, such as when training a self-driving vehicle." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The easiest imitation learning algorithm is call BC (Behavioral Cloning) and is just supervised learning on the collected expert data, i.e.:\n", - "\n", - "$$ D = \\{(s_0, a_0), (s_1, a_1), \\ldots, (s_T, a_T^o)\\} \\quad \\tag{Dataset} $$\n", - "\n", - "$$ L_{BC}(\\theta) = \\frac{1}{2} \\left(\\pi_\\theta(s_t) - a_t\\right)^2 \\tag{Cost function}$$\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are improve versions of BC like DAgger (Dataset Aggregation) where after BC the policy is being rollout and if new states appear a new feedback is to ask the human expert. This could produce a huge improvement, although it could be quite expensive.\n", - "\n", - "Pros and cons of these models:\n", - "\n", - "**pros**: If you have expert dataset, and you are not worry about safety (i.e. unexpected policy behavior in unknown states) this could be a fast approach.\n", - "\n", - "**cons**: In general we don't have access to expert data so this is one of the main issues, but even if we have we will have problems related with distributional shift between our clone policy and the provided dataset. We will see this in a moment in an exercise. Also, many of the properties of the Minari datasets (see exercise notebook) that could appear in reality cannot be handled with simple imitation learning approaches, like for instance the stitching property." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "There are other interesting methods that combine imitation learning and the offline RL methods we will introduce later. Typically, they involve two steps:\n", - "\n", - "1 - Modeling data distribution (Imitation learning).\n", - "\n", - "2 - Applying offline RL for planning.\n", - "\n", - "In the first step, they use more sophisticated techniques for cloning, such as Transformers to generate new trajectories or normalizing flows to fit the state-action data distribution." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### References\n", - "\n", - "[ Ross et al. 2012 - A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning](https://arxiv.org/abs/1011.0686)\n", - "\n", - "[Janner et al. 2021 - Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039)\n", - "\n", - "[Prudencio et al. 2023 - A Survey on Offline Reinforcement Learning: Taxonomy, Review, and Open Problems ](https://arxiv.org/pdf/2203.01387.pdf)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.0" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/nb_190_Final_Remarks.ipynb b/notebooks/nb_190_Final_Remarks.ipynb deleted file mode 100644 index 697f5ec2..00000000 --- a/notebooks/nb_190_Final_Remarks.ipynb +++ /dev/null @@ -1,78 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "813ba6c9", - "metadata": {}, - "source": [ - "Final remarks\n", - "---" - ] - }, - { - "cell_type": "markdown", - "id": "e7d216cf", - "metadata": {}, - "source": [ - "Offline RL proves valuable in various scenarios, especially when:\n", - "\n", - "a. Robots require intelligent behavior in complex open-world environments demanding extensive training data due to robust visual perception requirements. (complex environment modeling and extensive data collection)\n", - "\n", - "b. Robot grasping tasks, which involve expert data that cannot be accurately simulated, providing an opportunity to assess our BCQ algorithm.\n", - "\n", - "c. Robotic navigation tasks, where offline RL aids in crafting effective navigation policies using real-world data.\n", - "\n", - "d. Autonomous driving, where ample expert data and an offline approach enhance safety.\n", - "\n", - "e. Healthcare applications, where safety is paramount due to the potential serious consequences of inaccurate forecasts.\n", - "\n", - "... and many more. \n", - "\n", - "However, if you have access to an environment with abundant data, online Reinforcement Learning (RL) can be a powerful choice due to its potential for exploration and real-time feedback. Nevertheless, the landscape of RL is evolving, and a data-centric approach is gaining prominence, exemplified by vast datasets like X-Embodiment. It's becoming evident that robots trained with diverse data across various scenarios tend to outperform those solely focused on specific tasks. Furthermore, leveraging multitask trained agents for transfer learning can be a valuable strategy for addressing your specific task at hand." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/nb_20_IntroductionToControl.ipynb b/notebooks/nb_20_IntroductionToControl.ipynb index 65d0f2ac..aa052edd 100644 --- a/notebooks/nb_20_IntroductionToControl.ipynb +++ b/notebooks/nb_20_IntroductionToControl.ipynb @@ -118,7 +118,7 @@ } }, "source": [ - "# Introduction\n", + "# Introduction to Control\n", "\n", "Control theory is a field of control engineering and applied mathematics that deals with the control of dynamical systems in engineered processes and machines. The objective is to develop a model or algorithm governing the application of system inputs to drive the system to a desired state, while minimizing any delay, overshoot, or steady-state error and ensuring a level of control stability; often with the aim to achieve a degree of optimality. " ] @@ -446,7 +446,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -3891,7 +3890,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" } diff --git a/notebooks/nb_30_ControlAndPlanning.ipynb b/notebooks/nb_30_ControlAndPlanning.ipynb index 11ea9637..8e024c04 100644 --- a/notebooks/nb_30_ControlAndPlanning.ipynb +++ b/notebooks/nb_30_ControlAndPlanning.ipynb @@ -126,7 +126,7 @@ } }, "source": [ - "# Introduction\n", + "# Control and Planning\n", "\n", "In previous sections, we have designed feedback controllers for various systems with the goal of regulating the system output to a desired setpoint. Specifically, we utilized Fullstate Feedback and PID controllers. While these simple controllers can effectively regulate many systems, they have limitations that prevent high performance control for more complex systems.\n", "\n", @@ -1542,7 +1542,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -2166,7 +2165,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" }, @@ -3234,7 +3232,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" }, diff --git a/notebooks/nb_40_RecentDevelopmentsInControl.ipynb b/notebooks/nb_40_RecentDevelopmentsInControl.ipynb index 390e5c25..8f447112 100644 --- a/notebooks/nb_40_RecentDevelopmentsInControl.ipynb +++ b/notebooks/nb_40_RecentDevelopmentsInControl.ipynb @@ -111,7 +111,7 @@ } }, "source": [ - "# Introduction\n", + "# Recent Developments in Control Theory\n", "\n", "- So far we have focused on deterministic systems with no noise or disturbances.\n", "- In this part of the training, we will focus on stochastic systems and how MPC can be used to handle such systems.\n", @@ -791,7 +791,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -949,7 +948,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false, "slideshow": { "slide_type": "subslide" } diff --git a/notebooks/nb_50_IntroRL.ipynb b/notebooks/nb_50_IntroRL.ipynb index 98dd6bab..8b540251 100644 --- a/notebooks/nb_50_IntroRL.ipynb +++ b/notebooks/nb_50_IntroRL.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -14,21 +14,9 @@ "remove-output", "remove-input-nbconv", "remove-output-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:41:16.500081692Z", - "start_time": "2023-11-26T23:41:15.344177128Z" - } + ] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "UsageError: Line magic function `%%capture` not found.\n" - ] - } - ], + "outputs": [], "source": [ "from typing import Callable\n", "\n", @@ -39,12 +27,14 @@ "from dataclasses import dataclass\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "import gymnasium as gym" + "import gymnasium as gym\n", + "from functools import partial" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "%%capture\n", @@ -53,18 +43,11 @@ "%autoreload 2\n", "%matplotlib inline\n", "%load_ext training_rl" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-11-26T23:43:03.513569713Z", - "start_time": "2023-11-26T23:43:02.943422281Z" - } - } + ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -74,19 +57,175 @@ "tags": [ "remove-input", "remove-input-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:43:04.526180480Z", - "start_time": "2023-11-26T23:43:04.451400708Z" - } + ] }, "outputs": [ { "data": { - "text/plain": "", - "text/html": "" + "text/html": [ + "" + ], + "text/plain": [ + "" + ] }, - "execution_count": 6, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -109,11 +248,7 @@ "remove-output", "remove-input-nbconv", "remove-output-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:43:08.146406647Z", - "start_time": "2023-11-26T23:43:05.573993067Z" - } + ] }, "outputs": [], "source": [ @@ -124,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -134,17 +269,27 @@ "tags": [ "remove-input-nbconv", "remove-cell" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:43:08.188041370Z", - "start_time": "2023-11-26T23:43:08.137284385Z" - } + ] }, "outputs": [ { "data": { - "text/plain": "", - "text/markdown": "\n$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n$\\newcommand{\\amax}{{\\text{argmax}}}$\n$\\newcommand{\\P}{{\\mathbb{P}}}$\n$\\newcommand{\\E}{{\\mathbb{E}}}$\n$\\newcommand{\\R}{{\\mathbb{R}}}$\n$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n$\\newcommand{\\N}{{\\mathbb{N}}}$\n$\\newcommand{\\C}{{\\mathbb{C}}}$\n$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] }, "metadata": {}, "output_type": "display_data" @@ -163,7 +308,7 @@ }, "source": [ "\"Snow\"\n", - "
Include title and greeting with divs
" + "
Intro to Reinforcement Learning
" ] }, { @@ -592,6 +737,65 @@ "cartpole_env" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solution:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_returns(rewards: list[float], gamma: float = 1):\n", + " returns_reverted = [rewards[-1]]\n", + " for rew in rewards[-2::-1]:\n", + " returns_reverted.append(rew + gamma * returns_reverted[-1])\n", + " return list(reversed(returns_reverted))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "traj = []\n", + "rewards = []\n", + "cartpole_env.reset()\n", + "for step_num in range(50):\n", + " action = cartpole_env.action_space.sample()\n", + " obs, reward, terminated, truncated, info = cartpole_env.step(action)\n", + " rewards.append(reward)\n", + " entry = TrajEntry(step=step_num, obs=obs, action=action, reward=reward, next_obs=obs)\n", + " entry.frame = cartpole_env.render()\n", + " traj.append(entry)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "traj_returns = compute_returns(rewards)\n", + "entry_plotter = partial(plot_traj_entry_with_reward_and_return, returns=traj_returns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_trajectory_animation(traj, entry_plotter)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1165,7 +1369,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.6" }, "rise": { "footer": "logo", diff --git a/notebooks/nb_70_TrainingRLAgents.ipynb b/notebooks/nb_70_TrainingRLAgents.ipynb index f61fe0d2..55d6a247 100644 --- a/notebooks/nb_70_TrainingRLAgents.ipynb +++ b/notebooks/nb_70_TrainingRLAgents.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -14,11 +14,7 @@ "remove-output", "remove-input-nbconv", "remove-output-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:56:04.430208235Z", - "start_time": "2023-11-26T23:56:03.267608423Z" - } + ] }, "outputs": [], "source": [ @@ -32,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -42,19 +38,175 @@ "tags": [ "remove-input", "remove-input-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:56:04.433355151Z", - "start_time": "2023-11-26T23:56:04.397407316Z" - } + ] }, "outputs": [ { "data": { - "text/plain": "", - "text/html": "" + "text/html": [ + "" + ], + "text/plain": [ + "" + ] }, - "execution_count": 2, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -65,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -77,11 +229,7 @@ "remove-output", "remove-input-nbconv", "remove-output-nbconv" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:56:06.022148692Z", - "start_time": "2023-11-26T23:56:04.410836798Z" - } + ] }, "outputs": [], "source": [ @@ -92,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "hide_input": true, "init_cell": true, @@ -102,17 +250,27 @@ "tags": [ "remove-input-nbconv", "remove-cell" - ], - "ExecuteTime": { - "end_time": "2023-11-26T23:56:06.042304920Z", - "start_time": "2023-11-26T23:56:06.019350431Z" - } + ] }, "outputs": [ { "data": { - "text/plain": "", - "text/markdown": "\n$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n$\\newcommand{\\amax}{{\\text{argmax}}}$\n$\\newcommand{\\P}{{\\mathbb{P}}}$\n$\\newcommand{\\E}{{\\mathbb{E}}}$\n$\\newcommand{\\R}{{\\mathbb{R}}}$\n$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n$\\newcommand{\\N}{{\\mathbb{N}}}$\n$\\newcommand{\\C}{{\\mathbb{C}}}$\n$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] }, "metadata": {}, "output_type": "display_data" @@ -131,7 +289,7 @@ }, "source": [ "\"Snow\"\n", - "
Include title and greeting with divs
" + "
Training RL Agents
" ] }, { @@ -156,23 +314,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-26T23:56:21.418951424Z", - "start_time": "2023-11-26T23:56:21.388038633Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The tensorboard extension is already loaded. To reload it, use:\n", - " %reload_ext tensorboard\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "%load_ext tensorboard\n", "\n", @@ -194,22 +338,9 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "ExecuteTime": { - "end_time": "2023-11-26T23:56:26.204074047Z", - "start_time": "2023-11-26T23:56:24.043625888Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": "Launching TensorBoard..." - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "%tensorboard --logdir log --host localhost" ] @@ -734,7 +865,7 @@ "metadata": {}, "outputs": [], "source": [ - "env = get_pendulum_env()\n", + "env = get_pendulum_env(render_mode=\"rgb_array\")\n", "\n", "\n", "demo_model(env, policy.compute_action, num_steps=400)" diff --git a/notebooks/nb_75_EnvironmentEngineering.ipynb b/notebooks/nb_75_EnvironmentEngineering.ipynb index fa98e6b5..0fd302ec 100644 --- a/notebooks/nb_75_EnvironmentEngineering.ipynb +++ b/notebooks/nb_75_EnvironmentEngineering.ipynb @@ -1,5 +1,259 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "\"Snow\"\n", + "
Environments and Feature Engineering
" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": { @@ -80,20 +334,7 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "import sys" - ] - }, - { - "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "pycharm": { "is_executing": true, @@ -104,14 +345,23 @@ "source": [ "import logging\n", "\n", + "from gymnasium import ObservationWrapper\n", + "from gymnasium.spaces import Box\n", + "\n", "from gymnasium.envs.classic_control import PendulumEnv\n", "from gymnasium.wrappers import TimeLimit\n", - "import numpy as np" + "import numpy as np\n", + "\n", + "import gymnasium\n", + "from abc import ABC, abstractmethod\n", + "from functools import reduce\n", + "from numbers import Number\n", + "from typing import TypeVar, Generic, Sequence" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -163,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -171,10 +421,6 @@ }, "outputs": [], "source": [ - "from gymnasium import ObservationWrapper\n", - "from gymnasium.spaces import Box\n", - "\n", - "\n", "class AddThetadotSquaredWrapper(ObservationWrapper):\n", " def __init__(self, env: PendulumEnv):\n", " super().__init__(env)\n", @@ -186,7 +432,7 @@ " low=np.array(low), high=np.array(high), dtype=env.observation_space.dtype\n", " )\n", "\n", - " def observation(self, observation):\n", + " def observation(self, observation: np.ndarray) -> np.ndarray:\n", " thetadot = observation[-1]\n", " result = list(observation)\n", " result.append(thetadot ** 2)\n", @@ -195,32 +441,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/mpanchen/.cache/pypoetry/virtualenvs/tfl_training_rl-vFKr2zha-py3.11/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001B[33mWARN: Box bound precision lowered by casting to float32\u001B[0m\n", - " gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n" - ] - }, - { - "data": { - "text/plain": [ - "(array([ 0.99859995, -0.05289762, -0.69357789, 0.48105028]), {})" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "extended_env = AddThetadotSquaredWrapper(env)\n", "extended_env.reset()" @@ -259,31 +486,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "data": { - "text/plain": [ - ">>" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "extended_env" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -291,11 +507,6 @@ }, "outputs": [], "source": [ - "import gymnasium\n", - "from typing import List\n", - "from abc import ABC, abstractmethod\n", - "\n", - "\n", "class ScalarObservation(ABC):\n", " \"\"\"\n", " Base class for observations based on greyscale images, e.g. as provided by ScanningEMEnv\n", @@ -339,7 +550,7 @@ " return f\"<{self.scalar_observation}{self.env}>\"\n", "\n", "\n", - "def concatenate_boxes(boxes: List[Box]):\n", + "def concatenate_boxes(boxes: list[Box]):\n", " result_lows = []\n", " result_highs = []\n", " for b in boxes:\n", @@ -381,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -420,7 +631,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -438,48 +649,26 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "data": { - "text/plain": [ - ">>" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "enhanced_env" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "Box([-1. -1. -8. 0. 0.], [ 1. 1. 8. 64. 512.], (5,), float32)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "enhanced_env.observation_space" ] @@ -503,7 +692,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -553,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -636,7 +825,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -697,7 +886,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" @@ -705,10 +894,6 @@ }, "outputs": [], "source": [ - "from functools import reduce\n", - "from numbers import Number\n", - "from typing import TypeVar, Generic, Sequence\n", - "\n", "log = logging.getLogger(__name__)\n", "\n", "EnvType = TypeVar(\"EnvType\", bound=gymnasium.Env)\n", @@ -769,7 +954,7 @@ " return self._range\n", "\n", " def __init__(\n", - " self, reward_metrics: List[RewardMetric], weights: Sequence[float] = None\n", + " self, reward_metrics: list[RewardMetric], weights: Sequence[float] = None\n", " ):\n", " self.reward_metrics = reward_metrics\n", " n_metrics = len(reward_metrics)\n", @@ -818,7 +1003,7 @@ "\n", "\n", "class RewardMetricProduct(RewardMetric):\n", - " def __init__(self, reward_metrics: List[RewardMetric]):\n", + " def __init__(self, reward_metrics: list[RewardMetric]):\n", " self.reward_metrics = reward_metrics\n", "\n", " # NOTE: finding the right range in the general case requires quite complicated logic b/c of changes in sign.\n", diff --git a/notebooks/nb_80_OfflineRL.ipynb b/notebooks/nb_80_OfflineRL.ipynb deleted file mode 100644 index ea8c7522..00000000 --- a/notebooks/nb_80_OfflineRL.ipynb +++ /dev/null @@ -1,273 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hide_input": true, - "init_cell": true, - "slideshow": { - "slide_type": "skip" - }, - "tags": [ - "remove-input", - "remove-output", - "remove-input-nbconv", - "remove-output-nbconv" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "%matplotlib inline\n", - "%load_ext training_rl" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hide_input": true, - "init_cell": true, - "slideshow": { - "slide_type": "skip" - }, - "tags": [ - "remove-input", - "remove-input-nbconv" - ] - }, - "outputs": [], - "source": [ - "%presentation_style" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hide_input": true, - "init_cell": true, - "slideshow": { - "slide_type": "skip" - }, - "tags": [ - "remove-input", - "remove-output", - "remove-input-nbconv", - "remove-output-nbconv" - ] - }, - "outputs": [], - "source": [ - "%%capture\n", - "\n", - "%set_random_seed 12" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "hide_input": true, - "init_cell": true, - "slideshow": { - "slide_type": "skip" - }, - "tags": [ - "remove-input-nbconv", - "remove-cell" - ] - }, - "outputs": [], - "source": [ - "%load_latex_macros" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "slide" - } - }, - "source": [ - "\"Snow\"\n", - "
Include title and greeting with divs
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "slide" - } - }, - "source": [ - "# Offline Reinforcement Learning\n", - "\n", - "Copy this file and use it as a template for new notebooks." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## What kind of data is used in offline RL?\n", - "\n", - "- only actions: BC, inverse RL\n", - "- perfect trajectories and rewards: modified BC, critic regularized regression, CQL\n", - "- imperfect trajectories: want to improve on behavior policy. BQL, IQL" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Typical problems in offline RL\n", - "\n", - "- Distribution shift -> Show example\n", - "- Argmax in learned critic is problematic (like adversarial attack) -> Show example\n", - "- Poor behavior policy\n", - "- If only actions are collected - what is the reward? Inverse RL is hard and ambiguous\n", - "- Poor examples are not collected at all by good policies (driving car off the road), major distribution shift. Solution - something like Dagger, if access to env is possible somehow (show a Dagger example?)\n", - "- Standard off-policy algorithms don't work well since errors in critic are not corrected by collecting more samples (compare SAC and AWAC)\n", - "- ..." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Why offline RL then?\n", - "- Might be no other choice (no access to env)\n", - "- Way easier to implement (supervised learning, no sampling loop)\n", - "- When no clear reward is given (navigating drones, self-driving cars, robot arms, etc.)\n", - "- As a pre-training step before an actual RL algo (AWAC and follow-up papers) -> Show example" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Software for offline RL\n", - "- d4rl -> minari\n", - "- supervised libraries, but also tianshou" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Examples:\n", - "- how to collect data\n", - "- mujoco, data collected with experts - CRR, BC.\n", - "- mujoco, suboptimal policy - BQL, IQL?\n", - "- using d4rl/minari datasets" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## References" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "slide" - }, - "tags": [ - "remove-cell", - "remove-cell-nbconv" - ] - }, - "source": [ - "\"Snow\"\n", - "
Thank you for the attention!
" - ] - } - ], - "metadata": { - "celltoolbar": "Slideshow", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" - }, - "rise": { - "footer": "logo", - "header": "logo", - "theme": "white" - }, - "toc": { - "base_numbering": 1, - "nav_menu": { - "height": "148px", - "width": "256px" - }, - "number_sections": false, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": { - "height": "563.2px", - "left": "125px", - "top": "116.469px", - "width": "315.6px" - }, - "toc_section_display": true, - "toc_window_display": true - }, - "varInspector": { - "cols": { - "lenName": 16, - "lenType": 16, - "lenVar": 40 - }, - "kernels_config": { - "python": { - "delete_cmd_postfix": "", - "delete_cmd_prefix": "del ", - "library": "var_list.py", - "varRefreshCmd": "print(var_dic_list())" - }, - "r": { - "delete_cmd_postfix": ") ", - "delete_cmd_prefix": "rm(", - "library": "var_list.r", - "varRefreshCmd": "cat(var_dic_list()) " - } - }, - "types_to_exclude": [ - "module", - "function", - "builtin_function_or_method", - "instance", - "_Feature" - ], - "window_display": false - } - }, - "nbformat": 4, - "nbformat_minor": 1 -} diff --git a/notebooks/nb_120_IntroOfflineRL.ipynb b/notebooks/nb_90_IntroOfflineRL.ipynb similarity index 54% rename from notebooks/nb_120_IntroOfflineRL.ipynb rename to notebooks/nb_90_IntroOfflineRL.ipynb index d6515f6c..760f6eec 100644 --- a/notebooks/nb_120_IntroOfflineRL.ipynb +++ b/notebooks/nb_90_IntroOfflineRL.ipynb @@ -1,5 +1,249 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -8,6 +252,73 @@ "---" ] }, + { + "cell_type": "markdown", + "source": [ + "## What kind of data is used in offline RL?\n", + "\n", + "- only actions: BC, inverse RL\n", + "- perfect trajectories and rewards: modified BC, critic regularized regression, CQL\n", + "- imperfect trajectories: want to improve on behavior policy. BQL, IQL" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Typical problems in offline RL\n", + "\n", + "- Distribution shift -> Show example\n", + "- Argmax in learned critic is problematic (like adversarial attack) -> Show example\n", + "- Poor behavior policy\n", + "- If only actions are collected - what is the reward? Inverse RL is hard and ambiguous\n", + "- Poor examples are not collected at all by good policies (driving car off the road), major distribution shift. Solution - something like Dagger, if access to env is possible somehow (show a Dagger example?)\n", + "- Standard off-policy algorithms don't work well since errors in critic are not corrected by collecting more samples (compare SAC and AWAC)\n", + "- ..." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Why offline RL then?\n", + "- Might be no other choice (no access to env)\n", + "- Way easier to implement (supervised learning, no sampling loop)\n", + "- When no clear reward is given (navigating drones, self-driving cars, robot arms, etc.)\n", + "- As a pre-training step before an actual RL algo (AWAC and follow-up papers) -> Show example" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Software for offline RL\n", + "- d4rl -> minari\n", + "- supervised libraries, but also tianshou" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Examples:\n", + "- how to collect data\n", + "- mujoco, data collected with experts - CRR, BC.\n", + "- mujoco, suboptimal policy - BQL, IQL?\n", + "- using d4rl/minari datasets" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, diff --git a/notebooks/nb_130_RLOpenSourceDatasets.ipynb b/notebooks/nb_91_RLOpenSourceDatasets.ipynb similarity index 52% rename from notebooks/nb_130_RLOpenSourceDatasets.ipynb rename to notebooks/nb_91_RLOpenSourceDatasets.ipynb index 95e22bb0..ab98a059 100644 --- a/notebooks/nb_130_RLOpenSourceDatasets.ipynb +++ b/notebooks/nb_91_RLOpenSourceDatasets.ipynb @@ -1,11 +1,254 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Open Source Datasets for offline RL\n", - "---" + "# Open Source Datasets for offline RL" ] }, { @@ -45,7 +288,7 @@ "\n", "2 - **Undirected and multitask data**: Undirected in the sense that is not directed towards the specific task one is trying to accomplish. E.g.: recording user interactions on the internet or recording videos of a car for autonomous driving. The main purpose it to test how good is the offline agent to be used for \"trajectory stitching\", i.e. combining trajectories from different tasks to achieve new objectives, instead of searching for out-of-distribution trajectories.\n", "\n", - "\"stich_traj\"\n", + "\"stich_traj\"\n", "\n", "3 - **Sparse rewards**: Sparse rewards are challenging in online settings due to their close correlation with exploration. In offline RL, we exclusively explore within the dataset, making it an ideal framework to study the algorithm's response to sparse rewards.\n", "Note that crafting effective rewards can be challenging, and overly complex rewards may inadvertently push solutions towards suboptimal outcomes. In contrast, designing sparse rewards is often more straightforward as it merely involves specifying the task's success criteria, making it an attractive property to work with.\n", @@ -123,11 +366,9 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } + "source": [] } ], "metadata": { diff --git a/notebooks/nb_92_Minari_Overview.ipynb b/notebooks/nb_92_Minari_Overview.ipynb new file mode 100644 index 00000000..8baf9aac --- /dev/null +++ b/notebooks/nb_92_Minari_Overview.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + }, + "id": "3c1aa179ce3989c8" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + }, + "id": "77bdd6b7972f09d2" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + }, + "id": "df1956d66b075301" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + }, + "id": "a52f15ccceafd52b" + }, + { + "cell_type": "markdown", + "id": "ae619194", + "metadata": {}, + "source": [ + "## Minari overview\n", + "\n", + "Unfortunately not supported in windows (maybe mainly because of Mujoco dependency)\n", + "\n", + "1 - **[Intro Minari](https://minari.farama.org/main/content/basic_usage/)**\n", + "\n", + "It is use to collect data given a gymnasium environment:\n", + "\n", + " from minari import DataCollectorV0\n", + " import gymnasium as gym\n", + "\n", + " env = gym.make('LunarLander-v2')\n", + " env = DataCollectorV0(env, record_infos=True, max_buffer_steps=10000)\n", + "\n", + "\n", + "Has different functionalities like to add useful metadata to the datasets, fuse data coming from different behavioral policies, custom preprocessing of the collected (observation, action, reward) data, add new data to existing datasets, restore the environments associated with your data, etc. We will use some of these functionalities in our exercises.\n", + "\n", + "However, is not a very robust library so in the possible it is better to have your own functions to restore your environment, load metadata, etc. And rely on Minari as little as possible. We will see that in the notebook exercises.\n", + "\n", + "\n", + "Of course one of its main use is to Download the very useful datasets available online as well as to have access to the policies used to generate them: [datasets](https://minari.farama.org/main/datasets/pen/human/)\n", + "\n", + "\n", + "\n", + "\n", + "Many interesting datasets uses the Mujoco C/C++ library.\n", + "\n", + "2 - **Mujoco:**\n", + "\n", + "MuJoCo, short for Multi-Joint dynamics with Contact, is a versatile physics engine designed to support research and development across various fields, including robotics, bio-mechanics, graphics, animation, machine learning, and more. Originally created by Robotic LLC, it was later acquired by DeepMind and made freely accessible to the public in October 2021. Furthermore, it was open-sourced in May 2022. [GitHub](https://github.com/google-deepmind/mujoco)" + ] + }, + { + "cell_type": "markdown", + "id": "f6408068", + "metadata": {}, + "source": [ + "## RL unplugged (Deepmind)\n", + "\n", + "[website](https://www.deepmind.com/blog/rl-unplugged-benchmarks-for-offline-reinforcement-learning) and [blog](https://www.deepmind.com/blog/rl-unplugged-benchmarks-for-offline-reinforcement-learning)" + ] + }, + { + "cell_type": "markdown", + "id": "252dec6b", + "metadata": {}, + "source": [ + "## Open X-Embodiment\n", + "\n", + "[website](https://robotics-transformer-x.github.io/)" + ] + }, + { + "cell_type": "markdown", + "id": "62910b6a", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/nb_140_CollectDataWithMinari.ipynb b/notebooks/nb_93_CollectDataWithMinari.ipynb similarity index 62% rename from notebooks/nb_140_CollectDataWithMinari.ipynb rename to notebooks/nb_93_CollectDataWithMinari.ipynb index fff76aeb..4cce5514 100644 --- a/notebooks/nb_140_CollectDataWithMinari.ipynb +++ b/notebooks/nb_93_CollectDataWithMinari.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:52:11.335657Z", - "end_time": "2023-11-24T16:52:13.175313Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -71,12 +310,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:52:36.311793Z", - "end_time": "2023-11-24T16:52:36.454496Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# We have different 2-d grid environments registered.\n", @@ -216,12 +450,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:52:47.643757Z", - "end_time": "2023-11-24T16:53:04.118273Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "DATA_SET_NAME = \"data\"\n", @@ -280,12 +509,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:53:09.030154Z", - "end_time": "2023-11-24T16:53:09.660878Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "NAME_EXPERT_DATA = \"Grid_2D_8x8_discrete-data_rl_workshop-v0\"\n", @@ -319,11 +543,9 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } + "source": [] } ], "metadata": { diff --git a/notebooks/nb_94_Imitation_learning_theory.ipynb b/notebooks/nb_94_Imitation_learning_theory.ipynb new file mode 100644 index 00000000..adf75814 --- /dev/null +++ b/notebooks/nb_94_Imitation_learning_theory.ipynb @@ -0,0 +1,357 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imitation Learning\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Imitation learning is a supervise learning approach focuses on learning policies or behaviors by observing and imitating expert demonstrations**. Instead of learning from trial and error, imitation learning leverages existing expert knowledge to train agents.\n", + "\n", + "This makes these algorithms appealing as **you don't need to create a reward function for your task** like in situations where the manual approach becomes essential because creating a reward function directly is not feasible, such as when training a self-driving vehicle." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The easiest imitation learning algorithm is call BC (Behavioral Cloning) and is just supervised learning on the collected expert data, i.e.:\n", + "\n", + "$$ D = \\{(s_0, a_0), (s_1, a_1), \\ldots, (s_T, a_T^o)\\} \\quad \\tag{Dataset} $$\n", + "\n", + "$$ L_{BC}(\\theta) = \\frac{1}{2} \\left(\\pi_\\theta(s_t) - a_t\\right)^2 \\tag{Cost function}$$\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are improve versions of BC like DAgger (Dataset Aggregation) where after BC the policy is being rollout and if new states appear a new feedback is to ask the human expert. This could produce a huge improvement, although it could be quite expensive.\n", + "\n", + "Pros and cons of these models:\n", + "\n", + "**pros**: If you have expert dataset, and you are not worry about safety (i.e. unexpected policy behavior in unknown states) this could be a fast approach.\n", + "\n", + "**cons**: In general we don't have access to expert data so this is one of the main issues, but even if we have we will have problems related with distributional shift between our clone policy and the provided dataset. We will see this in a moment in an exercise. Also, many of the properties of the Minari datasets (see exercise notebook) that could appear in reality cannot be handled with simple imitation learning approaches, like for instance the stitching property." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are other interesting methods that combine imitation learning and the offline RL methods we will introduce later. Typically, they involve two steps:\n", + "\n", + "1 - Modeling data distribution (Imitation learning).\n", + "\n", + "2 - Applying offline RL for planning.\n", + "\n", + "In the first step, they use more sophisticated techniques for cloning, such as Transformers to generate new trajectories or normalizing flows to fit the state-action data distribution." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### References\n", + "\n", + "[ Ross et al. 2012 - A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning](https://arxiv.org/abs/1011.0686)\n", + "\n", + "[Janner et al. 2021 - Offline Reinforcement Learning as One Big Sequence Modeling Problem](https://arxiv.org/abs/2106.02039)\n", + "\n", + "[Prudencio et al. 2023 - A Survey on Offline Reinforcement Learning: Taxonomy, Review, and Open Problems ](https://arxiv.org/pdf/2203.01387.pdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/nb_151_imitation_learning-example.ipynb b/notebooks/nb_95_imitation_learning-example.ipynb similarity index 71% rename from notebooks/nb_151_imitation_learning-example.ipynb rename to notebooks/nb_95_imitation_learning-example.ipynb index 0dfbaeec..05b07de2 100644 --- a/notebooks/nb_151_imitation_learning-example.ipynb +++ b/notebooks/nb_95_imitation_learning-example.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:41:35.307617Z", - "end_time": "2023-11-24T16:41:37.147465Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -70,12 +309,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:41:37.149613Z", - "end_time": "2023-11-24T16:41:37.150649Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ENV_NAME = CustomEnv.Grid_2D_8x8_discrete\n", @@ -126,12 +360,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:42:05.870794Z", - "end_time": "2023-11-24T16:42:06.551861Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "DATA_SET_NAME = \"data\"\n", @@ -167,12 +396,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:42:09.040274Z", - "end_time": "2023-11-24T16:42:09.991581Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Data saved in /offline_data\n", @@ -218,13 +442,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": true, - "ExecuteTime": { - "start_time": "2023-11-24T16:42:12.827639Z", - "end_time": "2023-11-24T16:42:47.144146Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "#The model policy to be trained.\n", @@ -269,13 +487,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false, - "ExecuteTime": { - "start_time": "2023-11-24T16:42:49.544152Z", - "end_time": "2023-11-24T16:43:17.519357Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_FILE = \"policy.pth\"\n", @@ -383,12 +595,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:43:53.298645Z", - "end_time": "2023-11-24T16:44:24.256635Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Let's now remove the forbidden region and recreate the environment\n", @@ -471,11 +678,9 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } + "source": [] } ], "metadata": { diff --git a/notebooks/nb_160_Offline_RL_part_I.ipynb b/notebooks/nb_96_Offline_RL_part_I.ipynb similarity index 51% rename from notebooks/nb_160_Offline_RL_part_I.ipynb rename to notebooks/nb_96_Offline_RL_part_I.ipynb index 74695b98..f39378c9 100644 --- a/notebooks/nb_160_Offline_RL_part_I.ipynb +++ b/notebooks/nb_96_Offline_RL_part_I.ipynb @@ -1,10 +1,254 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Offline RL challenges" + "# Offline RL challenges" ] }, { @@ -16,18 +260,18 @@ "\n", "The key distinction between online RL and offline RL lies in their exploration capabilities. In online RL, we actively explore the state-action space, while offline RL operates solely within a fixed dataset (denoted as $D$). Going \"out of distribution\" beyond this dataset in offline RL can lead to severe issues.\n", "\n", - "\"offline_rl\"\n", + "\"offline_rl\"\n", "\n", "Online RL involves interactive exploration to discover the highest-reward regions by gathering environmental feedback. In contrast, offline RL imposes strict limitations on exploration beyond the dataset. This constraint results in algorithms overestimating unknown areas and attempting to navigate beyond the dataset, as illustrated in the figure below where the dataset doesn't fully represent high-reward regions.\n", "\n", - "\"offline_rl_1\"\n", + "\"offline_rl_1\"\n", "\n", "As shown in the figure above on the right, once you are out of distribution (o.o.d) (states $s$ and $s'$ in red in the figure), as you don't have any feedback it will be hard to come back to $D$, as the o.o.d errors will propagate. As we will see this is one of the main challenges of offline RL and there are different techniques to mitigate this wrong behavior.\n", "\n", "The o.o.d. issues are not the only distributional shift effect in offline RL.\n", "After computing the optimal policy, it typically operates within a subset of the original dataset distribution, creating a distinct form of distributional shift (D' subset in green in the figure below). Evaluating a policy substantially different from the behavior policy reduces the effective sample size (from D to D'), resulting in increased variance in the estimates. In simpler terms, the limited number of data points may not accurately represent the true data distribution. \n", "\n", - "\"offline_rl_2\"\n", + "\"offline_rl_2\"\n", "\n", "\n", "Can we apply techniques from online RL, known for its effectiveness in solving complex problems, to offline RL? Yes, we can. We can adapt concepts from online RL, particularly off-policy RL algorithms.\n", @@ -51,7 +295,7 @@ "$$ Q^\\pi(s, a) = \\mathbb{E}_\\pi \\left[ r_0 + \\gamma r_1 + \\gamma^2 r_2 + \\ldots \\mid s_0 = s, a_0 = a \\right]\n", "$$\n", "\n", - "\"offline_rl\"\n", + "\"offline_rl\"\n", "\n", "\n", "DQN is a simple method to reach the optimal policy by iteratively compute:\n", diff --git a/notebooks/nb_161_Offpolicy_distributional_shift_1.ipynb b/notebooks/nb_97_Offpolicy_distributional_shift_1.ipynb similarity index 67% rename from notebooks/nb_161_Offpolicy_distributional_shift_1.ipynb rename to notebooks/nb_97_Offpolicy_distributional_shift_1.ipynb index 67dffa09..72244c38 100644 --- a/notebooks/nb_161_Offpolicy_distributional_shift_1.ipynb +++ b/notebooks/nb_97_Offpolicy_distributional_shift_1.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:12.263613Z", - "end_time": "2023-11-24T16:07:13.835752Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -41,6 +280,15 @@ "register_grid_envs()" ] }, + { + "cell_type": "markdown", + "source": [ + "# Off-Policy Distributional Shift - Pt. 1" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -91,12 +339,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:16.840208Z", - "end_time": "2023-11-24T16:07:16.994904Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ENV_NAME = CustomEnv.Grid_2D_8x8_discrete\n", @@ -126,12 +369,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:18.383684Z", - "end_time": "2023-11-24T16:07:18.392400Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "BEHAVIOR_POLICY_I = BehaviorPolicyType.behavior_8x8_eps_greedy_4_0_to_7_7\n", @@ -154,12 +392,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:19.871136Z", - "end_time": "2023-11-24T16:07:20.379213Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "config_combined_data = create_combined_minari_dataset(\n", @@ -221,12 +454,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:24.547105Z", - "end_time": "2023-11-24T16:07:25.189152Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "#Create Buffers with minari datasets\n", @@ -252,12 +480,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:07:26.602656Z", - "end_time": "2023-11-24T16:07:26.611050Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_NAME = PolicyName.dqn\n", @@ -282,12 +505,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:11:22.391320Z", - "end_time": "2023-11-24T16:11:38.215933Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Run the training\n", @@ -321,12 +539,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:11:40.173889Z", - "end_time": "2023-11-24T16:11:40.184722Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_FILE = \"policy_best_reward.pth\"\n", @@ -376,12 +589,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:58:47.521383Z", - "end_time": "2023-11-24T16:00:34.931070Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "NUM_EPISODES = 100 # as more episodes the better\n", diff --git a/notebooks/nb_162_Offpolicy_distributional_shift_2.ipynb b/notebooks/nb_98_Offpolicy_distributional_shift_2.ipynb similarity index 62% rename from notebooks/nb_162_Offpolicy_distributional_shift_2.ipynb rename to notebooks/nb_98_Offpolicy_distributional_shift_2.ipynb index c0ea60aa..437f4385 100644 --- a/notebooks/nb_162_Offpolicy_distributional_shift_2.ipynb +++ b/notebooks/nb_98_Offpolicy_distributional_shift_2.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:14:28.491652Z", - "end_time": "2023-11-24T15:14:30.182328Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -40,6 +279,15 @@ "register_grid_envs()" ] }, + { + "cell_type": "markdown", + "source": [ + "# Off-Policy Distributional Shift - Pt. 2" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -83,12 +331,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:18:25.747995Z", - "end_time": "2023-11-24T15:18:25.928046Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ENV_NAME = CustomEnv.Grid_2D_8x8_discrete\n", @@ -118,12 +361,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:14:46.193491Z", - "end_time": "2023-11-24T15:14:46.200540Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "DATA_SET_IDENTIFIER_I = \"_downwards_\"\n", @@ -145,12 +383,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:14:47.627210Z", - "end_time": "2023-11-24T15:14:48.267078Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "config_combined_data = create_combined_minari_dataset(\n", @@ -213,12 +446,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:14:59.921826Z", - "end_time": "2023-11-24T15:15:00.836537Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "name_combined_dataset = config_combined_data.data_set_name\n", @@ -244,12 +472,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:18:53.363301Z", - "end_time": "2023-11-24T15:18:53.366455Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_NAME = PolicyName.dqn\n", @@ -274,12 +497,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:18:55.331162Z", - "end_time": "2023-11-24T15:19:53.278917Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Run the training\n", @@ -312,12 +530,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T15:18:08.073981Z", - "end_time": "2023-11-24T15:18:08.084727Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_FILE = \"policy_best_reward.pth\"\n", @@ -368,11 +581,9 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], - "source": [], - "metadata": { - "collapsed": false - } + "source": [] } ], "metadata": { diff --git a/notebooks/nb_170_Offline_RL_part_II.ipynb b/notebooks/nb_990_Offline_RL_part_II.ipynb similarity index 66% rename from notebooks/nb_170_Offline_RL_part_II.ipynb rename to notebooks/nb_990_Offline_RL_part_II.ipynb index 155f52c9..23797a7e 100644 --- a/notebooks/nb_170_Offline_RL_part_II.ipynb +++ b/notebooks/nb_990_Offline_RL_part_II.ipynb @@ -1,5 +1,249 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -70,7 +314,7 @@ "\n", "To overcome this issues another approach is to constraint the policies but in their support, i.e. in the space of action where they are defined, as see in the figure below.\n", "\n", - "\"offline_rl_4\"\n", + "\"offline_rl_4\"\n", "\n", "ToDo: Give an example of support matching!! --> see 2023 review." ] diff --git a/notebooks/nb_180_offline_rl_algorithms_theory.ipynb b/notebooks/nb_991_offline_rl_algorithms_theory.ipynb similarity index 52% rename from notebooks/nb_180_offline_rl_algorithms_theory.ipynb rename to notebooks/nb_991_offline_rl_algorithms_theory.ipynb index 4bcceef9..e04e4a75 100644 --- a/notebooks/nb_180_offline_rl_algorithms_theory.ipynb +++ b/notebooks/nb_991_offline_rl_algorithms_theory.ipynb @@ -1,10 +1,254 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Review of some famous offline RL algorithms" + "# Review of some famous offline RL algorithms" ] }, { @@ -37,9 +281,7 @@ }, { "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "Batch Constrained deep Q-learning (BCQ) algorithm\n", "---" @@ -47,9 +289,7 @@ }, { "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "BCQ algorithm tries to solve the problem of distributional shift, and in particular the issues mentioned before during the Q-value evaluation process, i.e.: \n", "\n", @@ -96,9 +336,7 @@ }, { "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "Conservative Q-Learning (CQL) algorithm\n", "---" @@ -106,9 +344,7 @@ }, { "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "CQL follows a pessimistic approach by considering a lower bound of the Q-value. In the paper they show that the solution of:\n", "\n", @@ -125,9 +361,7 @@ }, { "cell_type": "markdown", - "metadata": { - "collapsed": true - }, + "metadata": {}, "source": [ "IMPLICIT Q-LEARNING (IQL):\n", "---\n", diff --git a/notebooks/nb_181_Offline_rl_algorithms_I.ipynb b/notebooks/nb_992_Offline_rl_algorithms_I.ipynb similarity index 61% rename from notebooks/nb_181_Offline_rl_algorithms_I.ipynb rename to notebooks/nb_992_Offline_rl_algorithms_I.ipynb index ec937533..22bac3c2 100644 --- a/notebooks/nb_181_Offline_rl_algorithms_I.ipynb +++ b/notebooks/nb_992_Offline_rl_algorithms_I.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:19:52.191645Z", - "end_time": "2023-11-24T16:19:53.770312Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -39,6 +278,15 @@ "register_grid_envs()" ] }, + { + "cell_type": "markdown", + "source": [ + "# Offline RL in Practice - Pt 1" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -50,7 +298,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Exercise" + "## Exercise" ] }, { @@ -70,12 +318,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:20:19.265572Z", - "end_time": "2023-11-24T16:20:19.269338Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ENV_NAME = CustomEnv.Grid_2D_8x8_discrete\n", @@ -105,12 +348,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:25:34.647009Z", - "end_time": "2023-11-24T16:25:34.688059Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "IDENTIFIER_COMBINED_DATASETS = \"_stiching_property_I\"\n", @@ -137,12 +375,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:25:36.461088Z", - "end_time": "2023-11-24T16:25:40.265348Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "config_combined_data = create_combined_minari_dataset(\n", @@ -206,12 +439,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:28:58.181174Z", - "end_time": "2023-11-24T16:28:58.187720Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# The model policy to be trained.\n", @@ -237,12 +465,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:23:11.441535Z", - "end_time": "2023-11-24T16:24:59.092322Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "NUM_EPOCHS = 20\n", @@ -268,6 +491,7 @@ { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "POLICY_FILE = \"policy.pth\"\n", @@ -278,18 +502,12 @@ "log_name = os.path.join(name_expert_data, POLICY_NAME)\n", "log_path = get_trained_policy_path(log_name)\n", "policy.load_state_dict(torch.load(os.path.join(log_path, POLICY_FILE), map_location=\"cpu\"))\n" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2023-11-24T16:28:22.246178Z", - "end_time": "2023-11-24T16:28:22.292105Z" - } - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "offpolicy_rendering(\n", @@ -300,10 +518,7 @@ " num_frames=1000,\n", " imitation_policy_sampling=False\n", ")" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "markdown", diff --git a/notebooks/nb_182_Offline_rl_algorithms_II.ipynb b/notebooks/nb_992_Offline_rl_algorithms_II.ipynb similarity index 54% rename from notebooks/nb_182_Offline_rl_algorithms_II.ipynb rename to notebooks/nb_992_Offline_rl_algorithms_II.ipynb index c1b7d074..9643d6e9 100644 --- a/notebooks/nb_182_Offline_rl_algorithms_II.ipynb +++ b/notebooks/nb_992_Offline_rl_algorithms_II.ipynb @@ -3,12 +3,251 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "%matplotlib inline\n", + "%load_ext training_rl" + ], "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:32:17.063655Z", - "end_time": "2023-11-24T16:32:18.882627Z" + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "%presentation_style" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%%capture\n", + "\n", + "%set_random_seed 12" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [ + { + "data": { + "text/markdown": [ + "\n", + "$\\newcommand{\\vect}[1]{{\\mathbf{\\boldsymbol{#1}} }}$\n", + "$\\newcommand{\\amax}{{\\text{argmax}}}$\n", + "$\\newcommand{\\P}{{\\mathbb{P}}}$\n", + "$\\newcommand{\\E}{{\\mathbb{E}}}$\n", + "$\\newcommand{\\R}{{\\mathbb{R}}}$\n", + "$\\newcommand{\\Z}{{\\mathbb{Z}}}$\n", + "$\\newcommand{\\N}{{\\mathbb{N}}}$\n", + "$\\newcommand{\\C}{{\\mathbb{C}}}$\n", + "$\\newcommand{\\abs}[1]{{ \\left| #1 \\right| }}$\n", + "$\\newcommand{\\simpl}[1]{{\\Delta^{#1} }}$\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%load_latex_macros" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from training_rl.offline_rl.custom_envs.custom_envs_registration import register_grid_envs\n", @@ -38,6 +277,15 @@ "register_grid_envs()" ] }, + { + "cell_type": "markdown", + "source": [ + "# Offline RL in Practice - Pt. 2" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, @@ -66,12 +314,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:32:29.915543Z", - "end_time": "2023-11-24T16:32:30.082588Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ENV_NAME = CustomEnv.Grid_2D_8x8_discrete\n", @@ -102,12 +345,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:32:31.749744Z", - "end_time": "2023-11-24T16:32:31.754826Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "IDENTIFIER_COMBINED_DATASETS = \"_conservative_test\"\n", @@ -135,12 +373,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:32:32.869322Z", - "end_time": "2023-11-24T16:32:33.481293Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "config_combined_data = create_combined_minari_dataset(\n", @@ -204,12 +437,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:32:37.313786Z", - "end_time": "2023-11-24T16:32:37.318615Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "# The model policy to be trained.\n", @@ -236,13 +464,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false, - "ExecuteTime": { - "start_time": "2023-11-24T16:32:39.722413Z", - "end_time": "2023-11-24T16:34:06.279972Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "NUM_EPOCHS = 20\n", @@ -275,12 +497,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "start_time": "2023-11-24T16:34:31.570542Z", - "end_time": "2023-11-24T16:34:31.616851Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "POLICY_FILE = \"policy_best_reward.pth\"\n", @@ -320,13 +537,45 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Exercise 2:\n", + "## Exercise 2:\n", "\n", "a) Remove the obstacle. What do you think are going to be the results?\n", "\n", "b) Modify the parameters related to distributional shift in BCQ and CQL, and observe their impact on out-of-distribution behavior." ] }, + { + "cell_type": "markdown", + "source": [ + "## Final remarks" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "Offline RL proves valuable in various scenarios, especially when:\n", + "\n", + "a. Robots require intelligent behavior in complex open-world environments demanding extensive training data due to robust visual perception requirements. (complex environment modeling and extensive data collection)\n", + "\n", + "b. Robot grasping tasks, which involve expert data that cannot be accurately simulated, providing an opportunity to assess our BCQ algorithm.\n", + "\n", + "c. Robotic navigation tasks, where offline RL aids in crafting effective navigation policies using real-world data.\n", + "\n", + "d. Autonomous driving, where ample expert data and an offline approach enhance safety.\n", + "\n", + "e. Healthcare applications, where safety is paramount due to the potential serious consequences of inaccurate forecasts.\n", + "\n", + "... and many more. \n", + "\n", + "However, if you have access to an environment with abundant data, online Reinforcement Learning (RL) can be a powerful choice due to its potential for exploration and real-time feedback. Nevertheless, the landscape of RL is evolving, and a data-centric approach is gaining prominence, exemplified by vast datasets like X-Embodiment. It's becoming evident that robots trained with diverse data across various scenarios tend to outperform those solely focused on specific tasks. Furthermore, leveraging multitask trained agents for transfer learning can be a valuable strategy for addressing your specific task at hand." + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, diff --git a/poetry.lock b/poetry.lock index 250cc64a..ebd88b31 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11,6 +11,20 @@ files = [ {file = "absl_py-2.0.0-py3-none-any.whl", hash = "sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3"}, ] +[[package]] +name = "accessible-pygments" +version = "0.0.4" +description = "A collection of accessible pygments styles" +optional = false +python-versions = "*" +files = [ + {file = "accessible-pygments-0.0.4.tar.gz", hash = "sha256:e7b57a9b15958e9601c7e9eb07a440c813283545a20973f2574a5f453d0e953e"}, + {file = "accessible_pygments-0.0.4-py2.py3-none-any.whl", hash = "sha256:416c6d8c1ea1c5ad8701903a20fcedf953c6e720d64f33dc47bfb2d3f2fa4e8d"}, +] + +[package.dependencies] +pygments = ">=1.5" + [[package]] name = "alabaster" version = "0.7.13" @@ -1441,6 +1455,77 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "greenlet" +version = "3.0.2" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.7" +files = [ + {file = "greenlet-3.0.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9acd8fd67c248b8537953cb3af8787c18a87c33d4dcf6830e410ee1f95a63fd4"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:339c0272a62fac7e602e4e6ec32a64ff9abadc638b72f17f6713556ed011d493"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:38878744926cec29b5cc3654ef47f3003f14bfbba7230e3c8492393fe29cc28b"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b3f0497db77cfd034f829678b28267eeeeaf2fc21b3f5041600f7617139e6773"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1a8a08de7f68506a38f9a2ddb26bbd1480689e66d788fcd4b5f77e2d9ecfcc"}, + {file = "greenlet-3.0.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89a6f6ddcbef4000cda7e205c4c20d319488ff03db961d72d4e73519d2465309"}, + {file = "greenlet-3.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c1f647fe5b94b51488b314c82fdda10a8756d650cee8d3cd29f657c6031bdf73"}, + {file = "greenlet-3.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9560c580c896030ff9c311c603aaf2282234643c90d1dec738a1d93e3e53cd51"}, + {file = "greenlet-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2e9c5423046eec21f6651268cb674dfba97280701e04ef23d312776377313206"}, + {file = "greenlet-3.0.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1fd25dfc5879a82103b3d9e43fa952e3026c221996ff4d32a9c72052544835d"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cecfdc950dd25f25d6582952e58521bca749cf3eeb7a9bad69237024308c8196"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:edf7a1daba1f7c54326291a8cde58da86ab115b78c91d502be8744f0aa8e3ffa"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4cf532bf3c58a862196b06947b1b5cc55503884f9b63bf18582a75228d9950e"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e79fb5a9fb2d0bd3b6573784f5e5adabc0b0566ad3180a028af99523ce8f6138"}, + {file = "greenlet-3.0.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:006c1028ac0cfcc4e772980cfe73f5476041c8c91d15d64f52482fc571149d46"}, + {file = "greenlet-3.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fefd5eb2c0b1adffdf2802ff7df45bfe65988b15f6b972706a0e55d451bffaea"}, + {file = "greenlet-3.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0c0fdb8142742ee68e97c106eb81e7d3e883cc739d9c5f2b28bc38a7bafeb6d1"}, + {file = "greenlet-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:8f8d14a0a4e8c670fbce633d8b9a1ee175673a695475acd838e372966845f764"}, + {file = "greenlet-3.0.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:654b84c9527182036747938b81938f1d03fb8321377510bc1854a9370418ab66"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd5bc4fde0842ff2b9cf33382ad0b4db91c2582db836793d58d174c569637144"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27b142a9080bdd5869a2fa7ebf407b3c0b24bd812db925de90e9afe3c417fd6"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0df7eed98ea23b20e9db64d46eb05671ba33147df9405330695bcd81a73bb0c9"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5d60805057d8948065338be6320d35e26b0a72f45db392eb32b70dd6dc9227"}, + {file = "greenlet-3.0.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e0e28f5233d64c693382f66d47c362b72089ebf8ac77df7e12ac705c9fa1163d"}, + {file = "greenlet-3.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e4bfa752b3688d74ab1186e2159779ff4867644d2b1ebf16db14281f0445377"}, + {file = "greenlet-3.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c42bb589e6e9f9d8bdd79f02f044dff020d30c1afa6e84c0b56d1ce8a324553c"}, + {file = "greenlet-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:b2cedf279ca38ef3f4ed0d013a6a84a7fc3d9495a716b84a5fc5ff448965f251"}, + {file = "greenlet-3.0.2-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:6d65bec56a7bc352bcf11b275b838df618651109074d455a772d3afe25390b7d"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0acadbc3f72cb0ee85070e8d36bd2a4673d2abd10731ee73c10222cf2dd4713c"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:14b5d999aefe9ffd2049ad19079f733c3aaa426190ffecadb1d5feacef8fe397"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f27aa32466993c92d326df982c4acccd9530fe354e938d9e9deada563e71ce76"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f34a765c5170c0673eb747213a0275ecc749ab3652bdbec324621ed5b2edaef"}, + {file = "greenlet-3.0.2-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:520fcb53a39ef90f5021c77606952dbbc1da75d77114d69b8d7bded4a8e1a813"}, + {file = "greenlet-3.0.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d1fceb5351ab1601903e714c3028b37f6ea722be6873f46e349a960156c05650"}, + {file = "greenlet-3.0.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:7363756cc439a503505b67983237d1cc19139b66488263eb19f5719a32597836"}, + {file = "greenlet-3.0.2-cp37-cp37m-win32.whl", hash = "sha256:d5547b462b8099b84746461e882a3eb8a6e3f80be46cb6afb8524eeb191d1a30"}, + {file = "greenlet-3.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:950e21562818f9c771989b5b65f990e76f4ac27af66e1bb34634ae67886ede2a"}, + {file = "greenlet-3.0.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d64643317e76b4b41fdba659e7eca29634e5739b8bc394eda3a9127f697ed4b0"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f9ea7c2c9795549653b6f7569f6bc75d2c7d1f6b2854eb8ce0bc6ec3cb2dd88"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db4233358d3438369051a2f290f1311a360d25c49f255a6c5d10b5bcb3aa2b49"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ed9bf77b41798e8417657245b9f3649314218a4a17aefb02bb3992862df32495"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4d0df07a38e41a10dfb62c6fc75ede196572b580f48ee49b9282c65639f3965"}, + {file = "greenlet-3.0.2-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10d247260db20887ae8857c0cbc750b9170f0b067dd7d38fb68a3f2334393bd3"}, + {file = "greenlet-3.0.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a37ae53cca36823597fd5f65341b6f7bac2dd69ecd6ca01334bb795460ab150b"}, + {file = "greenlet-3.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:80d068e4b6e2499847d916ef64176811ead6bf210a610859220d537d935ec6fd"}, + {file = "greenlet-3.0.2-cp38-cp38-win32.whl", hash = "sha256:b1405614692ac986490d10d3e1a05e9734f473750d4bee3cf7d1286ef7af7da6"}, + {file = "greenlet-3.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:8756a94ed8f293450b0e91119eca2a36332deba69feb2f9ca410d35e74eae1e4"}, + {file = "greenlet-3.0.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:2c93cd03acb1499ee4de675e1a4ed8eaaa7227f7949dc55b37182047b006a7aa"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1dac09e3c0b78265d2e6d3cbac2d7c48bd1aa4b04a8ffeda3adde9f1688df2c3"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ee59c4627c8c4bb3e15949fbcd499abd6b7f4ad9e0bfcb62c65c5e2cabe0ec4"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18fe39d70d482b22f0014e84947c5aaa7211fb8e13dc4cc1c43ed2aa1db06d9a"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84bef3cfb6b6bfe258c98c519811c240dbc5b33a523a14933a252e486797c90"}, + {file = "greenlet-3.0.2-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aecea0442975741e7d69daff9b13c83caff8c13eeb17485afa65f6360a045765"}, + {file = "greenlet-3.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f260e6c2337871a52161824058923df2bbddb38bc11a5cbe71f3474d877c5bd9"}, + {file = "greenlet-3.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fc14dd9554f88c9c1fe04771589ae24db76cd56c8f1104e4381b383d6b71aff8"}, + {file = "greenlet-3.0.2-cp39-cp39-win32.whl", hash = "sha256:bfcecc984d60b20ffe30173b03bfe9ba6cb671b0be1e95c3e2056d4fe7006590"}, + {file = "greenlet-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:c235131bf59d2546bb3ebaa8d436126267392f2e51b85ff45ac60f3a26549af0"}, + {file = "greenlet-3.0.2.tar.gz", hash = "sha256:1c1129bc47266d83444c85a8e990ae22688cf05fb20d7951fd2866007c2ba9bc"}, +] + +[package.extras] +docs = ["Sphinx"] +test = ["objgraph", "psutil"] + [[package]] name = "grpcio" version = "1.59.3" @@ -1648,6 +1733,25 @@ files = [ {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, ] +[[package]] +name = "importlib-metadata" +version = "7.0.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-7.0.0-py3-none-any.whl", hash = "sha256:d97503976bb81f40a193d41ee6570868479c69d5068651eb039c40d850c59d67"}, + {file = "importlib_metadata-7.0.0.tar.gz", hash = "sha256:7fc841f8b8332803464e5dc1c63a2e59121f46ca186c0e2e182e80bf8c1319f7"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + [[package]] name = "importlib-resources" version = "6.1.1" @@ -1929,6 +2033,90 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "jupyter" +version = "1.0.0" +description = "Jupyter metapackage. Install all the Jupyter components in one go." +optional = false +python-versions = "*" +files = [ + {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, + {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, + {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"}, +] + +[package.dependencies] +ipykernel = "*" +ipywidgets = "*" +jupyter-console = "*" +nbconvert = "*" +notebook = "*" +qtconsole = "*" + +[[package]] +name = "jupyter-book" +version = "0.15.1" +description = "Build a book with Jupyter Notebooks and Sphinx." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyter-book-0.15.1.tar.gz", hash = "sha256:8a1634ec16f7eedee0d116f1e5fb7c48203289ad92da42e09519dc71d956c010"}, + {file = "jupyter_book-0.15.1-py3-none-any.whl", hash = "sha256:7671264952abd1ca3f5e713b03e138dda710c92a985c49154f398817fe089968"}, +] + +[package.dependencies] +click = ">=7.1,<9" +docutils = ">=0.15,<0.19" +Jinja2 = "*" +jsonschema = "<5" +linkify-it-py = ">=2.0.0,<2.1.0" +myst-nb = ">=0.17.1,<0.18.0" +pyyaml = "*" +sphinx = ">=4,<6" +sphinx-book-theme = ">=1.0.0,<1.1.0" +sphinx-comments = "*" +sphinx-copybutton = "*" +sphinx-design = ">=0.3.0,<0.4.0" +sphinx-external-toc = ">=0.3.1,<0.4.0" +sphinx-jupyterbook-latex = ">=0.5.2,<0.6.0" +sphinx-multitoc-numbering = ">=0.1.3,<0.2.0" +sphinx-thebe = ">=0.2.0,<0.3.0" +sphinx_togglebutton = "*" +sphinxcontrib-bibtex = ">=2.2.0,<=2.5.0" + +[package.extras] +code-style = ["pre-commit (>=3.1,<4.0)"] +pdfhtml = ["pyppeteer"] +sphinx = ["altair", "bokeh", "folium", "ipywidgets", "jupytext", "matplotlib", "nbclient", "numpy", "pandas", "plotly", "sphinx-click", "sphinx-examples", "sphinx-proof", "sphinx_inline_tabs", "sphinxext-rediraffe (>=0.2.3,<0.3.0)", "sympy"] +testing = ["altair", "beautifulsoup4", "beautifulsoup4", "cookiecutter", "coverage", "jupytext", "matplotlib", "pyppeteer", "pytest (>=6.2.4)", "pytest-cov", "pytest-regressions", "pytest-timeout", "pytest-xdist", "sphinx_click", "sphinx_tabs", "texsoup"] + +[[package]] +name = "jupyter-cache" +version = "0.6.1" +description = "A defined interface for working with a cache of jupyter notebooks." +optional = false +python-versions = "~=3.8" +files = [ + {file = "jupyter-cache-0.6.1.tar.gz", hash = "sha256:26f83901143edf4af2f3ff5a91e2d2ad298e46e2cee03c8071d37a23a63ccbfc"}, + {file = "jupyter_cache-0.6.1-py3-none-any.whl", hash = "sha256:2fce7d4975805c77f75bdfc1bc2e82bc538b8e5b1af27f2f5e06d55b9f996a82"}, +] + +[package.dependencies] +attrs = "*" +click = "*" +importlib-metadata = "*" +nbclient = ">=0.2,<0.8" +nbformat = "*" +pyyaml = "*" +sqlalchemy = ">=1.3.12,<3" +tabulate = "*" + +[package.extras] +cli = ["click-log"] +code-style = ["pre-commit (>=2.12,<4.0)"] +rtd = ["ipykernel", "jupytext", "myst-nb", "nbdime", "sphinx-book-theme", "sphinx-copybutton"] +testing = ["coverage", "ipykernel", "jupytext", "matplotlib", "nbdime", "nbformat (>=5.1)", "numpy", "pandas", "pytest (>=6,<8)", "pytest-cov", "pytest-regressions", "sympy"] + [[package]] name = "jupyter-client" version = "7.4.9" @@ -1953,6 +2141,30 @@ traitlets = "*" doc = ["ipykernel", "myst-parser", "sphinx (>=1.3.6)", "sphinx-rtd-theme", "sphinxcontrib-github-alt"] test = ["codecov", "coverage", "ipykernel (>=6.12)", "ipython", "mypy", "pre-commit", "pytest", "pytest-asyncio (>=0.18)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "jupyter-console" +version = "6.6.3" +description = "Jupyter terminal console" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, + {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, +] + +[package.dependencies] +ipykernel = ">=6.14" +ipython = "*" +jupyter-client = ">=7.0.0" +jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" +prompt-toolkit = ">=3.0.30" +pygments = "*" +pyzmq = ">=17" +traitlets = ">=5.4" + +[package.extras] +test = ["flaky", "pexpect", "pytest"] + [[package]] name = "jupyter-contrib-core" version = "0.4.2" @@ -2279,6 +2491,26 @@ files = [ [package.dependencies] six = ">=1.4.1" +[[package]] +name = "linkify-it-py" +version = "2.0.2" +description = "Links recognition library with FULL unicode support." +optional = false +python-versions = ">=3.7" +files = [ + {file = "linkify-it-py-2.0.2.tar.gz", hash = "sha256:19f3060727842c254c808e99d465c80c49d2c7306788140987a1a7a29b0d6ad2"}, + {file = "linkify_it_py-2.0.2-py3-none-any.whl", hash = "sha256:a3a24428f6c96f27370d7fe61d2ac0be09017be5190d68d8658233171f1b6541"}, +] + +[package.dependencies] +uc-micro-py = "*" + +[package.extras] +benchmark = ["pytest", "pytest-benchmark"] +dev = ["black", "flake8", "isort", "pre-commit", "pyproject-flake8"] +doc = ["myst-parser", "sphinx", "sphinx-book-theme"] +test = ["coverage", "pytest", "pytest-cov"] + [[package]] name = "llvmlite" version = "0.40.1" @@ -2436,13 +2668,13 @@ testing = ["coverage", "pyyaml"] [[package]] name = "markdown-it-py" -version = "3.0.0" +version = "2.2.0" description = "Python port of markdown-it. Markdown parsing, done right!" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, - {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, + {file = "markdown-it-py-2.2.0.tar.gz", hash = "sha256:7c9a5e412688bc771c67432cbfebcdd686c93ce6484913dccf06cb5a0bea35a1"}, + {file = "markdown_it_py-2.2.0-py3-none-any.whl", hash = "sha256:5a35f8d1870171d9acc47b99612dc146129b631baf04970128b568f190d0cc30"}, ] [package.dependencies] @@ -2455,7 +2687,7 @@ compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0 linkify = ["linkify-it-py (>=1,<3)"] plugins = ["mdit-py-plugins"] profiling = ["gprof2dot"] -rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] @@ -2579,6 +2811,25 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mdit-py-plugins" +version = "0.3.5" +description = "Collection of plugins for markdown-it-py" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mdit-py-plugins-0.3.5.tar.gz", hash = "sha256:eee0adc7195e5827e17e02d2a258a2ba159944a0748f59c5099a4a27f78fcf6a"}, + {file = "mdit_py_plugins-0.3.5-py3-none-any.whl", hash = "sha256:ca9a0714ea59a24b2b044a1831f48d817dd0c817e84339f20e7889f392d77c4e"}, +] + +[package.dependencies] +markdown-it-py = ">=1.0.0,<3.0.0" + +[package.extras] +code-style = ["pre-commit"] +rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "mdurl" version = "0.1.2" @@ -2762,6 +3013,60 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "myst-nb" +version = "0.17.2" +description = "A Jupyter Notebook Sphinx reader built on top of the MyST markdown parser." +optional = false +python-versions = ">=3.7" +files = [ + {file = "myst-nb-0.17.2.tar.gz", hash = "sha256:0f61386515fab07c73646adca97fff2f69f41e90d313a260217c5bbe419d858b"}, + {file = "myst_nb-0.17.2-py3-none-any.whl", hash = "sha256:132ca4d0f5c308fdd4b6fdaba077712e28e119ccdafd04d6e41b51aac5483494"}, +] + +[package.dependencies] +importlib_metadata = "*" +ipykernel = "*" +ipython = "*" +jupyter-cache = ">=0.5,<0.7" +myst-parser = ">=0.18.0,<0.19.0" +nbclient = "*" +nbformat = ">=5.0,<6.0" +pyyaml = "*" +sphinx = ">=4,<6" +typing-extensions = "*" + +[package.extras] +code-style = ["pre-commit"] +rtd = ["alabaster", "altair", "bokeh", "coconut (>=1.4.3,<2.3.0)", "ipykernel (>=5.5,<6.0)", "ipywidgets", "jupytext (>=1.11.2,<1.12.0)", "matplotlib", "numpy", "pandas", "plotly", "sphinx-book-theme (>=0.3.0,<0.4.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0,<0.5.0)", "sphinxcontrib-bibtex", "sympy"] +testing = ["beautifulsoup4", "coverage (>=6.4,<8.0)", "ipykernel (>=5.5,<6.0)", "ipython (!=8.1.0,<8.5)", "ipywidgets (>=8)", "jupytext (>=1.11.2,<1.12.0)", "matplotlib (>=3.5.3,<3.6)", "nbdime", "numpy", "pandas", "pytest (>=7.1,<8.0)", "pytest-cov (>=3,<5)", "pytest-param-files (>=0.3.3,<0.4.0)", "pytest-regressions", "sympy (>=1.10.1)"] + +[[package]] +name = "myst-parser" +version = "0.18.1" +description = "An extended commonmark compliant parser, with bridges to docutils & sphinx." +optional = false +python-versions = ">=3.7" +files = [ + {file = "myst-parser-0.18.1.tar.gz", hash = "sha256:79317f4bb2c13053dd6e64f9da1ba1da6cd9c40c8a430c447a7b146a594c246d"}, + {file = "myst_parser-0.18.1-py3-none-any.whl", hash = "sha256:61b275b85d9f58aa327f370913ae1bec26ebad372cc99f3ab85c8ec3ee8d9fb8"}, +] + +[package.dependencies] +docutils = ">=0.15,<0.20" +jinja2 = "*" +markdown-it-py = ">=1.0.0,<3.0.0" +mdit-py-plugins = ">=0.3.1,<0.4.0" +pyyaml = "*" +sphinx = ">=4,<6" +typing-extensions = "*" + +[package.extras] +code-style = ["pre-commit (>=2.12,<3.0)"] +linkify = ["linkify-it-py (>=1.0,<2.0)"] +rtd = ["ipython", "sphinx-book-theme", "sphinx-design", "sphinxcontrib.mermaid (>=0.7.1,<0.8.0)", "sphinxext-opengraph (>=0.6.3,<0.7.0)", "sphinxext-rediraffe (>=0.2.7,<0.3.0)"] +testing = ["beautifulsoup4", "coverage[toml]", "pytest (>=6,<7)", "pytest-cov", "pytest-param-files (>=0.3.4,<0.4.0)", "pytest-regressions", "sphinx (<5.2)", "sphinx-pytest"] + [[package]] name = "nbclassic" version = "1.0.0" @@ -2799,25 +3104,25 @@ test = ["coverage", "nbval", "pytest", "pytest-cov", "pytest-jupyter", "pytest-p [[package]] name = "nbclient" -version = "0.9.0" +version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.7.0" files = [ - {file = "nbclient-0.9.0-py3-none-any.whl", hash = "sha256:a3a1ddfb34d4a9d17fc744d655962714a866639acd30130e9be84191cd97cd15"}, - {file = "nbclient-0.9.0.tar.gz", hash = "sha256:4b28c207877cf33ef3a9838cdc7a54c5ceff981194a82eac59d558f05487295e"}, + {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, + {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, ] [package.dependencies] jupyter-client = ">=6.1.12" jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" nbformat = ">=5.1" -traitlets = ">=5.4" +traitlets = ">=5.3" [package.extras] dev = ["pre-commit"] docs = ["autodoc-traits", "mock", "moto", "myst-parser", "nbclient[test]", "sphinx (>=1.7)", "sphinx-book-theme", "sphinxcontrib-spelling"] -test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] +test = ["flaky", "ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=7.0)", "pytest-asyncio", "pytest-cov (>=4.0)", "testpath", "xmltodict"] [[package]] name = "nbconvert" @@ -3846,6 +4151,33 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +[[package]] +name = "pydata-sphinx-theme" +version = "0.14.4" +description = "Bootstrap-based Sphinx theme from the PyData community" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydata_sphinx_theme-0.14.4-py3-none-any.whl", hash = "sha256:ac15201f4c2e2e7042b0cad8b30251433c1f92be762ddcefdb4ae68811d918d9"}, + {file = "pydata_sphinx_theme-0.14.4.tar.gz", hash = "sha256:f5d7a2cb7a98e35b9b49d3b02cec373ad28958c2ed5c9b1ffe6aff6c56e9de5b"}, +] + +[package.dependencies] +accessible-pygments = "*" +Babel = "*" +beautifulsoup4 = "*" +docutils = "!=0.17.0" +packaging = "*" +pygments = ">=2.7" +sphinx = ">=5.0" +typing-extensions = "*" + +[package.extras] +a11y = ["pytest-playwright"] +dev = ["nox", "pre-commit", "pydata-sphinx-theme[doc,test]", "pyyaml"] +doc = ["ablog (>=0.11.0rc2)", "colorama", "ipykernel", "ipyleaflet", "jupyter_sphinx", "jupyterlite-sphinx", "linkify-it-py", "matplotlib", "myst-parser", "nbsphinx", "numpy", "numpydoc", "pandas", "plotly", "rich", "sphinx-autoapi (>=3.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-favicon (>=1.0.1)", "sphinx-sitemap", "sphinx-togglebutton", "sphinxcontrib-youtube (<1.4)", "sphinxext-rediraffe", "xarray"] +test = ["pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "pyenchant" version = "3.2.2" @@ -4293,6 +4625,48 @@ files = [ cffi = {version = "*", markers = "implementation_name == \"pypy\""} py = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "qtconsole" +version = "5.5.1" +description = "Jupyter Qt console" +optional = false +python-versions = ">= 3.8" +files = [ + {file = "qtconsole-5.5.1-py3-none-any.whl", hash = "sha256:8c75fa3e9b4ed884880ff7cea90a1b67451219279ec33deaee1d59e3df1a5d2b"}, + {file = "qtconsole-5.5.1.tar.gz", hash = "sha256:a0e806c6951db9490628e4df80caec9669b65149c7ba40f9bf033c025a5b56bc"}, +] + +[package.dependencies] +ipykernel = ">=4.1" +jupyter-client = ">=4.1" +jupyter-core = "*" +packaging = "*" +pygments = "*" +pyzmq = ">=17.1" +qtpy = ">=2.4.0" +traitlets = "<5.2.1 || >5.2.1,<5.2.2 || >5.2.2" + +[package.extras] +doc = ["Sphinx (>=1.3)"] +test = ["flaky", "pytest", "pytest-qt"] + +[[package]] +name = "qtpy" +version = "2.4.1" +description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." +optional = false +python-versions = ">=3.7" +files = [ + {file = "QtPy-2.4.1-py3-none-any.whl", hash = "sha256:1c1d8c4fa2c884ae742b069151b0abe15b3f70491f3972698c683b8e38de839b"}, + {file = "QtPy-2.4.1.tar.gz", hash = "sha256:a5a15ffd519550a1361bdc56ffc07fda56a6af7292f17c7b395d4083af632987"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] + [[package]] name = "referencing" version = "0.31.0" @@ -4740,26 +5114,26 @@ files = [ [[package]] name = "sphinx" -version = "6.2.1" +version = "5.0.2" description = "Python documentation generator" optional = false -python-versions = ">=3.8" +python-versions = ">=3.6" files = [ - {file = "Sphinx-6.2.1.tar.gz", hash = "sha256:6d56a34697bb749ffa0152feafc4b19836c755d90a7c59b72bc7dfd371b9cc6b"}, - {file = "sphinx-6.2.1-py3-none-any.whl", hash = "sha256:97787ff1fa3256a3eef9eda523a63dbf299f7b47e053cfcf684a1c2a8380c912"}, + {file = "Sphinx-5.0.2-py3-none-any.whl", hash = "sha256:d3e57663eed1d7c5c50895d191fdeda0b54ded6f44d5621b50709466c338d1e8"}, + {file = "Sphinx-5.0.2.tar.gz", hash = "sha256:b18e978ea7565720f26019c702cd85c84376e948370f1cd43d60265010e1c7b0"}, ] [package.dependencies] alabaster = ">=0.7,<0.8" -babel = ">=2.9" -colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -docutils = ">=0.18.1,<0.20" -imagesize = ">=1.3" -Jinja2 = ">=3.0" -packaging = ">=21.0" -Pygments = ">=2.13" -requests = ">=2.25.0" -snowballstemmer = ">=2.0" +babel = ">=1.3" +colorama = {version = ">=0.3.5", markers = "sys_platform == \"win32\""} +docutils = ">=0.14,<0.19" +imagesize = "*" +Jinja2 = ">=2.3" +packaging = "*" +Pygments = ">=2.0" +requests = ">=2.5.0" +snowballstemmer = ">=1.1" sphinxcontrib-applehelp = "*" sphinxcontrib-devhelp = "*" sphinxcontrib-htmlhelp = ">=2.0.0" @@ -4769,8 +5143,148 @@ sphinxcontrib-serializinghtml = ">=1.1.5" [package.extras] docs = ["sphinxcontrib-websupport"] -lint = ["docutils-stubs", "flake8 (>=3.5.0)", "flake8-simplify", "isort", "mypy (>=0.990)", "ruff", "sphinx-lint", "types-requests"] -test = ["cython", "filelock", "html5lib", "pytest (>=4.6)"] +lint = ["docutils-stubs", "flake8 (>=3.5.0)", "isort", "mypy (>=0.950)", "types-requests", "types-typed-ast"] +test = ["cython", "html5lib", "pytest (>=4.6)", "typed-ast"] + +[[package]] +name = "sphinx-book-theme" +version = "1.0.1" +description = "A clean book theme for scientific explanations and documentation with Sphinx" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sphinx_book_theme-1.0.1-py3-none-any.whl", hash = "sha256:d15f8248b3718a9a6be0ba617a32d1591f9fa39c614469bface777ba06a73b75"}, + {file = "sphinx_book_theme-1.0.1.tar.gz", hash = "sha256:927b399a6906be067e49c11ef1a87472f1b1964075c9eea30fb82c64b20aedee"}, +] + +[package.dependencies] +pydata-sphinx-theme = ">=0.13.3" +sphinx = ">=4,<7" + +[package.extras] +code-style = ["pre-commit"] +doc = ["ablog", "docutils (==0.17.1)", "folium", "ipywidgets", "matplotlib", "myst-nb", "nbclient", "numpy", "numpydoc", "pandas", "plotly", "sphinx-copybutton", "sphinx-design", "sphinx-examples", "sphinx-tabs (<=3.4.0)", "sphinx-thebe", "sphinx-togglebutton", "sphinxcontrib-bibtex", "sphinxcontrib-youtube", "sphinxext-opengraph"] +test = ["beautifulsoup4", "coverage", "myst-nb", "pytest", "pytest-cov", "pytest-regressions", "sphinx_thebe"] + +[[package]] +name = "sphinx-comments" +version = "0.0.3" +description = "Add comments and annotation to your documentation." +optional = false +python-versions = "*" +files = [ + {file = "sphinx-comments-0.0.3.tar.gz", hash = "sha256:00170afff27019fad08e421da1ae49c681831fb2759786f07c826e89ac94cf21"}, + {file = "sphinx_comments-0.0.3-py3-none-any.whl", hash = "sha256:1e879b4e9bfa641467f83e3441ac4629225fc57c29995177d043252530c21d00"}, +] + +[package.dependencies] +sphinx = ">=1.8" + +[package.extras] +code-style = ["black", "flake8 (>=3.7.0,<3.8.0)", "pre-commit (==1.17.0)"] +sphinx = ["myst-parser", "sphinx (>=2)", "sphinx-book-theme"] +testing = ["beautifulsoup4", "myst-parser", "pytest", "pytest-regressions", "sphinx (>=2)", "sphinx-book-theme"] + +[[package]] +name = "sphinx-copybutton" +version = "0.5.2" +description = "Add a copy button to each of your code cells." +optional = false +python-versions = ">=3.7" +files = [ + {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"}, + {file = "sphinx_copybutton-0.5.2-py3-none-any.whl", hash = "sha256:fb543fd386d917746c9a2c50360c7905b605726b9355cd26e9974857afeae06e"}, +] + +[package.dependencies] +sphinx = ">=1.8" + +[package.extras] +code-style = ["pre-commit (==2.12.1)"] +rtd = ["ipython", "myst-nb", "sphinx", "sphinx-book-theme", "sphinx-examples"] + +[[package]] +name = "sphinx-design" +version = "0.3.0" +description = "A sphinx extension for designing beautiful, view size responsive web components." +optional = false +python-versions = ">=3.7" +files = [ + {file = "sphinx_design-0.3.0-py3-none-any.whl", hash = "sha256:823c1dd74f31efb3285ec2f1254caefed29d762a40cd676f58413a1e4ed5cc96"}, + {file = "sphinx_design-0.3.0.tar.gz", hash = "sha256:7183fa1fae55b37ef01bda5125a21ee841f5bbcbf59a35382be598180c4cefba"}, +] + +[package.dependencies] +sphinx = ">=4,<6" + +[package.extras] +code-style = ["pre-commit (>=2.12,<3.0)"] +rtd = ["myst-parser (>=0.18.0,<0.19.0)"] +testing = ["myst-parser (>=0.18.0,<0.19.0)", "pytest (>=7.1,<8.0)", "pytest-cov", "pytest-regressions"] +theme-furo = ["furo (>=2022.06.04,<2022.07)"] +theme-pydata = ["pydata-sphinx-theme (>=0.9.0,<0.10.0)"] +theme-rtd = ["sphinx-rtd-theme (>=1.0,<2.0)"] +theme-sbt = ["sphinx-book-theme (>=0.3.0,<0.4.0)"] + +[[package]] +name = "sphinx-external-toc" +version = "0.3.1" +description = "A sphinx extension that allows the site-map to be defined in a single YAML file." +optional = false +python-versions = "~=3.7" +files = [ + {file = "sphinx_external_toc-0.3.1-py3-none-any.whl", hash = "sha256:cd93c1e7599327b2a728db12d9819068ce719c4b037ffc62e47f20ffb6310fb3"}, + {file = "sphinx_external_toc-0.3.1.tar.gz", hash = "sha256:9c8ea9980ea0e57bf3ce98f6a400f9b69eb1df808f7dd796c9c8cc1873d8b355"}, +] + +[package.dependencies] +click = ">=7.1,<9" +pyyaml = "*" +sphinx = ">=4,<6" + +[package.extras] +code-style = ["pre-commit (>=2.12,<3.0)"] +rtd = ["myst-parser (>=0.17.0,<0.18.0)", "sphinx-book-theme (>=0.0.36)"] +testing = ["coverage", "pytest (>=7.1,<8.0)", "pytest-cov", "pytest-regressions"] + +[[package]] +name = "sphinx-jupyterbook-latex" +version = "0.5.2" +description = "Latex specific features for jupyter book" +optional = false +python-versions = ">=3.6" +files = [ + {file = "sphinx_jupyterbook_latex-0.5.2-py3-none-any.whl", hash = "sha256:24de689689ddc27c736b15b91c6b9afdcdc31570938572693bb05bfff8f50758"}, + {file = "sphinx_jupyterbook_latex-0.5.2.tar.gz", hash = "sha256:da1d3ad028f55ddbf10b9130bb9f24fc60cafb671cbd39dfd95537aafc90972e"}, +] + +[package.dependencies] +sphinx = ">=4,<5.1" + +[package.extras] +code-style = ["pre-commit (>=2.12,<3.0)"] +myst = ["myst-nb (>=0.13,<0.18)"] +rtd = ["myst-parser (<=0.18)", "sphinx-book-theme", "sphinx-design", "sphinx-jupyterbook-latex"] +testing = ["coverage (<5.0)", "myst-nb (>=0.13,<0.18)", "pytest (>=3.6,<4)", "pytest-cov (>=2.8,<3.0)", "pytest-regressions", "sphinx-external-toc (>=0.1.0,<0.3.0)", "sphinxcontrib-bibtex (>=2.2.1,<2.3.0)", "texsoup"] + +[[package]] +name = "sphinx-multitoc-numbering" +version = "0.1.3" +description = "Supporting continuous HTML section numbering" +optional = false +python-versions = "*" +files = [ + {file = "sphinx-multitoc-numbering-0.1.3.tar.gz", hash = "sha256:c9607671ac511236fa5d61a7491c1031e700e8d498c9d2418e6c61d1251209ae"}, + {file = "sphinx_multitoc_numbering-0.1.3-py3-none-any.whl", hash = "sha256:33d2e707a9b2b8ad636b3d4302e658a008025106fe0474046c651144c26d8514"}, +] + +[package.dependencies] +sphinx = ">=3" + +[package.extras] +code-style = ["black", "flake8 (>=3.7.0,<3.8.0)", "pre-commit (==1.17.0)"] +rtd = ["myst-parser", "sphinx (>=3.0)", "sphinx-book-theme"] +testing = ["coverage (<5.0)", "jupyter-book", "pytest (>=5.4,<6.0)", "pytest-cov (>=2.8,<3.0)", "pytest-regressions"] [[package]] name = "sphinx-rtd-theme" @@ -4791,6 +5305,44 @@ sphinxcontrib-jquery = ">=4,<5" [package.extras] dev = ["bump2version", "sphinxcontrib-httpdomain", "transifex-client", "wheel"] +[[package]] +name = "sphinx-thebe" +version = "0.2.1" +description = "Integrate interactive code blocks into your documentation with Thebe and Binder." +optional = false +python-versions = "*" +files = [ + {file = "sphinx-thebe-0.2.1.tar.gz", hash = "sha256:f4c8c1542054f991b73fcb28c4cf21697e42aba2f83f22348c1c851b82766583"}, + {file = "sphinx_thebe-0.2.1-py3-none-any.whl", hash = "sha256:e8af555c90acba3541fa7108ea5981ae9c4bd406b54d9a242ab054d326ab7441"}, +] + +[package.dependencies] +sphinx = ">=4,<7" + +[package.extras] +sphinx = ["matplotlib", "myst-nb", "sphinx-book-theme (>=0.4.0rc1)", "sphinx-copybutton", "sphinx-design"] +testing = ["beautifulsoup4", "matplotlib", "pytest", "pytest-regressions"] + +[[package]] +name = "sphinx-togglebutton" +version = "0.3.2" +description = "Toggle page content and collapse admonitions in Sphinx." +optional = false +python-versions = "*" +files = [ + {file = "sphinx-togglebutton-0.3.2.tar.gz", hash = "sha256:ab0c8b366427b01e4c89802d5d078472c427fa6e9d12d521c34fa0442559dc7a"}, + {file = "sphinx_togglebutton-0.3.2-py3-none-any.whl", hash = "sha256:9647ba7874b7d1e2d43413d8497153a85edc6ac95a3fea9a75ef9c1e08aaae2b"}, +] + +[package.dependencies] +docutils = "*" +setuptools = "*" +sphinx = "*" +wheel = "*" + +[package.extras] +sphinx = ["matplotlib", "myst-nb", "numpy", "sphinx-book-theme", "sphinx-design", "sphinx-examples"] + [[package]] name = "sphinxcontrib-applehelp" version = "1.0.7" @@ -4944,6 +5496,93 @@ Sphinx = ">=3.0.0" [package.extras] test = ["coverage (>=4.0,!=4.4)", "pytest", "pytest-cov"] +[[package]] +name = "sqlalchemy" +version = "2.0.23" +description = "Database Abstraction Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:638c2c0b6b4661a4fd264f6fb804eccd392745c5887f9317feb64bb7cb03b3ea"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e3b5036aa326dc2df50cba3c958e29b291a80f604b1afa4c8ce73e78e1c9f01d"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:787af80107fb691934a01889ca8f82a44adedbf5ef3d6ad7d0f0b9ac557e0c34"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c14eba45983d2f48f7546bb32b47937ee2cafae353646295f0e99f35b14286ab"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0666031df46b9badba9bed00092a1ffa3aa063a5e68fa244acd9f08070e936d3"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89a01238fcb9a8af118eaad3ffcc5dedaacbd429dc6fdc43fe430d3a941ff965"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win32.whl", hash = "sha256:cabafc7837b6cec61c0e1e5c6d14ef250b675fa9c3060ed8a7e38653bd732ff8"}, + {file = "SQLAlchemy-2.0.23-cp310-cp310-win_amd64.whl", hash = "sha256:87a3d6b53c39cd173990de2f5f4b83431d534a74f0e2f88bd16eabb5667e65c6"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d5578e6863eeb998980c212a39106ea139bdc0b3f73291b96e27c929c90cd8e1"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:62d9e964870ea5ade4bc870ac4004c456efe75fb50404c03c5fd61f8bc669a72"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c80c38bd2ea35b97cbf7c21aeb129dcbebbf344ee01a7141016ab7b851464f8e"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75eefe09e98043cff2fb8af9796e20747ae870c903dc61d41b0c2e55128f958d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bd45a5b6c68357578263d74daab6ff9439517f87da63442d244f9f23df56138d"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a86cb7063e2c9fb8e774f77fbf8475516d270a3e989da55fa05d08089d77f8c4"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win32.whl", hash = "sha256:b41f5d65b54cdf4934ecede2f41b9c60c9f785620416e8e6c48349ab18643855"}, + {file = "SQLAlchemy-2.0.23-cp311-cp311-win_amd64.whl", hash = "sha256:9ca922f305d67605668e93991aaf2c12239c78207bca3b891cd51a4515c72e22"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0f7fb0c7527c41fa6fcae2be537ac137f636a41b4c5a4c58914541e2f436b45"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c424983ab447dab126c39d3ce3be5bee95700783204a72549c3dceffe0fc8f4"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f508ba8f89e0a5ecdfd3761f82dda2a3d7b678a626967608f4273e0dba8f07ac"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6463aa765cf02b9247e38b35853923edbf2f6fd1963df88706bc1d02410a5577"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e599a51acf3cc4d31d1a0cf248d8f8d863b6386d2b6782c5074427ebb7803bda"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fd54601ef9cc455a0c61e5245f690c8a3ad67ddb03d3b91c361d076def0b4c60"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win32.whl", hash = "sha256:42d0b0290a8fb0165ea2c2781ae66e95cca6e27a2fbe1016ff8db3112ac1e846"}, + {file = "SQLAlchemy-2.0.23-cp312-cp312-win_amd64.whl", hash = "sha256:227135ef1e48165f37590b8bfc44ed7ff4c074bf04dc8d6f8e7f1c14a94aa6ca"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:14aebfe28b99f24f8a4c1346c48bc3d63705b1f919a24c27471136d2f219f02d"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e983fa42164577d073778d06d2cc5d020322425a509a08119bdcee70ad856bf"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0dc9031baa46ad0dd5a269cb7a92a73284d1309228be1d5935dac8fb3cae24"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5f94aeb99f43729960638e7468d4688f6efccb837a858b34574e01143cf11f89"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:63bfc3acc970776036f6d1d0e65faa7473be9f3135d37a463c5eba5efcdb24c8"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-win32.whl", hash = "sha256:f48ed89dd11c3c586f45e9eec1e437b355b3b6f6884ea4a4c3111a3358fd0c18"}, + {file = "SQLAlchemy-2.0.23-cp37-cp37m-win_amd64.whl", hash = "sha256:1e018aba8363adb0599e745af245306cb8c46b9ad0a6fc0a86745b6ff7d940fc"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64ac935a90bc479fee77f9463f298943b0e60005fe5de2aa654d9cdef46c54df"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c4722f3bc3c1c2fcc3702dbe0016ba31148dd6efcd2a2fd33c1b4897c6a19693"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4af79c06825e2836de21439cb2a6ce22b2ca129bad74f359bddd173f39582bf5"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683ef58ca8eea4747737a1c35c11372ffeb84578d3aab8f3e10b1d13d66f2bc4"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d4041ad05b35f1f4da481f6b811b4af2f29e83af253bf37c3c4582b2c68934ab"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aeb397de65a0a62f14c257f36a726945a7f7bb60253462e8602d9b97b5cbe204"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win32.whl", hash = "sha256:42ede90148b73fe4ab4a089f3126b2cfae8cfefc955c8174d697bb46210c8306"}, + {file = "SQLAlchemy-2.0.23-cp38-cp38-win_amd64.whl", hash = "sha256:964971b52daab357d2c0875825e36584d58f536e920f2968df8d581054eada4b"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:616fe7bcff0a05098f64b4478b78ec2dfa03225c23734d83d6c169eb41a93e55"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e680527245895aba86afbd5bef6c316831c02aa988d1aad83c47ffe92655e74"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9585b646ffb048c0250acc7dad92536591ffe35dba624bb8fd9b471e25212a35"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4895a63e2c271ffc7a81ea424b94060f7b3b03b4ea0cd58ab5bb676ed02f4221"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cc1d21576f958c42d9aec68eba5c1a7d715e5fc07825a629015fe8e3b0657fb0"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:967c0b71156f793e6662dd839da54f884631755275ed71f1539c95bbada9aaab"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win32.whl", hash = "sha256:0a8c6aa506893e25a04233bc721c6b6cf844bafd7250535abb56cb6cc1368884"}, + {file = "SQLAlchemy-2.0.23-cp39-cp39-win_amd64.whl", hash = "sha256:f3420d00d2cb42432c1d0e44540ae83185ccbbc67a6054dcc8ab5387add6620b"}, + {file = "SQLAlchemy-2.0.23-py3-none-any.whl", hash = "sha256:31952bbc527d633b9479f5f81e8b9dfada00b91d6baba021a869095f1a97006d"}, + {file = "SQLAlchemy-2.0.23.tar.gz", hash = "sha256:c1bda93cbbe4aa2aa0aa8655c5aeda505cd219ff3e8da91d1d329e143e4aff69"}, +] + +[package.dependencies] +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +typing-extensions = ">=4.2.0" + +[package.extras] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] +aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] +asyncio = ["greenlet (!=0.4.17)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mssql = ["pyodbc"] +mssql-pymssql = ["pymssql"] +mssql-pyodbc = ["pyodbc"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] +mysql-connector = ["mysql-connector-python"] +oracle = ["cx-oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] +postgresql = ["psycopg2 (>=2.7)"] +postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] +postgresql-psycopg2binary = ["psycopg2-binary"] +postgresql-psycopg2cffi = ["psycopg2cffi"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] +sqlcipher = ["sqlcipher3-binary"] + [[package]] name = "stack-data" version = "0.6.3" @@ -4977,6 +5616,20 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tensorboard" version = "2.14.1" @@ -5317,6 +5970,20 @@ files = [ {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, ] +[[package]] +name = "uc-micro-py" +version = "1.0.2" +description = "Micro subset of unicode data files for linkify-it-py projects." +optional = false +python-versions = ">=3.7" +files = [ + {file = "uc-micro-py-1.0.2.tar.gz", hash = "sha256:30ae2ac9c49f39ac6dce743bd187fcd2b574b16ca095fa74cd9396795c954c54"}, + {file = "uc_micro_py-1.0.2-py3-none-any.whl", hash = "sha256:8c9110c309db9d9e87302e2f4ad2c3152770930d88ab385cd544e7a7e75f3de0"}, +] + +[package.extras] +test = ["coverage", "pytest", "pytest-cov"] + [[package]] name = "uri-template" version = "1.3.0" @@ -5457,6 +6124,20 @@ MarkupSafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "wheel" +version = "0.42.0" +description = "A built-package format for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "wheel-0.42.0-py3-none-any.whl", hash = "sha256:177f9c9b0d45c47873b619f5b650346d632cdc35fb5e4d25058e09c9e581433d"}, + {file = "wheel-0.42.0.tar.gz", hash = "sha256:c45be39f7882c9d34243236f2d63cbd58039e360f85d0913425fbd7ceea617a8"}, +] + +[package.extras] +test = ["pytest (>=6.0.0)", "setuptools (>=65)"] + [[package]] name = "widgetsnbextension" version = "4.0.9" @@ -5486,4 +6167,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "bcfb471d536479389f438454c51ddc35386491ee710713f41ef4334760369186" +content-hash = "5111800c86939adc71b8d95e5cd284b57bf944aaeecc54cc2ea827f6ace93830" diff --git a/pyproject.toml b/pyproject.toml index 46f4214f..9cade725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,14 @@ include=["src/training_rl/assets"] python = "^3.11" ipykernel = "^6.25.2" ipywidgets = "^8.1.1" +jupyter = "^1.0.0" +jupyter-book = "^0.15.1" jupyter-contrib-nbextensions = "^0.7.0" +matplotlib = "^3.8.0" notebook = "<7.0.0" rise = "^5.7.1" -matplotlib = "^3.8.0" seaborn = "^0.13.0" +tianshou = {git = "https://github.com/thu-ml/tianshou.git", rev = "8d3d1f1"} traitlets = "5.9.0" # special sauce b/c of a flaky bug in poetry on windows # see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926 @@ -39,12 +42,10 @@ virtualenv = [ { version = "^20.4.3,!=20.4.5,!=20.4.6" }, { version = "<20.16.4", markers = "sys_platform == 'win32'" }, ] -tianshou = {git = "https://github.com/thu-ml/tianshou.git", rev = "8d3d1f1"} [tool.poetry.group.add1] optional = true [tool.poetry.group.add1.dependencies] -pettingzoo = "^1.24.1" jsonargparse = "^4.25.0" numba = "^0.57.1" # b/c of numba @@ -52,6 +53,7 @@ numpy = "<=1.24" overrides = "^7.4.0" packaging = "*" pandas = {extras = ["performance"], version = "^2.1.0"} +pettingzoo = "^1.24.1" [tool.poetry.group.add2] optional = true @@ -70,18 +72,18 @@ optional = true [tool.poetry.group.control.dependencies] control = "^0.9.4" do-mpc = "^4.6.1" -mediapy = "^1.1.9" gymnasium = {extras = ["classic-control", "mujoco"], version = "^0.28.0"} +mediapy = "^1.1.9" networkx = "^3.1" [tool.poetry.group.offline] optional = true [tool.poetry.group.offline.dependencies] -minari = "^0.4.2" chardet ="*" +minari = "^0.4.2" opencv-python="*" -tensorboardX="*" pygame = "*" +tensorboardX="*" [tool.poetry.group.dev]