diff --git a/notebooks/16_rddl_tuto.ipynb b/notebooks/16_rddl_tuto.ipynb
new file mode 100644
index 0000000000..e641bbee1e
--- /dev/null
+++ b/notebooks/16_rddl_tuto.ipynb
@@ -0,0 +1,604 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "source": [
+ "# Using RDDL domains and solvers with scikit-decide"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this notebook, we demonstrate how to use the RDLL scikit-decide wrapper domain in order to solve it with scikit-decide solvers. This domain is built upon the RDDL environment from the excellent pyrddlgym-project GitHub project. Some of the solvers tested here are actually also wrapped from the same project but we will see also how to use other solvers (coded directly within scikit-decide or wrapped from other third party libraries)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Concerning the python kernel to use for this notebook:\n",
+ "- If running locally, be sure to use an environment with scikit-decide[all].\n",
+ "- If running on colab, the next cell does it for you.\n",
+ "- If running on binder, the environment should be ready."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# On Colab: install the library\n",
+ "on_colab = \"google.colab\" in str(get_ipython())\n",
+ "if on_colab:\n",
+ " import glob\n",
+ " import json\n",
+ " import sys\n",
+ "\n",
+ " using_nightly_version = True\n",
+ "\n",
+ " if using_nightly_version:\n",
+ " # look for nightly build download url\n",
+ " release_curl_res = !curl -L -H \"Accept: application/vnd.github+json\" -H \"X-GitHub-Api-Version: 2022-11-28\" https://api.github.com/repos/airbus/scikit-decide/releases/tags/nightly\n",
+ " release_dict = json.loads(release_curl_res.s)\n",
+ " release_download_url = sorted(\n",
+ " release_dict[\"assets\"], key=lambda d: d[\"updated_at\"]\n",
+ " )[-1][\"browser_download_url\"]\n",
+ " print(release_download_url)\n",
+ "\n",
+ " # download and unzip\n",
+ " !wget --output-document=release.zip {release_download_url}\n",
+ " !unzip -o release.zip\n",
+ "\n",
+ " # get proper wheel name according to python version used\n",
+ " wheel_pythonversion_tag = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n",
+ " wheel_path = glob.glob(\n",
+ " f\"dist/scikit_decide*{wheel_pythonversion_tag}*manylinux*.whl\"\n",
+ " )[0]\n",
+ "\n",
+ " skdecide_pip_spec = f\"{wheel_path}[all]\"\n",
+ " else:\n",
+ " skdecide_pip_spec = \"scikit-decide[all]\"\n",
+ "\n",
+ " # uninstall google protobuf conflicting with ray and sb3\n",
+ " ! pip uninstall -y protobuf\n",
+ "\n",
+ " # install scikit-decide with all extras\n",
+ " !pip install {skdecide_pip_spec}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "import os\n",
+ "import shutil\n",
+ "\n",
+ "from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator\n",
+ "from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv\n",
+ "from ray.rllib.algorithms.ppo import PPO as RLLIB_PPO\n",
+ "from rddlrepository.archive.competitions.IPPC2023.MountainCar.MountainCarViz import (\n",
+ " MountainCarVisualizer,\n",
+ ")\n",
+ "from rddlrepository.archive.standalone.Elevators.ElevatorViz import ElevatorVisualizer\n",
+ "from rddlrepository.archive.standalone.Quadcopter.QuadcopterViz import (\n",
+ " QuadcopterVisualizer,\n",
+ ")\n",
+ "from rddlrepository.core.manager import RDDLRepoManager\n",
+ "from stable_baselines3 import PPO as SB3_PPO\n",
+ "\n",
+ "from skdecide.hub.domain.rddl import RDDLDomain, RDDLDomainSimplifiedSpaces\n",
+ "from skdecide.hub.solver.cgp import CGP\n",
+ "from skdecide.hub.solver.ray_rllib import RayRLlib\n",
+ "from skdecide.hub.solver.rddl import RDDLGurobiSolver, RDDLJaxSolver\n",
+ "from skdecide.hub.solver.stable_baselines import StableBaseline\n",
+ "from skdecide.utils import rollout"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Instantiating and visualizing a RDDL domain"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The pyrddlgym-project provides the [rddlrepository](https://github.com/pyrddlgym-project/rddlrepository) library of RDDL benchmarks from past IPPC competitions and third-party contributors. We list below the available problems with our pip installation of the library."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "manager = RDDLRepoManager(rebuild=True)\n",
+ "print(sorted(manager.list_problems()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will use 3 different rddl benchmarks here to demonstrate the scikit-decide integration of pyrddlgym:\n",
+ "- MountainCar_ippc2023\n",
+ "- Quadcopter\n",
+ "- Elevators"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's create the scikit-decide `RDDLDomain` instance and render it.\n",
+ "Note that here we use some options to display within the notebook:\n",
+ "- `display_with_pygame`: True by default (as in pyRDDLGym), here set to False to avoid a pygame window to pop up\n",
+ "- `display_within_jupyter`: useful to display within a jupyter notebook\n",
+ "- `visualizer`: we use a visualizer dedicated to the chosen benchmark\n",
+ "- `movie_name`: if set, a movie will be created at the end of a rollout "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"MountainCar_ippc2023\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = MountainCarVisualizer\n",
+ "domain = RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ " movie_name=None, # here left empty because not used in a roll-out\n",
+ ")\n",
+ "domain.reset()\n",
+ "img = domain.render()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"Quadcopter\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = QuadcopterVisualizer\n",
+ "domain = RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ ")\n",
+ "domain.reset()\n",
+ "img = domain.render()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"Elevators\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = ElevatorVisualizer\n",
+ "domain = RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ ")\n",
+ "domain.reset()\n",
+ "img = domain.render()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Solving the domain with scikit-decide (potentially bridged) solvers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now comes the fun part: solving the domain with scikit-decide solvers, some of them - especially the reinforcement learning ones - being bridged to state-of-the-art existing libraries (e.g. RLlib, SB3). You will see that once the domain is defined, solving it takes very few lines of code."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### RL solvers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "First, we create the domain factory for the benchmark \"MountainCar_ippc2023\". For these RL solvers, we need the underlying rddl env to use the base class `SimplifiedActionRDDLEnv` from [pyRDDLGym-rl](https://github.com/pyrddlgym-project/pyRDDLGym-rl), which uses gym spaces tractable by RL algorithms. This is done thanks to the argument `base_class`, which will be passed directly to `pyRDDLgym.make()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"MountainCar_ippc2023\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = MountainCarVisualizer\n",
+ "\n",
+ "domain_factory_rl = lambda alg_name=None: RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " base_class=SimplifiedActionRDDLEnv,\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### RLlib's PPO algorithm"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The code below creates a scikit-decide's `RayRLlib` solver, then it calls the `solver.solve()` method, and it finally rollout the optimized policy by using scikit-decide's `rollout` utility function. The latter function will render the solution and the domain will generate a movie in the `rddl_movies` folder when reaching the termination condition of the rollout episode."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "solver_factory = lambda: RayRLlib(\n",
+ " domain_factory=domain_factory_rl, algo_class=RLLIB_PPO, train_iterations=10\n",
+ ")\n",
+ "\n",
+ "with solver_factory() as solver:\n",
+ " solver.solve()\n",
+ " rollout(\n",
+ " domain_factory_rl(alg_name=\"RLLIB-PPO\"),\n",
+ " solver,\n",
+ " max_steps=300,\n",
+ " render=True,\n",
+ " verbose=False,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here is an example of executing the RLlib's PPO policy trained for 100 iterations on the mountain car benchmark:\n",
+ "\n",
+ "![RLLIB PPO example solution](rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### StableBaselines-3's PPO"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Once the domain is defined, very few lines of code are sufficient to test another solver whose capabilities are compatible with the domain. In the cell below, we now test Stablebaselines-3's PPO algorithm."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "solver_factory = lambda: StableBaseline(\n",
+ " domain_factory=domain_factory_rl,\n",
+ " algo_class=SB3_PPO,\n",
+ " baselines_policy=\"MultiInputPolicy\",\n",
+ " learn_config={\"total_timesteps\": 10000},\n",
+ " verbose=0,\n",
+ ")\n",
+ "\n",
+ "with solver_factory() as solver:\n",
+ " solver.solve()\n",
+ " rollout(\n",
+ " domain_factory_rl(alg_name=\"SB3-PPO\"),\n",
+ " solver,\n",
+ " max_steps=1000,\n",
+ " render=True,\n",
+ " verbose=False,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### CGP"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Scikit-decide provides an implementation of [Cartesian Genetic Programming](https://dl.acm.org/doi/10.1145/3205455.3205578) (CGP), a form of Genetic Programming which optimizes a function (e.g. control policy) by learning its best representation as a directed acyclic graph of mathematical operators. One of the great capabilities of scikit-decide is to provide simple high-level means to compare algorithms from different communities (RL, GP, search, planning, etc.) on the same domains with few lines of code.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Since our current implementation of CGP in scikit-decide does not handle complex observation spaces such as the dictionary spaces returned by the RDDL simulator, we used instead `RDDLDomainSimplifiedSpaces` where all actions and observations are numpy arrays thanks to the powerful `flatten` and `flatten_space` methods of `gymnasium`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We call the CGP solver on this simplified domain and we render the obtained solution after a few iterations (including the generation of the video in the `rddl_movies` folder)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "editable": true,
+ "slideshow": {
+ "slide_type": ""
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "problem_name = \"MountainCar_ippc2023\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = MountainCarVisualizer\n",
+ "\n",
+ "domain_factory_cgp = lambda alg_name=None: RDDLDomainSimplifiedSpaces(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " base_class=SimplifiedActionRDDLEnv,\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n",
+ " max_frames=200,\n",
+ ")\n",
+ "\n",
+ "if os.path.exists(\"TEMP_CGP\"):\n",
+ " shutil.rmtree(\"TEMP_CGP\")\n",
+ "\n",
+ "solver_factory = lambda: CGP(\n",
+ " domain_factory=domain_factory_cgp, folder_name=\"TEMP_CGP\", n_it=25, verbose=False\n",
+ ")\n",
+ "with solver_factory() as solver:\n",
+ " solver.solve()\n",
+ " rollout(\n",
+ " domain_factory_cgp(\"CGP\"), solver, max_steps=200, render=True, verbose=False\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here is an example of executing the CGP policy on the mountain car benchmark:\n",
+ "\n",
+ "![CGP example solution](rddl_images/MountainCar_ippc2023-CGP_example.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Solving the domain with pyRDDLGym solvers wrapped in scikit-decide\n",
+ "\n",
+ "One can also use the solvers implemented in pyRDDLGym project from within scikit-decide like the jax planner (https://github.com/pyrddlgym-project/pyRDDLGym-jax), or the gurobi planner (https://github.com/pyrddlgym-project/pyRDDLGym-gurobi)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### JAX Agent\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The scikit-decide solver `RDDLJaxSolver` wraps the offline version of [JaxPlan](https://openreview.net/forum?id=7IKtmUpLEH) planner which compiles the RDDL model to a Jax computation graph allowing for planning by backpropagation. \n",
+ "The solver constructor takes a configuration file of the `Jax` planner as explained [here](https://github.com/pyrddlgym-project/pyRDDLGym-jax/tree/main?tab=readme-ov-file#writing-a-configuration-file-for-a-custom-domain).\n",
+ "\n",
+ "We apply it to the becnhmark \"Quadcopter\". \n",
+ "\n",
+ "Note that for this solver the domain needs\n",
+ "- to use the simulation backend specific to Jax,\n",
+ "- to be vectorized. \n",
+ "\n",
+ "This is done thanks to the arguments `backend` and `vectorized` which are passed to `pyRDDLGym.make()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"Quadcopter\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = QuadcopterVisualizer\n",
+ "\n",
+ "if not os.path.exists(\"Quadcopter_slp.cfg\"):\n",
+ " !wget https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg\n",
+ "\n",
+ "domain_factory_jax_agent = lambda alg_name=None: RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(1),\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ " backend=JaxRDDLSimulator,\n",
+ " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n",
+ " max_frames=500,\n",
+ " vectorized=True,\n",
+ ")\n",
+ "\n",
+ "assert RDDLJaxSolver.check_domain(domain_factory_jax_agent())\n",
+ "\n",
+ "logging.getLogger(\"matplotlib.font_manager\").disabled = True\n",
+ "with RDDLJaxSolver(\n",
+ " domain_factory=domain_factory_jax_agent, config=\"Quadcopter_slp.cfg\"\n",
+ ") as solver:\n",
+ " solver.solve()\n",
+ " rollout(\n",
+ " domain_factory_jax_agent(alg_name=\"JaxAgent\"),\n",
+ " solver,\n",
+ " max_steps=500,\n",
+ " render=True,\n",
+ " max_framerate=5,\n",
+ " verbose=False,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We obtain the following example execution of the Jax policy, which clearly converges towards the goal (quadcopters flying towards the red triangle):\n",
+ "\n",
+ "![JaxAgent example solution](rddl_images/Quadcopter-JaxAgent_example.gif)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Gurobi Agent"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We finally try the online version of [GurobiPlan](https://openreview.net/forum?id=7IKtmUpLEH) planner which compiles the RDDL model to a [Gurobi](https://www.gurobi.com) MILP model. \n",
+ "\n",
+ "We apply it to \"Elevators\" benchmark. \n",
+ "\n",
+ "\n",
+ "
Note: \n",
+ "The solver needs a real license for Gurobi, as the free license available when installing gurobipy from PyPi is not sufficient to solve this domain.\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "problem_name = \"Elevators\"\n",
+ "problem_info = manager.get_problem(problem_name)\n",
+ "problem_visualizer = ElevatorVisualizer\n",
+ "\n",
+ "domain_factory_gurobi_agent = lambda alg_name=None: RDDLDomain(\n",
+ " rddl_domain=problem_info.get_domain(),\n",
+ " rddl_instance=problem_info.get_instance(0),\n",
+ " visualizer=problem_visualizer,\n",
+ " display_with_pygame=False,\n",
+ " display_within_jupyter=True,\n",
+ " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n",
+ " max_frames=50,\n",
+ ")\n",
+ "\n",
+ "assert RDDLGurobiSolver.check_domain(domain_factory_gurobi_agent())\n",
+ "\n",
+ "with RDDLGurobiSolver(\n",
+ " domain_factory=domain_factory_gurobi_agent, rollout_horizon=10\n",
+ ") as solver:\n",
+ " solver.solve()\n",
+ " rollout(\n",
+ " domain_factory_gurobi_agent(alg_name=\"GurobiAgent\"),\n",
+ " solver,\n",
+ " max_steps=50,\n",
+ " render=True,\n",
+ " max_framerate=5,\n",
+ " verbose=False,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here is an example of executing the online `GurobiPlan` strategy on this benchmark:\n",
+ "\n",
+ "![GurobiAgent example solution](rddl_images/Elevators-GurobiAgent_example.gif)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/rddl_images/Elevators-GurobiAgent_example.gif b/notebooks/rddl_images/Elevators-GurobiAgent_example.gif
new file mode 100644
index 0000000000..fccdb9678b
Binary files /dev/null and b/notebooks/rddl_images/Elevators-GurobiAgent_example.gif differ
diff --git a/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif b/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif
new file mode 100644
index 0000000000..548fbafe5c
Binary files /dev/null and b/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif differ
diff --git a/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif b/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif
new file mode 100644
index 0000000000..773cb7a1b0
Binary files /dev/null and b/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif differ
diff --git a/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif b/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif
new file mode 100644
index 0000000000..99b10f84ac
Binary files /dev/null and b/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif differ
diff --git a/notebooks/rddl_images/cgp-sketch.png b/notebooks/rddl_images/cgp-sketch.png
new file mode 100644
index 0000000000..2beec97d8c
Binary files /dev/null and b/notebooks/rddl_images/cgp-sketch.png differ
diff --git a/poetry.lock b/poetry.lock
index 8198224a70..d6a324cfc6 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -54,6 +54,21 @@ six = ">=1.12.0"
astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"]
test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"]
+[[package]]
+name = "astunparse"
+version = "1.6.3"
+description = "An AST unparser for Python"
+optional = true
+python-versions = "*"
+files = [
+ {file = "astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8"},
+ {file = "astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872"},
+]
+
+[package.dependencies]
+six = ">=1.6.1,<2.0"
+wheel = ">=0.23.0,<1.0"
+
[[package]]
name = "atomicwrites"
version = "1.4.1"
@@ -83,6 +98,23 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi
tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"]
tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"]
+[[package]]
+name = "bayesian-optimization"
+version = "2.0.0"
+description = "Bayesian Optimization package"
+optional = true
+python-versions = "<4.0,>=3.9"
+files = [
+ {file = "bayesian_optimization-2.0.0-py3-none-any.whl", hash = "sha256:7dc2b1b78dadab9d38c07f3b7c5d7fe91a138bce0ec71e895085642ddf4c85b3"},
+ {file = "bayesian_optimization-2.0.0.tar.gz", hash = "sha256:434abeb87e4a59f285641fbbfd74606fa27603dd831fff7c8984e5d10391e0d0"},
+]
+
+[package.dependencies]
+colorama = ">=0.4.6,<0.5.0"
+numpy = ">=1.25"
+scikit-learn = ">=1.0.0,<2.0.0"
+scipy = ">=1.0.0,<2.0.0"
+
[[package]]
name = "cartopy"
version = "0.23.0"
@@ -329,6 +361,26 @@ files = [
{file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"},
]
+[[package]]
+name = "chex"
+version = "0.1.87"
+description = "Chex: Testing made fun, in JAX!"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16"},
+ {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"},
+]
+
+[package.dependencies]
+absl-py = ">=0.9.0"
+jax = ">=0.4.27"
+jaxlib = ">=0.4.27"
+numpy = ">=1.24.1"
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+toolz = ">=0.9.0"
+typing-extensions = ">=4.2.0"
+
[[package]]
name = "click"
version = "8.1.7"
@@ -876,6 +928,27 @@ typing-extensions = ">=4.4"
quantum = ["qiskit (>=1.0.2)", "qiskit-aer (>=0.14.1)", "qiskit-algorithms (>=0.3.0)", "qiskit-ibm-runtime (>=0.24)", "qiskit-optimization (>=0.6.1)"]
test = ["optuna", "pytest", "pytest-cov", "scikit-learn (>=1.0)"]
+[[package]]
+name = "dm-haiku"
+version = "0.0.13"
+description = "Haiku is a library for building neural networks in JAX."
+optional = true
+python-versions = "*"
+files = [
+ {file = "dm_haiku-0.0.13-py3-none-any.whl", hash = "sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43"},
+ {file = "dm_haiku-0.0.13.tar.gz", hash = "sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d"},
+]
+
+[package.dependencies]
+absl-py = ">=0.7.1"
+jmp = ">=0.0.2"
+numpy = ">=1.18.0"
+tabulate = ">=0.8.9"
+
+[package.extras]
+flax = ["flax (>=0.7.1)"]
+jax = ["jax (>=0.4.28)", "jaxlib (>=0.4.28)"]
+
[[package]]
name = "dm-tree"
version = "0.1.8"
@@ -941,6 +1014,40 @@ files = [
{file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"},
]
+[[package]]
+name = "etils"
+version = "1.5.2"
+description = "Collection of common python utils"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "etils-1.5.2-py3-none-any.whl", hash = "sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b"},
+ {file = "etils-1.5.2.tar.gz", hash = "sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446"},
+]
+
+[package.dependencies]
+typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""}
+
+[package.extras]
+all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"]
+array-types = ["etils[enp]"]
+dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"]
+docs = ["etils[all,dev]", "sphinx-apitree[ext]"]
+eapp = ["absl-py", "etils[epy]", "simple_parsing"]
+ecolab = ["etils[enp]", "etils[epy]", "jupyter", "mediapy", "numpy", "packaging"]
+edc = ["etils[epy]"]
+enp = ["etils[epy]", "numpy"]
+epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"]
+epath-gcs = ["etils[epath]", "gcsfs"]
+epath-s3 = ["etils[epath]", "s3fs"]
+epy = ["typing_extensions"]
+etqdm = ["absl-py", "etils[epy]", "tqdm"]
+etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"]
+etree-dm = ["dm-tree", "etils[etree]"]
+etree-jax = ["etils[etree]", "jax[cpu]"]
+etree-tf = ["etils[etree]", "tensorflow"]
+lazy-imports = ["etils[ecolab]"]
+
[[package]]
name = "exceptiongroup"
version = "1.2.2"
@@ -1010,6 +1117,17 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.
testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"]
typing = ["typing-extensions (>=4.12.2)"]
+[[package]]
+name = "flatbuffers"
+version = "24.3.25"
+description = "The FlatBuffers serialization format for Python"
+optional = true
+python-versions = "*"
+files = [
+ {file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"},
+ {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
+]
+
[[package]]
name = "fonttools"
version = "4.54.1"
@@ -1221,6 +1339,99 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,
test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
tqdm = ["tqdm"]
+[[package]]
+name = "gast"
+version = "0.6.0"
+description = "Python AST that abstracts the underlying Python version"
+optional = true
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
+files = [
+ {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"},
+ {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"},
+]
+
+[[package]]
+name = "google-pasta"
+version = "0.2.0"
+description = "pasta is an AST-based Python refactoring library"
+optional = true
+python-versions = "*"
+files = [
+ {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"},
+ {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"},
+ {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"},
+]
+
+[package.dependencies]
+six = "*"
+
+[[package]]
+name = "grpcio"
+version = "1.67.0"
+description = "HTTP/2-based RPC framework"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "grpcio-1.67.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:bd79929b3bb96b54df1296cd3bf4d2b770bd1df6c2bdf549b49bab286b925cdc"},
+ {file = "grpcio-1.67.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:16724ffc956ea42967f5758c2f043faef43cb7e48a51948ab593570570d1e68b"},
+ {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:2b7183c80b602b0ad816315d66f2fb7887614ead950416d60913a9a71c12560d"},
+ {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efe32b45dd6d118f5ea2e5deaed417d8a14976325c93812dd831908522b402c9"},
+ {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe89295219b9c9e47780a0f1c75ca44211e706d1c598242249fe717af3385ec8"},
+ {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa8d025fae1595a207b4e47c2e087cb88d47008494db258ac561c00877d4c8f8"},
+ {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f95e15db43e75a534420e04822df91f645664bf4ad21dfaad7d51773c80e6bb4"},
+ {file = "grpcio-1.67.0-cp310-cp310-win32.whl", hash = "sha256:a6b9a5c18863fd4b6624a42e2712103fb0f57799a3b29651c0e5b8119a519d65"},
+ {file = "grpcio-1.67.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6eb68493a05d38b426604e1dc93bfc0137c4157f7ab4fac5771fd9a104bbaa6"},
+ {file = "grpcio-1.67.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:e91d154689639932305b6ea6f45c6e46bb51ecc8ea77c10ef25aa77f75443ad4"},
+ {file = "grpcio-1.67.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cb204a742997277da678611a809a8409657b1398aaeebf73b3d9563b7d154c13"},
+ {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:ae6de510f670137e755eb2a74b04d1041e7210af2444103c8c95f193340d17ee"},
+ {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74b900566bdf68241118f2918d312d3bf554b2ce0b12b90178091ea7d0a17b3d"},
+ {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e95e43447a02aa603abcc6b5e727d093d161a869c83b073f50b9390ecf0fa8"},
+ {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0bb94e66cd8f0baf29bd3184b6aa09aeb1a660f9ec3d85da615c5003154bc2bf"},
+ {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:82e5bd4b67b17c8c597273663794a6a46a45e44165b960517fe6d8a2f7f16d23"},
+ {file = "grpcio-1.67.0-cp311-cp311-win32.whl", hash = "sha256:7fc1d2b9fd549264ae585026b266ac2db53735510a207381be509c315b4af4e8"},
+ {file = "grpcio-1.67.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac11ecb34a86b831239cc38245403a8de25037b448464f95c3315819e7519772"},
+ {file = "grpcio-1.67.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:227316b5631260e0bef8a3ce04fa7db4cc81756fea1258b007950b6efc90c05d"},
+ {file = "grpcio-1.67.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d90cfdafcf4b45a7a076e3e2a58e7bc3d59c698c4f6470b0bb13a4d869cf2273"},
+ {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:77196216d5dd6f99af1c51e235af2dd339159f657280e65ce7e12c1a8feffd1d"},
+ {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15c05a26a0f7047f720da41dc49406b395c1470eef44ff7e2c506a47ac2c0591"},
+ {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3840994689cc8cbb73d60485c594424ad8adb56c71a30d8948d6453083624b52"},
+ {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a1e03c3102b6451028d5dc9f8591131d6ab3c8a0e023d94c28cb930ed4b5f81"},
+ {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:682968427a63d898759474e3b3178d42546e878fdce034fd7474ef75143b64e3"},
+ {file = "grpcio-1.67.0-cp312-cp312-win32.whl", hash = "sha256:d01793653248f49cf47e5695e0a79805b1d9d4eacef85b310118ba1dfcd1b955"},
+ {file = "grpcio-1.67.0-cp312-cp312-win_amd64.whl", hash = "sha256:985b2686f786f3e20326c4367eebdaed3e7aa65848260ff0c6644f817042cb15"},
+ {file = "grpcio-1.67.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:8c9a35b8bc50db35ab8e3e02a4f2a35cfba46c8705c3911c34ce343bd777813a"},
+ {file = "grpcio-1.67.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:42199e704095b62688998c2d84c89e59a26a7d5d32eed86d43dc90e7a3bd04aa"},
+ {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:c4c425f440fb81f8d0237c07b9322fc0fb6ee2b29fbef5f62a322ff8fcce240d"},
+ {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:323741b6699cd2b04a71cb38f502db98f90532e8a40cb675393d248126a268af"},
+ {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:662c8e105c5e5cee0317d500eb186ed7a93229586e431c1bf0c9236c2407352c"},
+ {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f6bd2ab135c64a4d1e9e44679a616c9bc944547357c830fafea5c3caa3de5153"},
+ {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:2f55c1e0e2ae9bdd23b3c63459ee4c06d223b68aeb1961d83c48fb63dc29bc03"},
+ {file = "grpcio-1.67.0-cp313-cp313-win32.whl", hash = "sha256:fd6bc27861e460fe28e94226e3673d46e294ca4673d46b224428d197c5935e69"},
+ {file = "grpcio-1.67.0-cp313-cp313-win_amd64.whl", hash = "sha256:cf51d28063338608cd8d3cd64677e922134837902b70ce00dad7f116e3998210"},
+ {file = "grpcio-1.67.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:7f200aca719c1c5dc72ab68be3479b9dafccdf03df530d137632c534bb6f1ee3"},
+ {file = "grpcio-1.67.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0892dd200ece4822d72dd0952f7112c542a487fc48fe77568deaaa399c1e717d"},
+ {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:f4d613fbf868b2e2444f490d18af472ccb47660ea3df52f068c9c8801e1f3e85"},
+ {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c69bf11894cad9da00047f46584d5758d6ebc9b5950c0dc96fec7e0bce5cde9"},
+ {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9bca3ca0c5e74dea44bf57d27e15a3a3996ce7e5780d61b7c72386356d231db"},
+ {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:014dfc020e28a0d9be7e93a91f85ff9f4a87158b7df9952fe23cc42d29d31e1e"},
+ {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d4ea4509d42c6797539e9ec7496c15473177ce9abc89bc5c71e7abe50fc25737"},
+ {file = "grpcio-1.67.0-cp38-cp38-win32.whl", hash = "sha256:9d75641a2fca9ae1ae86454fd25d4c298ea8cc195dbc962852234d54a07060ad"},
+ {file = "grpcio-1.67.0-cp38-cp38-win_amd64.whl", hash = "sha256:cff8e54d6a463883cda2fab94d2062aad2f5edd7f06ae3ed030f2a74756db365"},
+ {file = "grpcio-1.67.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:62492bd534979e6d7127b8a6b29093161a742dee3875873e01964049d5250a74"},
+ {file = "grpcio-1.67.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eef1dce9d1a46119fd09f9a992cf6ab9d9178b696382439446ca5f399d7b96fe"},
+ {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f623c57a5321461c84498a99dddf9d13dac0e40ee056d884d6ec4ebcab647a78"},
+ {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54d16383044e681f8beb50f905249e4e7261dd169d4aaf6e52eab67b01cbbbe2"},
+ {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2a44e572fb762c668e4812156b81835f7aba8a721b027e2d4bb29fb50ff4d33"},
+ {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:391df8b0faac84d42f5b8dfc65f5152c48ed914e13c522fd05f2aca211f8bfad"},
+ {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfd9306511fdfc623a1ba1dc3bc07fbd24e6cfbe3c28b4d1e05177baa2f99617"},
+ {file = "grpcio-1.67.0-cp39-cp39-win32.whl", hash = "sha256:30d47dbacfd20cbd0c8be9bfa52fdb833b395d4ec32fe5cff7220afc05d08571"},
+ {file = "grpcio-1.67.0-cp39-cp39-win_amd64.whl", hash = "sha256:f55f077685f61f0fbd06ea355142b71e47e4a26d2d678b3ba27248abfe67163a"},
+ {file = "grpcio-1.67.0.tar.gz", hash = "sha256:e090b2553e0da1c875449c8e75073dd4415dd71c9bde6a406240fdf4c0ee467c"},
+]
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.67.0)"]
+
[[package]]
name = "gymnasium"
version = "0.28.1"
@@ -1254,6 +1465,44 @@ other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-pyt
testing = ["pytest (==7.1.3)", "scipy (==1.7.3)"]
toy-text = ["pygame (==2.1.3)", "pygame (==2.1.3)"]
+[[package]]
+name = "h5py"
+version = "3.12.1"
+description = "Read and write HDF5 files from Python"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "h5py-3.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f0f1a382cbf494679c07b4371f90c70391dedb027d517ac94fa2c05299dacda"},
+ {file = "h5py-3.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cb65f619dfbdd15e662423e8d257780f9a66677eae5b4b3fc9dca70b5fd2d2a3"},
+ {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2"},
+ {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307"},
+ {file = "h5py-3.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:577d618d6b6dea3da07d13cc903ef9634cde5596b13e832476dd861aaf651f3e"},
+ {file = "h5py-3.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ccd9006d92232727d23f784795191bfd02294a4f2ba68708825cb1da39511a93"},
+ {file = "h5py-3.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad8a76557880aed5234cfe7279805f4ab5ce16b17954606cca90d578d3e713ef"},
+ {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e"},
+ {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166"},
+ {file = "h5py-3.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fdf95092d60e8130ba6ae0ef7a9bd4ade8edbe3569c13ebbaf39baefffc5ba4"},
+ {file = "h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed"},
+ {file = "h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351"},
+ {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834"},
+ {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9"},
+ {file = "h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc"},
+ {file = "h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc"},
+ {file = "h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653"},
+ {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32"},
+ {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f"},
+ {file = "h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8"},
+ {file = "h5py-3.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2b8dd64f127d8b324f5d2cd1c0fd6f68af69084e9e47d27efeb9e28e685af3e"},
+ {file = "h5py-3.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4532c7e97fbef3d029735db8b6f5bf01222d9ece41e309b20d63cfaae2fb5c4d"},
+ {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdf6d7936fa824acfa27305fe2d9f39968e539d831c5bae0e0d83ed521ad1ac"},
+ {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84342bffd1f82d4f036433e7039e241a243531a1d3acd7341b35ae58cdab05bf"},
+ {file = "h5py-3.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:62be1fc0ef195891949b2c627ec06bc8e837ff62d5b911b6e42e38e0f20a897d"},
+ {file = "h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf"},
+]
+
+[package.dependencies]
+numpy = ">=1.19.3"
+
[[package]]
name = "idna"
version = "3.10"
@@ -1438,6 +1687,41 @@ qtconsole = ["qtconsole"]
test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"]
test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"]
+[[package]]
+name = "jax"
+version = "0.4.30"
+description = "Differentiate, compile, and transform Numpy code."
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "jax-0.4.30-py3-none-any.whl", hash = "sha256:289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b"},
+ {file = "jax-0.4.30.tar.gz", hash = "sha256:94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""}
+jaxlib = ">=0.4.27,<=0.4.30"
+ml-dtypes = ">=0.2.0"
+numpy = [
+ {version = ">=1.26.0", markers = "python_version >= \"3.12\""},
+ {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
+ {version = ">=1.22", markers = "python_version < \"3.11\""},
+]
+opt-einsum = "*"
+scipy = [
+ {version = ">=1.11.1", markers = "python_version >= \"3.12\""},
+ {version = ">=1.9", markers = "python_version < \"3.12\""},
+]
+
+[package.extras]
+ci = ["jaxlib (==0.4.29)"]
+cuda = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"]
+cuda12 = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"]
+cuda12-local = ["jax-cuda12-plugin (==0.4.30)", "jaxlib (==0.4.30)"]
+cuda12-pip = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"]
+minimum-jaxlib = ["jaxlib (==0.4.27)"]
+tpu = ["jaxlib (==0.4.30)", "libtpu-nightly (==0.1.dev20240617)", "requests"]
+
[[package]]
name = "jax-jumpy"
version = "1.0.0"
@@ -1456,6 +1740,43 @@ numpy = ">=1.18.0"
jax = ["jax (>=0.3.24)", "jaxlib (>=0.3.24)"]
testing = ["pytest (==7.1.3)"]
+[[package]]
+name = "jaxlib"
+version = "0.4.30"
+description = "XLA library for JAX"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "jaxlib-0.4.30-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c40856e28f300938c6824ab1a615166193d6997dec946578823f6d402ad454e5"},
+ {file = "jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bdfda6a3c7a2b0cc0a7131009eb279e98ca4a6f25679fabb5302dd135a5e349"},
+ {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:28e032c9b394ab7624d89b0d9d3bbcf4d1d71694fe8b3e09d3fe64122eda7b0c"},
+ {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d83f36ef42a403bbf7c7f2da526b34ba286988e170f4df5e58b3bb735417868c"},
+ {file = "jaxlib-0.4.30-cp310-cp310-win_amd64.whl", hash = "sha256:a56678b28f96b524ded6da8ef4b38e72a532356d139cfd434da804abf4234e14"},
+ {file = "jaxlib-0.4.30-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:bfb5d85b69c29c3c6e8051a0ea715ac1e532d6e54494c8d9c3813dcc00deac30"},
+ {file = "jaxlib-0.4.30-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:974998cd8a78550402e6c09935c1f8d850cad9cc19ccd7488bde45b6f7f99c12"},
+ {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e93eb0646b41ba213252b51b0b69096b9cd1d81a35ea85c9d06663b5d11efe45"},
+ {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:16b2ab18ea90d2e15941bcf45de37afc2f289a029129c88c8d7aba0404dd0043"},
+ {file = "jaxlib-0.4.30-cp311-cp311-win_amd64.whl", hash = "sha256:3a2e2c11c179f8851a72249ba1ae40ae817dfaee9877d23b3b8f7c6b7a012f76"},
+ {file = "jaxlib-0.4.30-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7704db5962b32a2be3cc07185433cbbcc94ed90ee50c84021a3f8a1ecfd66ee3"},
+ {file = "jaxlib-0.4.30-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:57090d33477fd0f0c99dc686274882ea75c44c7d712ae42dd2460b10f896131d"},
+ {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:0a3850e76278038e21685975a62b622bcf3708485f13125757a0561ee4512940"},
+ {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:c58a8071c4e00898282118169f6a5a97eb15a79c2897858f3a732b17891c99ab"},
+ {file = "jaxlib-0.4.30-cp312-cp312-win_amd64.whl", hash = "sha256:b7079a5b1ab6864a7d4f2afaa963841451186d22c90f39719a3ff85735ce3915"},
+ {file = "jaxlib-0.4.30-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ea3a00005faafbe3c18b178d3b534208b3b4027b2be6230227e7b87ce399fc29"},
+ {file = "jaxlib-0.4.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d31e01191ce8052bd611aaf16ff967d8d0ec0b63f1ea4b199020cecb248d667"},
+ {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:11602d5556e8baa2f16314c36518e9be4dfae0c2c256a361403fb29dc9dc79a4"},
+ {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18"},
+ {file = "jaxlib-0.4.30-cp39-cp39-win_amd64.whl", hash = "sha256:54987e97a22db70f3829b437b9329e4799d653634bacc8b398554d3b90c76b2a"},
+]
+
+[package.dependencies]
+ml-dtypes = ">=0.2.0"
+numpy = ">=1.22"
+scipy = [
+ {version = ">=1.11.1", markers = "python_version >= \"3.12\""},
+ {version = ">=1.9", markers = "python_version < \"3.12\""},
+]
+
[[package]]
name = "jedi"
version = "0.19.1"
@@ -1492,6 +1813,23 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
+[[package]]
+name = "jmp"
+version = "0.0.4"
+description = "JMP is a Mixed Precision library for JAX."
+optional = true
+python-versions = "*"
+files = [
+ {file = "jmp-0.0.4-py3-none-any.whl", hash = "sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d"},
+ {file = "jmp-0.0.4.tar.gz", hash = "sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730"},
+]
+
+[package.dependencies]
+numpy = ">=1.19.5"
+
+[package.extras]
+jax = ["jax (>=0.2.20)", "jaxlib (>=0.1.71)"]
+
[[package]]
name = "joblib"
version = "1.4.2"
@@ -1581,6 +1919,27 @@ traitlets = ">=5.3"
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"]
test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"]
+[[package]]
+name = "keras"
+version = "3.6.0"
+description = "Multi-backend Keras."
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "keras-3.6.0-py3-none-any.whl", hash = "sha256:49585e4577f6e86bd890d96dfbcb1890f5bab5967ef831c07fd63f9d86e4bfe9"},
+ {file = "keras-3.6.0.tar.gz", hash = "sha256:405727525a3522ed8f9ec0b46e0667e4c65fcf714a067322c16a00d902ded41d"},
+]
+
+[package.dependencies]
+absl-py = "*"
+h5py = "*"
+ml-dtypes = "*"
+namex = "*"
+numpy = "*"
+optree = "*"
+packaging = "*"
+rich = "*"
+
[[package]]
name = "kiwisolver"
version = "1.4.7"
@@ -1723,6 +2082,25 @@ dev = ["changelist (==0.5)"]
lint = ["pre-commit (==3.7.0)"]
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
+[[package]]
+name = "libclang"
+version = "18.1.1"
+description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier."
+optional = true
+python-versions = "*"
+files = [
+ {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"},
+ {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"},
+ {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"},
+ {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"},
+ {file = "libclang-18.1.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592"},
+ {file = "libclang-18.1.1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe"},
+ {file = "libclang-18.1.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f"},
+ {file = "libclang-18.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb"},
+ {file = "libclang-18.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8"},
+ {file = "libclang-18.1.1.tar.gz", hash = "sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250"},
+]
+
[[package]]
name = "llvmlite"
version = "0.43.0"
@@ -1803,6 +2181,24 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"]
flake8 = ["flake8"]
tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"]
+[[package]]
+name = "markdown"
+version = "3.7"
+description = "Python implementation of John Gruber's Markdown."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"},
+ {file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"},
+]
+
+[package.dependencies]
+importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
+
+[package.extras]
+docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"]
+testing = ["coverage", "pyyaml"]
+
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@@ -2000,6 +2396,43 @@ files = [
[package.extras]
dzn = ["lark-parser (>=0.12.0,<0.13.0)"]
+[[package]]
+name = "ml-dtypes"
+version = "0.4.1"
+description = ""
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "ml_dtypes-0.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1fe8b5b5e70cd67211db94b05cfd58dace592f24489b038dc6f9fe347d2e07d5"},
+ {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c09a6d11d8475c2a9fd2bc0695628aec105f97cab3b3a3fb7c9660348ff7d24"},
+ {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5e8f75fa371020dd30f9196e7d73babae2abd51cf59bdd56cb4f8de7e13354"},
+ {file = "ml_dtypes-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:15fdd922fea57e493844e5abb930b9c0bd0af217d9edd3724479fc3d7ce70e3f"},
+ {file = "ml_dtypes-0.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2d55b588116a7085d6e074cf0cdb1d6fa3875c059dddc4d2c94a4cc81c23e975"},
+ {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e138a9b7a48079c900ea969341a5754019a1ad17ae27ee330f7ebf43f23877f9"},
+ {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74c6cfb5cf78535b103fde9ea3ded8e9f16f75bc07789054edc7776abfb3d752"},
+ {file = "ml_dtypes-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:274cc7193dd73b35fb26bef6c5d40ae3eb258359ee71cd82f6e96a8c948bdaa6"},
+ {file = "ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b"},
+ {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7"},
+ {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9"},
+ {file = "ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c"},
+ {file = "ml_dtypes-0.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e35e486e97aee577d0890bc3bd9e9f9eece50c08c163304008587ec8cfe7575b"},
+ {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:560be16dc1e3bdf7c087eb727e2cf9c0e6a3d87e9f415079d2491cc419b3ebf5"},
+ {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad0b757d445a20df39035c4cdeed457ec8b60d236020d2560dbc25887533cf50"},
+ {file = "ml_dtypes-0.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef0d7e3fece227b49b544fa69e50e607ac20948f0043e9f76b44f35f229ea450"},
+ {file = "ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.26.0", markers = "python_version >= \"3.12\""},
+ {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
+ {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
+ {version = ">1.20", markers = "python_version < \"3.10\""},
+]
+
+[package.extras]
+dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"]
+
[[package]]
name = "more-itertools"
version = "10.5.0"
@@ -2129,6 +2562,17 @@ files = [
[package.dependencies]
dill = ">=0.3.9"
+[[package]]
+name = "namex"
+version = "0.0.8"
+description = "A simple utility to separate the implementation of your Python package and its public API surface."
+optional = true
+python-versions = "*"
+files = [
+ {file = "namex-0.0.8-py3-none-any.whl", hash = "sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487"},
+ {file = "namex-0.0.8.tar.gz", hash = "sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b"},
+]
+
[[package]]
name = "nbclient"
version = "0.6.8"
@@ -2469,6 +2913,151 @@ pandas = ">=1.2"
pyyaml = ">=5.1"
scipy = ">=1.7"
+[[package]]
+name = "opt-einsum"
+version = "3.4.0"
+description = "Path optimization of einsum functions."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"},
+ {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"},
+]
+
+[[package]]
+name = "optax"
+version = "0.2.3"
+description = "A gradient processing and optimisation library in JAX."
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "optax-0.2.3-py3-none-any.whl", hash = "sha256:083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f"},
+ {file = "optax-0.2.3.tar.gz", hash = "sha256:ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9"},
+]
+
+[package.dependencies]
+absl-py = ">=0.7.1"
+chex = ">=0.1.86"
+etils = {version = "*", extras = ["epy"]}
+jax = ">=0.4.27"
+jaxlib = ">=0.4.27"
+numpy = ">=1.18.0"
+
+[package.extras]
+docs = ["flax", "ipython (>=8.8.0)", "matplotlib (>=3.5.0)", "myst-nb (>=1.0.0)", "sphinx (>=6.0.0)", "sphinx-autodoc-typehints", "sphinx-book-theme (>=1.0.1)", "sphinx-collections (>=0.0.1)", "sphinx-gallery (>=0.14.0)", "sphinx_contributors", "sphinxcontrib-katex", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"]
+dp-accounting = ["absl-py (>=1.0.0)", "attrs (>=21.4.0)", "mpmath (>=1.2.1)", "numpy (>=1.21.4)", "scipy (>=1.7.1)"]
+examples = ["dp_accounting (>=0.4)", "flax", "ipywidgets", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"]
+test = ["dm-tree (>=0.1.7)", "flax (>=0.5.3)", "scikit-learn", "scipy (>=1.7.1)"]
+
+[[package]]
+name = "optree"
+version = "0.13.0"
+description = "Optimized PyTree Utilities."
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "optree-0.13.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7b8fe0442ac5e50b5e6bceb37dcc2cd4908e7716b869cbe6b8901cc0b489884f"},
+ {file = "optree-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a1aab34de5ac7673fbfb94266bf10482be51985c7f899c3e767ce19d13ce3b4"},
+ {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2c79961d5afeb20557c30a0ae899d14ff58cdf1c0e2c8aa3d6807600d00f619"},
+ {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb55eb77541cf280a009829280d5844936dc8a2e4a3eb069c010a1f547dbfe97"},
+ {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44449e3bc5e7530b50c9a1f5bcf2971ffe317e34edd74d8c9778c5d32078114d"},
+ {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b4195a6ba2052c70bac6d73f19aa69644424c5a30fa09f7319cc1b59e15acb6"},
+ {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7fecc701ece0500fe38fc671b5704d904e2dca9a9284b35263b0bd7e5c62527"},
+ {file = "optree-0.13.0-cp310-cp310-win32.whl", hash = "sha256:46a9e66217fdf421e25c133089c94f8f99bc38a2b5a4a2c0c1e0c1b02b01dda4"},
+ {file = "optree-0.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:ef68fdcb3b1743a46210f3c888cd15668a07422aa10b4d4130ba512aac595bf7"},
+ {file = "optree-0.13.0-cp310-cp310-win_arm64.whl", hash = "sha256:d12a5665169abceb878d50b55571d6a7690bf97aaaf9a7f5438b10e474fde3f2"},
+ {file = "optree-0.13.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:92d1c34b6022bedee4b3899f3a9a1105777da11a9abf1a51f4d84bed8f037fa1"},
+ {file = "optree-0.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d05c320af21efbc132fe887640f7a2dbb36cfb38af6d4e62396fe104b78f7b72"},
+ {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53ae0a0eb128a69a74db4165e7e5f24d54e2711678622198f7073dcb991962f"},
+ {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89f08fc3724b2fe7a081b69dfd3ad6625960443e1f61a984cae7c627776f12f4"},
+ {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f22f4e46d85f24b5bc49e68043dd754b258b880ac64d72f4f4b9ac1b11f0fb2f"},
+ {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbc884f3eab894126398c120b7f92a72a5b9f92db6d8c27d39087da871c642cd"},
+ {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c58b94669c9072d645e02c0c65c7455f8f136ef8f7b56a5d9123847421f95b"},
+ {file = "optree-0.13.0-cp311-cp311-win32.whl", hash = "sha256:54be625517ef3cf52905da7fee63795b2f154dbdb02b37e8cfd63e7fb2f266ea"},
+ {file = "optree-0.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:e3d100890a643e12f39de4226ab4f9d0a22842b4f34ae2964d0149419e4d7aff"},
+ {file = "optree-0.13.0-cp311-cp311-win_arm64.whl", hash = "sha256:cb8d9a2cebc5fadde98773bb27809a72ff01d11f1037cb58f8e71e740586223e"},
+ {file = "optree-0.13.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:abeb8acc83d168063b70168ccf8dfd55f5a7ce50f9af2ca025c41285781ecdd4"},
+ {file = "optree-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4771266f05e99e94312add38d45bf97a4d98449aeab100f5c658c521152eb5e5"},
+ {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc95c1d0c7acd534184bf3ba243a454e0942e4a7c8b9edd32d939fc15e33d753"},
+ {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e48491e042f956d4232ebc138e07074100878c0080e3ba10af4c2db1ba4df9f"},
+ {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e001d9c902e98912503eca66c93d4b4b22f5071e4ab777f4db9e140f35288f4"},
+ {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87870346278f46a8c22866ff48716590be35b4aea16e1373e695fb6442c28c41"},
+ {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7797c54a35e9d89b4664ec7d542745b87b5ffa9c1201c1062fdcd488eb583390"},
+ {file = "optree-0.13.0-cp312-cp312-win32.whl", hash = "sha256:fc90a5373c92f4a9babb4c40fe148516f52160c0ba803bc9b2f936367f2f7437"},
+ {file = "optree-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:1bc65743e8edb29e902cab894d1c4665a8fd6f8d10f75db68a2cef6c7246fa5c"},
+ {file = "optree-0.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:de2729e1e4ae47a07ac3c70ff977ed1ebe19e7b44d5089075c94f7a9a2dc6f4f"},
+ {file = "optree-0.13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dda6efabd0621f53eb46a3789ec89c6fd2c90dfb57aebfce3fcda6eab9ed6a7e"},
+ {file = "optree-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5de8da9bbdd08b6200244ee818cd15d1da0f2b06ac926dba0e686260bac7fd40"},
+ {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca1e4854134023ba687a7abf45ed3355f773ca7198b6895d88a89030446a9f2e"},
+ {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1ac5343e921ce21f8f10f91158ad6404a1488c1cc22ddfa6b34cfb9d997cebd"},
+ {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e282212ddf3aafb10ca6ca223772e06ea3c31687c9cae192467b8e0a7dafbfc"},
+ {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:24fcd4cb659bcd9b675bc3401950de891b32a047c4787857fb870cd515fcc315"},
+ {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d735a7d2d2e2eb9a88d932d35b335c10fae9038034f381b6d437dafed46497e"},
+ {file = "optree-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ef01e79224f0ee6cf2ca642884f0bc04e446227b96dc576c312717eb33552d57"},
+ {file = "optree-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:d3f61fb669b36c1a714346b18c9c488ad33a58049b7b229785c241de18c005d7"},
+ {file = "optree-0.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:695b3f1aab50519230e3d8d86abaedaadf91af105b569cce3b8ebe0dc612b312"},
+ {file = "optree-0.13.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:1318434b0740a2325c197e191e6dd53d9df0a8ac0338c67d58b476aad9d07829"},
+ {file = "optree-0.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d58c6e8d4c4fa4e0c31bc4b876960ccba94eb5fcfb045f2b064ce55707034be9"},
+ {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6a290ba771cc9004f9fc194d23ab11ee4aae71550ca874c3dc985af5b5f910b"},
+ {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c95488ecbab2916de094e68f2a2c55c9475b2e979c03d91a6cd3565f9e5ff2f9"},
+ {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f76a65ff322b3d47af2a23f60409d6d8f184804da551c734e355834e69c0dfb"},
+ {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58cc303f982fb0f23644b7f8e98b4f64b0d031365fcc2284da896e96493176d2"},
+ {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6866b6e4154303dc7c48c7ca3b867a8ce31d469334b67976dfc0513455aa1ca0"},
+ {file = "optree-0.13.0-cp313-cp313t-win32.whl", hash = "sha256:f5ce67f81fe3d7ca5fed8fdaf93a762a63e1d125e20e425ca7200f9e54a3e3a6"},
+ {file = "optree-0.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0008cd39169c1fc10870528b2decfea8b79e61042c12d65a964f3b1cf41cc37d"},
+ {file = "optree-0.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:539962675b547957c64b52b7f82178febb9c0f2d47438b810bbc23cfdcf84821"},
+ {file = "optree-0.13.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b08e4873814d11aa25ef3927c848b9e5cf21215b925e83875b9fe11c7a035b0e"},
+ {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6e236c6601480997c6e1dbbd4ab2b7ea0bc82a9a7baa1f681a1b072c9c02677"},
+ {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:557b415b41006cca88d86ad190b795455e9334d3cf5838e63c4c668a65227ccb"},
+ {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11b78c8a18894fe9503515d073a60ebaed366aeb3cfa65e61e7e71ae833f640b"},
+ {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4207f6fa0bd4730f5496772c139f1444b2b69e4eeb0f454e2100b5a380648f70"},
+ {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe9fd84b7d87f365f720699dedd254882ba7e5ef927d3ba1e13413d45963b691"},
+ {file = "optree-0.13.0-cp37-cp37m-win32.whl", hash = "sha256:c0f9f250f617f114061ab718d460be6be8e0a1cbbfdbbfb5541ed1c8fefee150"},
+ {file = "optree-0.13.0-cp37-cp37m-win_amd64.whl", hash = "sha256:5cf612aefe0201a2995763cce82b9cd03cbddd2bfd6f8975f910c091dfa7bb5f"},
+ {file = "optree-0.13.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:46623259b10f6e3565ea0d37e0b313feb20484bccb005459b3504e1aa706b730"},
+ {file = "optree-0.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e7f9184c6040365e79a0b900507c289b6a4e06ade3c9691e501d176d5cf775cf"},
+ {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6201c065791422a73d5aeb4916e00879de9b097cf54526f82b5b3c297126d938"},
+ {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a423897010c6d8490097671d907da1b6ee90d3fa783aaad5e36e46e0a73bc5e"},
+ {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1fb74282fce108e07972e88dbc23f6b7650c2d3bbddbedc2002d3e0becb1c452"},
+ {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94ecab158521225b20e44d67c8afc2d9af6760985a9f489d21bf2aa8bbe467f8"},
+ {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8244d0fbfe1ef15ffb443f3d32a44aff062adbef0a7fd6db3f011999df966223"},
+ {file = "optree-0.13.0-cp38-cp38-win32.whl", hash = "sha256:0a34c11d637cb01217828e28eef382c621c9ec53f981d8ccbfe56e0a11cda501"},
+ {file = "optree-0.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:ebe56c17bf3754335307b17be7f554c5eae47acf738471cf38dba0ec73a42c37"},
+ {file = "optree-0.13.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e9c619a49984212e5f757e10d5e5f95888b0c08d67a7f2b9f395cede30712dc2"},
+ {file = "optree-0.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:50a9e2d9ffff99d45b37289a3422ed3723a45225616f5b48cea606ff0f539c0f"},
+ {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d702dbcafcd16e8925e30c0e780ab3dc81450e19008fd3e77494111fc161a2b2"},
+ {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f44a58f87059161f300e2be66ad3878fff540d27f5dcd69b21feae65c243a02"},
+ {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:954899edc024f13079932418f59bbdadabc52d9dcb49c7b559c382c7be352dfc"},
+ {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c736ce6f4b8857bd171f3682ef849e3d67692c3fc4db42b99c5d2c7cc1bdf11"},
+ {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7941d3bd48d860d0e17ca24827b5233ea27bb4227e822eafb3897df1f43f8342"},
+ {file = "optree-0.13.0-cp39-cp39-win32.whl", hash = "sha256:9f6fc47c9b10d1a9e77163ebd6f2e251af41fab895475d2ce9643423a41899af"},
+ {file = "optree-0.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:246020f0be50fb66791d8a25c4acb59ad0b4bbdea71c998e375eba4c58fbc3e0"},
+ {file = "optree-0.13.0-cp39-cp39-win_arm64.whl", hash = "sha256:069bf166b7aa48ccf8dfe76b920d2115dd8261107c7895d02500b2ce39621b40"},
+ {file = "optree-0.13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:496170a3d093a7fb69be7ce847f5b5b3aa30a6da81457ba6b54268e6e97c6b13"},
+ {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73543a82be71c041d5b169754089a58d02063eb72ac8688533b6fc26ab6beea8"},
+ {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:278e2c620df99f5b1477b375b01cf9658528fa0332c0bc431d3ec65857244094"},
+ {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36b32155dce29edb6f63a99a44d6da2d8fcd1c56353cc2f4af65f793a0b2712f"},
+ {file = "optree-0.13.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c98a43204841cc4698155acb523d7b21a78f8b05666704359e0fddecd5d1043d"},
+ {file = "optree-0.13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c2803d4ef257f2599cffd0e9d60cfb3d4c522abbe8f5a839bd48d8edd26dae7"},
+ {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3b454f98d28a89c15a1170e771c61902cbc53eed126db36138b684dba5a729"},
+ {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b74afed3db289228e0f95a8909835365f644eb69ff31cd6c0b45608ca9e56d78"},
+ {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc3cebfd7d0826d223662f01ed0fa25932edf3f62479be13c4d6ff0fab090c34"},
+ {file = "optree-0.13.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:5703637ede6fba04cbeabbb47aada7d17606c2d4df73305063f4a3c829c21fc7"},
+ {file = "optree-0.13.0.tar.gz", hash = "sha256:1ea493cde8c60f7950ccbd682bd67e787bf67ed2251d6d3e9ad7471b72d37538"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.5.0"
+
+[package.extras]
+benchmark = ["dm-tree (>=0.1,<0.2.0a0)", "jax[cpu] (>=0.4.6,<0.5.0a0)", "pandas", "tabulate", "termcolor", "torch (>=2.0,<2.4.0a0)", "torchvision"]
+docs = ["docutils", "jax[cpu]", "numpy", "sphinx", "sphinx-autoapi", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinx-rtd-theme", "sphinxcontrib-bibtex", "torch"]
+jax = ["jax"]
+lint = ["black", "cpplint", "doc8", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "flake8-pyi", "flake8-simplify", "isort", "mypy", "pre-commit", "pydocstyle", "pyenchant", "pylint[spelling]", "ruff", "xdoctest"]
+numpy = ["numpy"]
+test = ["pytest", "pytest-cov", "pytest-xdist"]
+torch = ["torch"]
+
[[package]]
name = "ortools"
version = "9.10.4067"
@@ -2574,9 +3163,9 @@ files = [
[package.dependencies]
numpy = [
- {version = ">=1.22.4", markers = "python_version < \"3.11\""},
- {version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
+ {version = ">=1.23.2", markers = "python_version == \"3.11\""},
+ {version = ">=1.22.4", markers = "python_version < \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@@ -2776,6 +3365,17 @@ files = [
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
+[[package]]
+name = "ply"
+version = "3.11"
+description = "Python Lex & Yacc"
+optional = true
+python-versions = "*"
+files = [
+ {file = "ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce"},
+ {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"},
+]
+
[[package]]
name = "pox"
version = "0.3.5"
@@ -2903,52 +3503,55 @@ files = [
[[package]]
name = "pyarrow"
-version = "17.0.0"
+version = "18.0.0"
description = "Python library for Apache Arrow"
optional = true
-python-versions = ">=3.8"
+python-versions = ">=3.9"
files = [
- {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"},
- {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"},
- {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"},
- {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"},
- {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"},
- {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"},
- {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"},
- {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"},
- {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"},
- {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"},
- {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"},
- {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"},
- {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"},
- {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"},
- {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"},
- {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"},
- {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"},
- {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"},
- {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"},
- {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"},
- {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
- {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
- {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
- {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
- {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
- {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
- {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
- {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
- {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
- {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
- {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
- {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
- {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
- {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
- {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
- {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
+ {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2333f93260674e185cfbf208d2da3007132572e56871f451ba1a556b45dae6e2"},
+ {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4c381857754da44326f3a49b8b199f7f87a51c2faacd5114352fc78de30d3aba"},
+ {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:603cd8ad4976568954598ef0a6d4ed3dfb78aff3d57fa8d6271f470f0ce7d34f"},
+ {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58a62549a3e0bc9e03df32f350e10e1efb94ec6cf63e3920c3385b26663948ce"},
+ {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bc97316840a349485fbb137eb8d0f4d7057e1b2c1272b1a20eebbbe1848f5122"},
+ {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2e549a748fa8b8715e734919923f69318c953e077e9c02140ada13e59d043310"},
+ {file = "pyarrow-18.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:606e9a3dcb0f52307c5040698ea962685fb1c852d72379ee9412be7de9c5f9e2"},
+ {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d5795e37c0a33baa618c5e054cd61f586cf76850a251e2b21355e4085def6280"},
+ {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:5f0510608ccd6e7f02ca8596962afb8c6cc84c453e7be0da4d85f5f4f7b0328a"},
+ {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616ea2826c03c16e87f517c46296621a7c51e30400f6d0a61be645f203aa2b93"},
+ {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1824f5b029ddd289919f354bc285992cb4e32da518758c136271cf66046ef22"},
+ {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd1b52d0d58dd8f685ced9971eb49f697d753aa7912f0a8f50833c7a7426319"},
+ {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:320ae9bd45ad7ecc12ec858b3e8e462578de060832b98fc4d671dee9f10d9954"},
+ {file = "pyarrow-18.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2c992716cffb1088414f2b478f7af0175fd0a76fea80841b1706baa8fb0ebaad"},
+ {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:e7ab04f272f98ebffd2a0661e4e126036f6936391ba2889ed2d44c5006237802"},
+ {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:03f40b65a43be159d2f97fd64dc998f769d0995a50c00f07aab58b0b3da87e1f"},
+ {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be08af84808dff63a76860847c48ec0416928a7b3a17c2f49a072cac7c45efbd"},
+ {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70c1965cde991b711a98448ccda3486f2a336457cf4ec4dca257a926e149c9"},
+ {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:00178509f379415a3fcf855af020e3340254f990a8534294ec3cf674d6e255fd"},
+ {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a71ab0589a63a3e987beb2bc172e05f000a5c5be2636b4b263c44034e215b5d7"},
+ {file = "pyarrow-18.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe92efcdbfa0bcf2fa602e466d7f2905500f33f09eb90bf0bcf2e6ca41b574c8"},
+ {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:907ee0aa8ca576f5e0cdc20b5aeb2ad4d3953a3b4769fc4b499e00ef0266f02f"},
+ {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:66dcc216ebae2eb4c37b223feaf82f15b69d502821dde2da138ec5a3716e7463"},
+ {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc1daf7c425f58527900876354390ee41b0ae962a73ad0959b9d829def583bb1"},
+ {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:871b292d4b696b09120ed5bde894f79ee2a5f109cb84470546471df264cae136"},
+ {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:082ba62bdcb939824ba1ce10b8acef5ab621da1f4c4805e07bfd153617ac19d4"},
+ {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2c664ab88b9766413197733c1720d3dcd4190e8fa3bbdc3710384630a0a7207b"},
+ {file = "pyarrow-18.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc892be34dbd058e8d189b47db1e33a227d965ea8805a235c8a7286f7fd17d3a"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:28f9c39a56d2c78bf6b87dcc699d520ab850919d4a8c7418cd20eda49874a2ea"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:f1a198a50c409ab2d009fbf20956ace84567d67f2c5701511d4dd561fae6f32e"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5bd7fd32e3ace012d43925ea4fc8bd1b02cc6cc1e9813b518302950e89b5a22"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:336addb8b6f5208be1b2398442c703a710b6b937b1a046065ee4db65e782ff5a"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:45476490dd4adec5472c92b4d253e245258745d0ccaabe706f8d03288ed60a79"},
+ {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b46591222c864e7da7faa3b19455196416cd8355ff6c2cc2e65726a760a3c420"},
+ {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb7e3abcda7e1e6b83c2dc2909c8d045881017270a119cc6ee7fdcfe71d02df8"},
+ {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:09f30690b99ce34e0da64d20dab372ee54431745e4efb78ac938234a282d15f9"},
+ {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5ca5d707e158540312e09fd907f9f49bacbe779ab5236d9699ced14d2293b8"},
+ {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6331f280c6e4521c69b201a42dd978f60f7e129511a55da9e0bfe426b4ebb8d"},
+ {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3ac24b2be732e78a5a3ac0b3aa870d73766dd00beba6e015ea2ea7394f8b4e55"},
+ {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b30a927c6dff89ee702686596f27c25160dd6c99be5bcc1513a763ae5b1bfc03"},
+ {file = "pyarrow-18.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:8f40ec677e942374e3d7f2fad6a67a4c2811a8b975e8703c6fd26d3b168a90e2"},
+ {file = "pyarrow-18.0.0.tar.gz", hash = "sha256:a6aa027b1a9d2970cf328ccd6dbe4a996bc13c39fd427f502782f5bdb9ca20f5"},
]
-[package.dependencies]
-numpy = ">=1.16.6"
-
[package.extras]
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
@@ -3271,6 +3874,61 @@ files = [
[package.dependencies]
certifi = "*"
+[[package]]
+name = "pyrddlgym"
+version = "2.0"
+description = "pyRDDLGym: RDDL automatic generation tool for OpenAI Gym"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyRDDLGym-2.0-py3-none-any.whl", hash = "sha256:dfbd3a17d7b6843ecd2ca211d31384ce28bdcdc568c5f612b5c8e49ea789a340"},
+ {file = "pyRDDLGym-2.0.tar.gz", hash = "sha256:9df637574a40a915f986f55700107ca823542f7e7b7d16d7bcc151b6ec126ad1"},
+]
+
+[package.dependencies]
+gymnasium = "*"
+matplotlib = ">=3.5.0"
+numpy = ">=1.22"
+pillow = ">=9.2.0"
+ply = "*"
+pygame = "*"
+termcolor = "*"
+
+[[package]]
+name = "pyrddlgym-jax"
+version = "0.3"
+description = "pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyRDDLGym_jax-0.3-py3-none-any.whl", hash = "sha256:4908a0f5e3850a24ea5a75455723e0d9c219e4beed82b2fc738287c232dbc18b"},
+ {file = "pyrddlgym_jax-0.3.tar.gz", hash = "sha256:c5dd35ff2cd871135cb035f5266e9247b97863806f2b5b31b74ad6c28377cb24"},
+]
+
+[package.dependencies]
+bayesian-optimization = ">=1.4.3"
+dm-haiku = ">=0.0.10"
+jax = ">=0.4.12"
+optax = ">=0.1.9"
+pyRDDLGym = ">=2.0"
+tensorflow = ">=2.13.0"
+tensorflow-probability = ">=0.21.0"
+tqdm = ">=4.66"
+
+[[package]]
+name = "pyrddlgym-rl"
+version = "0.1"
+description = "pyRDDLGym-rl: Wrappers for reinforcement learning algorithms (i.e. stable baselines 3) to work with pyRDDLGym."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "pyRDDLGym-rl-0.1.tar.gz", hash = "sha256:2f1c9b420bc0f2418e0c44fcf587950471366f6c5fce9587295d5f5dd1cc1c33"},
+ {file = "pyRDDLGym_rl-0.1-py3-none-any.whl", hash = "sha256:1baa8a5bfb72e7a582aeb40803ca9792bc55d2d5fe1e3ee087d6ea1053b719de"},
+]
+
+[package.dependencies]
+pyRDDLGym = ">=2.0"
+
[[package]]
name = "pyshp"
version = "2.3.1"
@@ -3654,6 +4312,20 @@ serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "grpcio
train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"]
tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"]
+[[package]]
+name = "rddlrepository"
+version = "2.0"
+description = "Home for all things RDDL"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "rddlrepository-2.0-py3-none-any.whl", hash = "sha256:ed993b010b535128de15aefb7d4507d220df9b68dbc25c5bf0315830184817cb"},
+ {file = "rddlrepository-2.0.tar.gz", hash = "sha256:570b6b0044f20b013aac59b44519f5a554fc1097a5f0f68a5110fd7c07977cab"},
+]
+
+[package.dependencies]
+numpy = "*"
+
[[package]]
name = "referencing"
version = "0.35.1"
@@ -3869,6 +4541,56 @@ docs = ["PyWavelets (>=1.1.1)", "dask[array] (>=2022.9.2)", "ipykernel", "ipywid
optional = ["PyWavelets (>=1.1.1)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0)", "matplotlib (>=3.6)", "pooch (>=1.6.0)", "pyamg", "scikit-learn (>=1.1)"]
test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"]
+[[package]]
+name = "scikit-learn"
+version = "1.5.2"
+description = "A set of python modules for machine learning and data mining"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"},
+ {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"},
+]
+
+[package.dependencies]
+joblib = ">=1.2.0"
+numpy = ">=1.19.5"
+scipy = ">=1.6.0"
+threadpoolctl = ">=3.1.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
+build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"]
+examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
+install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
+maintenance = ["conda-lock (==2.5.6)"]
+tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"]
+
[[package]]
name = "scipy"
version = "1.13.1"
@@ -4115,6 +4837,54 @@ mpmath = ">=1.1.0,<1.4"
[package.extras]
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+[[package]]
+name = "tabulate"
+version = "0.9.0"
+description = "Pretty-print tabular data"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
+ {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
+]
+
+[package.extras]
+widechars = ["wcwidth"]
+
+[[package]]
+name = "tensorboard"
+version = "2.18.0"
+description = "TensorBoard lets you watch Tensors Flow"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"},
+]
+
+[package.dependencies]
+absl-py = ">=0.4"
+grpcio = ">=1.48.2"
+markdown = ">=2.6.8"
+numpy = ">=1.12.0"
+packaging = "*"
+protobuf = ">=3.19.6,<4.24.0 || >4.24.0"
+setuptools = ">=41.0.0"
+six = ">1.9"
+tensorboard-data-server = ">=0.7.0,<0.8.0"
+werkzeug = ">=1.0.1"
+
+[[package]]
+name = "tensorboard-data-server"
+version = "0.7.2"
+description = "Fast data loading for TensorBoard"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
+ {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
+ {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
+]
+
[[package]]
name = "tensorboardx"
version = "2.6.2.2"
@@ -4131,6 +4901,139 @@ numpy = "*"
packaging = "*"
protobuf = ">=3.20"
+[[package]]
+name = "tensorflow"
+version = "2.18.0"
+description = "TensorFlow is an open source machine learning framework for everyone."
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "tensorflow-2.18.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:8da90a9388a1f6dd00d626590d2b5810faffbb3e7367f9783d80efff882340ee"},
+ {file = "tensorflow-2.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:589342fb9bdcab2e9af0f946da4ca97757677e297d934fcdc087e87db99d6353"},
+ {file = "tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1eb77fae50d699442726d1b23c7512c97cd688cc7d857b028683d4535bbf3709"},
+ {file = "tensorflow-2.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:46f5a8b4e6273f488dc069fc3ac2211b23acd3d0437d919349c787fa341baa8a"},
+ {file = "tensorflow-2.18.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:453cb60638a02fd26316fb36c8cbcf1569d33671f17c658ca0cf2b4626f851e7"},
+ {file = "tensorflow-2.18.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85f1e7369af6d329b117b52e86093cd1e0458dd5404bf5b665853f873dd00b48"},
+ {file = "tensorflow-2.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b8dd70fa3600bfce66ab529eebb804e1f9d7c863d2f71bc8fe9fc7a1ec3976"},
+ {file = "tensorflow-2.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e8b0f499ef0b7652480a58e358a73844932047f21c42c56f7f3bdcaf0803edc"},
+ {file = "tensorflow-2.18.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ec4133a215c59314e929e7cbe914579d3afbc7874d9fa924873ee633fe4f71d0"},
+ {file = "tensorflow-2.18.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4822904b3559d8a9c25f0fe5fef191cfc1352ceca42ca64f2a7bc7ae0ff4a1f5"},
+ {file = "tensorflow-2.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfdd65ea7e064064283dd78d529dd621257ee617218f63681935fd15817c6286"},
+ {file = "tensorflow-2.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:a701c2d3dca5f2efcab315b2c217f140ebd3da80410744e87d77016b3aaf53cb"},
+ {file = "tensorflow-2.18.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:336cace378c129c20fee6292f6a541165073d153a9a4c9cf4f14478a81895776"},
+ {file = "tensorflow-2.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcfd32134de8f95515b2d0ced89cdae15484b787d3a21893e9291def06c10c4e"},
+ {file = "tensorflow-2.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada1f7290c75b34748ee7378c1b77927e4044c94b8dc72dc75e7667c4fdaeb94"},
+ {file = "tensorflow-2.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:f8c946df1cb384504578fac1c199a95322373b8e04abd88aa8ae01301df469ea"},
+]
+
+[package.dependencies]
+absl-py = ">=1.0.0"
+astunparse = ">=1.6.0"
+flatbuffers = ">=24.3.25"
+gast = ">=0.2.1,<0.5.0 || >0.5.0,<0.5.1 || >0.5.1,<0.5.2 || >0.5.2"
+google-pasta = ">=0.1.1"
+grpcio = ">=1.24.3,<2.0"
+h5py = ">=3.11.0"
+keras = ">=3.5.0"
+libclang = ">=13.0.0"
+ml-dtypes = ">=0.4.0,<0.5.0"
+numpy = ">=1.26.0,<2.1.0"
+opt-einsum = ">=2.3.2"
+packaging = "*"
+protobuf = ">=3.20.3,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
+requests = ">=2.21.0,<3"
+setuptools = "*"
+six = ">=1.12.0"
+tensorboard = ">=2.18,<2.19"
+tensorflow-io-gcs-filesystem = {version = ">=0.23.1", markers = "python_version < \"3.12\""}
+termcolor = ">=1.1.0"
+typing-extensions = ">=3.6.6"
+wrapt = ">=1.11.0"
+
+[package.extras]
+and-cuda = ["nvidia-cublas-cu12 (==12.5.3.2)", "nvidia-cuda-cupti-cu12 (==12.5.82)", "nvidia-cuda-nvcc-cu12 (==12.5.82)", "nvidia-cuda-nvrtc-cu12 (==12.5.82)", "nvidia-cuda-runtime-cu12 (==12.5.82)", "nvidia-cudnn-cu12 (==9.3.0.75)", "nvidia-cufft-cu12 (==11.2.3.61)", "nvidia-curand-cu12 (==10.3.6.82)", "nvidia-cusolver-cu12 (==11.6.3.83)", "nvidia-cusparse-cu12 (==12.5.1.3)", "nvidia-nccl-cu12 (==2.21.5)", "nvidia-nvjitlink-cu12 (==12.5.82)"]
+
+[[package]]
+name = "tensorflow-io-gcs-filesystem"
+version = "0.37.1"
+description = "TensorFlow IO"
+optional = true
+python-versions = "<3.13,>=3.7"
+files = [
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed"},
+ {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95"},
+]
+
+[package.extras]
+tensorflow = ["tensorflow (>=2.16.0,<2.17.0)"]
+tensorflow-aarch64 = ["tensorflow-aarch64 (>=2.16.0,<2.17.0)"]
+tensorflow-cpu = ["tensorflow-cpu (>=2.16.0,<2.17.0)"]
+tensorflow-gpu = ["tensorflow-gpu (>=2.16.0,<2.17.0)"]
+tensorflow-rocm = ["tensorflow-rocm (>=2.16.0,<2.17.0)"]
+
+[[package]]
+name = "tensorflow-probability"
+version = "0.24.0"
+description = "Probabilistic modeling and statistical inference in TensorFlow"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "tensorflow_probability-0.24.0-py2.py3-none-any.whl", hash = "sha256:8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b"},
+]
+
+[package.dependencies]
+absl-py = "*"
+cloudpickle = ">=1.3"
+decorator = "*"
+dm-tree = "*"
+gast = ">=0.3.2"
+numpy = ">=1.13.3"
+six = ">=1.10.0"
+
+[package.extras]
+jax = ["jax", "jaxlib"]
+tf = ["tensorflow (>=2.16)", "tf-keras (>=2.16)"]
+tfds = ["tensorflow-datasets (>=2.2.0)"]
+
+[[package]]
+name = "termcolor"
+version = "2.5.0"
+description = "ANSI color formatting for output in terminal"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
+ {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
+]
+
+[package.extras]
+tests = ["pytest", "pytest-cov"]
+
+[[package]]
+name = "threadpoolctl"
+version = "3.5.0"
+description = "threadpoolctl"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
+ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
+]
+
[[package]]
name = "tifffile"
version = "2024.8.30"
@@ -4164,6 +5067,17 @@ files = [
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
]
+[[package]]
+name = "toolz"
+version = "1.0.0"
+description = "List processing tools and functional utilities"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"},
+ {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"},
+]
+
[[package]]
name = "torch"
version = "2.5.0"
@@ -4445,6 +5359,23 @@ files = [
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
+[[package]]
+name = "werkzeug"
+version = "3.0.6"
+description = "The comprehensive WSGI web application library."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"},
+ {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.1.1"
+
+[package.extras]
+watchdog = ["watchdog (>=2.3)"]
+
[[package]]
name = "wheel"
version = "0.44.0"
@@ -4459,6 +5390,85 @@ files = [
[package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
+[[package]]
+name = "wrapt"
+version = "1.16.0"
+description = "Module for decorators, wrappers and monkey patching."
+optional = true
+python-versions = ">=3.6"
+files = [
+ {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"},
+ {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"},
+ {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"},
+ {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"},
+ {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"},
+ {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"},
+ {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"},
+ {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"},
+ {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"},
+ {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"},
+ {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"},
+ {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"},
+ {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"},
+ {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"},
+ {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"},
+ {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"},
+ {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"},
+ {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"},
+ {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"},
+ {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"},
+ {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"},
+ {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"},
+ {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"},
+ {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"},
+ {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"},
+ {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"},
+ {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"},
+ {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"},
+ {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"},
+ {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"},
+ {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"},
+ {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"},
+ {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"},
+ {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"},
+ {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"},
+ {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"},
+ {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"},
+ {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"},
+ {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"},
+ {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"},
+ {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"},
+ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"},
+]
+
[[package]]
name = "zipp"
version = "3.20.2"
@@ -4479,11 +5489,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
type = ["pytest-mypy"]
[extras]
-all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pygeodesy", "pygrib", "pygrib", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"]
-domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pygeodesy", "pygrib", "pygrib", "scipy", "unified-planning"]
-solvers = ["gymnasium", "joblib", "numpy", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"]
+all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym-jax", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"]
+domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "rddlrepository", "scipy", "unified-planning"]
+solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-jax", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
-content-hash = "0a37a3952616b8038d4f715c75d98aeb903691b1d93b9cfe268b15b09334ebb3"
+content-hash = "55a5472c7d8ca6d30cd149a54055e1a2322d87320c3f970fde56898801645455"
diff --git a/pyproject.toml b/pyproject.toml
index 1c57043c1f..c53fd3d598 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scikit-decide"
-version = "0.0.0" # placeholder for poetry-dynamic-versioning
+version = "1.0.3.dev5+90a1f41c" # placeholder for poetry-dynamic-versioning
description = "The AI framework for Reinforcement Learning, Automated Planning and Scheduling"
authors = ["Airbus AI Research "]
license = "MIT"
@@ -34,7 +34,7 @@ script = "builder.py"
generate-setup-file = true
[tool.poetry-dynamic-versioning]
-enable = true
+enable = false
vcs = "git"
format-jinja = """
{%- if distance == 0 -%}
@@ -71,6 +71,10 @@ pygrib = [
{ version = "<=2.1.5", platform = "linux", optional = true },
{ version = ">=2.1.5", platform = "darwin", optional = true },
]
+pyRDDLGym = { version = ">=2.0, <2.1", optional = true }
+pyRDDLGym-rl = { version = ">=0.1", optional = true }
+pyRDDLGym-jax = { version = ">=0.3", optional = true }
+rddlrepository = {version = ">=2.0", optional = true }
[tool.poetry.extras]
domains = [
@@ -82,7 +86,10 @@ domains = [
"unified-planning",
"cartopy",
"pygrib",
- "scipy"
+ "scipy",
+ "pyRDDLGym",
+ "pyRDDLGym-rl",
+ "rddlrepository"
]
solvers = [
"gymnasium",
@@ -95,7 +102,8 @@ solvers = [
"up-fast-downward",
"up-enhsp",
"up-pyperplan",
- "scipy"
+ "scipy",
+ "pyRDDLGym-jax"
]
all = [
"gymnasium",
@@ -113,7 +121,11 @@ all = [
"up-pyperplan",
"cartopy",
"pygrib",
- "scipy"
+ "scipy",
+ "pyRDDLGym",
+ "pyRDDLGym-rl",
+ "rddlrepository",
+ "pyRDDLGym-jax"
]
[tool.poetry.plugins."skdecide.domains"]
@@ -136,6 +148,9 @@ Stochastic_RCPSP = "skdecide.hub.domain.rcpsp:Stochastic_RCPSP [domains]"
SMRCPSPCalendar = "skdecide.hub.domain.rcpsp:SMRCPSPCalendar [domains]"
MSRCPSP = "skdecide.hub.domain.rcpsp:MSRCPSP [domains]"
MSRCPSPCalendar = "skdecide.hub.domain.rcpsp:MSRCPSPCalendar [domains]"
+RDDLDomain = "skdecide.hub.domain.rddl:RDDLDomain [domains]"
+RDDLDomainRL = "skdecide.hub.domain.rddl:RDDLDomainRL [domains]"
+RDDLDomainSimplifiedSpaces = "skdecide.hub.domain.rddl:RDDLDomainSimplifiedSpaces [domains]"
[tool.poetry.plugins."skdecide.solvers"]
AOstar = "skdecide.hub.solver.aostar:AOstar [solvers]"
@@ -162,6 +177,8 @@ DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]"
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]"
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]"
+RDDLJaxSolver = "skdecide.hub.solver.rddl:RDDLJaxSolver [solvers]"
+RDDLGurobiSolver = "skdecide.hub.solver.rddl:RDDLGurobiSolver [solvers]"
[tool.poetry.dev-dependencies]
pytest = "^6.2.2"
diff --git a/skdecide/hub/domain/rddl/__init__.py b/skdecide/hub/domain/rddl/__init__.py
new file mode 100644
index 0000000000..d9de2e0c34
--- /dev/null
+++ b/skdecide/hub/domain/rddl/__init__.py
@@ -0,0 +1 @@
+from .rddl import RDDLDomain, RDDLDomainRL, RDDLDomainSimplifiedSpaces
diff --git a/skdecide/hub/domain/rddl/rddl.py b/skdecide/hub/domain/rddl/rddl.py
new file mode 100644
index 0000000000..5c9fe549fc
--- /dev/null
+++ b/skdecide/hub/domain/rddl/rddl.py
@@ -0,0 +1,179 @@
+import os
+import shutil
+from datetime import datetime
+from typing import Any
+
+import numpy as np
+import pyRDDLGym
+from gymnasium.spaces.utils import flatten, flatten_space
+from pyRDDLGym import RDDLEnv
+from pyRDDLGym.core.simulator import RDDLSimulator
+from pyRDDLGym.core.visualizer.chart import ChartVisualizer
+from pyRDDLGym.core.visualizer.movie import MovieGenerator
+from pyRDDLGym.core.visualizer.viz import BaseViz
+from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv
+
+from skdecide.builders.domain import FullyObservable, Renderable, UnrestrictedActions
+from skdecide.core import Space, TransitionOutcome, Value
+from skdecide.domains import RLDomain
+from skdecide.hub.space.gym import GymSpace
+
+try:
+ import IPython
+except ImportError:
+ ipython_available = False
+else:
+ ipython_available = True
+ from IPython.display import clear_output, display
+
+
+class D(RLDomain, UnrestrictedActions, FullyObservable, Renderable):
+ T_state = dict[str, Any] # Type of states
+ T_observation = T_state # Type of observations
+ T_event = np.array # Type of events
+ T_value = float # Type of transition values (rewards or costs)
+ T_info = None # Type of additional information in environment outcome
+
+
+class RDDLDomain(D):
+ def __init__(
+ self,
+ rddl_domain: str,
+ rddl_instance: str,
+ base_class: type[RDDLEnv] = RDDLEnv,
+ backend: type[RDDLSimulator] = RDDLSimulator,
+ display_with_pygame: bool = True,
+ display_within_jupyter: bool = False,
+ visualizer: BaseViz = ChartVisualizer,
+ movie_name: str = None,
+ movie_dir: str = "rddl_movies",
+ max_frames=1000,
+ enforce_action_constraints=True,
+ **kwargs
+ ):
+ self.rddl_gym_env = pyRDDLGym.make(
+ rddl_domain,
+ rddl_instance,
+ base_class=base_class,
+ backend=backend,
+ enforce_action_constraints=enforce_action_constraints,
+ **kwargs
+ )
+ self.display_within_jupyter = display_within_jupyter
+ self.display_with_pygame = display_with_pygame
+ self.movie_name = movie_name
+ self._nb_step = 0
+ if movie_name is not None:
+ self.movie_path = os.path.join(movie_dir, movie_name)
+ if not os.path.exists(self.movie_path):
+ os.makedirs(self.movie_path)
+ tmp_pngs = os.path.join(self.movie_path, "tmp_pngs")
+ if os.path.exists(tmp_pngs):
+ shutil.rmtree(tmp_pngs)
+ os.makedirs(tmp_pngs)
+ self.movie_gen = MovieGenerator(tmp_pngs, movie_name, max_frames=max_frames)
+ self.rddl_gym_env.set_visualizer(visualizer, self.movie_gen)
+ else:
+ self.movie_gen = None
+ self.rddl_gym_env.set_visualizer(visualizer)
+
+ def _state_step(
+ self, action: D.T_event
+ ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
+ next_state, reward, terminated, truncated, _ = self.rddl_gym_env.step(action)
+ termination = terminated or truncated
+ if self.movie_gen is not None and (
+ termination or self._nb_step >= self.movie_gen.max_frames - 1
+ ):
+ self.movie_gen.save_animation(self.movie_name)
+ tmp_pngs = os.path.join(self.movie_path, "tmp_pngs")
+ shutil.move(
+ os.path.join(tmp_pngs, self.movie_name + ".gif"),
+ os.path.join(
+ self.movie_path,
+ self.movie_name
+ + "_"
+ + str(datetime.now().strftime("%Y%m%d-%H%M%S"))
+ + ".gif",
+ ),
+ )
+ self._nb_step += 1
+ # TransitionOutcome and Value are scikit-decide types
+ return TransitionOutcome(
+ state=next_state, value=Value(reward=reward), termination=termination
+ )
+
+ def _get_action_space_(self) -> Space[D.T_event]:
+ # Cast to skdecide's GymSpace
+ return GymSpace(self.rddl_gym_env.action_space)
+
+ def _state_reset(self) -> D.T_state:
+ self._nb_step = 0
+ # SkDecide only needs the state, not the info
+ return self.rddl_gym_env.reset()[0]
+
+ def _get_observation_space_(self) -> Space[D.T_observation]:
+ # Cast to skdecide's GymSpace
+ return GymSpace(self.rddl_gym_env.observation_space)
+
+ def _render_from(self, memory: D.T_state = None, **kwargs: Any) -> Any:
+ # We do not want the image to be displayed in a pygame window, but rather in this notebook
+ rddl_gym_img = self.rddl_gym_env.render(to_display=self.display_with_pygame)
+ if self.display_within_jupyter and ipython_available:
+ clear_output(wait=True)
+ display(rddl_gym_img)
+ return rddl_gym_img
+
+
+class RDDLDomainRL(RDDLDomain):
+ def __init__(
+ self,
+ rddl_domain: str,
+ rddl_instance: str,
+ base_class: type[RDDLEnv] = SimplifiedActionRDDLEnv,
+ backend: type[RDDLSimulator] = RDDLSimulator,
+ display_with_pygame: bool = True,
+ display_within_jupyter: bool = False,
+ visualizer: BaseViz = ChartVisualizer,
+ movie_name: str = None,
+ movie_dir: str = "rddl_movies",
+ max_frames=1000,
+ enforce_action_constraints=True,
+ **kwargs
+ ):
+ super().__init__(
+ rddl_domain=rddl_domain,
+ rddl_instance=rddl_instance,
+ base_class=base_class,
+ backend=backend,
+ display_with_pygame=display_with_pygame,
+ display_within_jupyter=display_within_jupyter,
+ visualizer=visualizer,
+ movie_name=movie_name,
+ movie_dir=movie_dir,
+ max_frames=max_frames,
+ enforce_action_constraints=enforce_action_constraints,
+ **kwargs
+ )
+
+
+class RDDLDomainSimplifiedSpaces(RDDLDomainRL):
+ def _state_step(
+ self, action: D.T_event
+ ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
+ outcome = super()._state_step(action)
+ return TransitionOutcome(
+ state=flatten(self.rddl_gym_env.observation_space, outcome.state),
+ value=outcome.value,
+ termination=outcome.termination,
+ )
+
+ def _get_action_space_(self) -> Space[D.T_event]:
+ return GymSpace(flatten_space(self.rddl_gym_env.action_space))
+
+ def _state_reset(self) -> D.T_state:
+ # SkDecide only needs the state, not the info
+ return flatten(self.rddl_gym_env.observation_space, super()._state_reset())
+
+ def _get_observation_space_(self) -> Space[D.T_observation]:
+ return GymSpace(flatten_space(self.rddl_gym_env.observation_space))
diff --git a/skdecide/hub/solver/rddl/__init__.py b/skdecide/hub/solver/rddl/__init__.py
new file mode 100644
index 0000000000..052f974b5e
--- /dev/null
+++ b/skdecide/hub/solver/rddl/__init__.py
@@ -0,0 +1 @@
+from .rddl import RDDLGurobiSolver, RDDLJaxSolver
diff --git a/skdecide/hub/solver/rddl/rddl.py b/skdecide/hub/solver/rddl/rddl.py
new file mode 100644
index 0000000000..8775256c13
--- /dev/null
+++ b/skdecide/hub/solver/rddl/rddl.py
@@ -0,0 +1,117 @@
+from collections.abc import Callable
+from typing import Any, Optional
+
+from pyRDDLGym_jax.core.planner import (
+ JaxBackpropPlanner,
+ JaxOfflineController,
+ JaxOnlineController,
+ load_config,
+)
+from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
+
+from skdecide import Solver
+from skdecide.builders.solver import FromInitialState, Policies
+from skdecide.hub.domain.rddl import RDDLDomain
+
+try:
+ from pyRDDLGym_gurobi.core.planner import (
+ GurobiOnlineController,
+ GurobiPlan,
+ GurobiStraightLinePlan,
+ )
+except ImportError:
+ pyrddlgym_gurobi_available = False
+else:
+ pyrddlgym_gurobi_available = True
+
+
+class D(RDDLDomain):
+ pass
+
+
+class RDDLJaxSolver(Solver, Policies, FromInitialState):
+ T_domain = D
+
+ def __init__(
+ self, domain_factory: Callable[[], RDDLDomain], config: Optional[str] = None
+ ):
+ Solver.__init__(self, domain_factory=domain_factory)
+ self._domain = domain_factory()
+ if config is not None:
+ self.planner_args, _, self.train_args = load_config(config)
+
+ @classmethod
+ def _check_domain_additional(cls, domain: D) -> bool:
+ return hasattr(domain, "rddl_gym_env")
+
+ def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
+ planner = JaxBackpropPlanner(
+ rddl=self._domain.rddl_gym_env.model,
+ **(self.planner_args if self.planner_args is not None else {})
+ )
+ self.controller = JaxOfflineController(
+ planner, **(self.train_args if self.train_args is not None else {})
+ )
+
+ def _sample_action(self, observation: D.T_observation) -> D.T_event:
+ return self.controller.sample_action(observation)
+
+ def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
+ return True
+
+
+if pyrddlgym_gurobi_available:
+
+ class D(RDDLDomain):
+ pass
+
+ class RDDLGurobiSolver(Solver, Policies, FromInitialState):
+ T_domain = D
+
+ def __init__(
+ self,
+ domain_factory: Callable[[], RDDLDomain],
+ plan: Optional[GurobiPlan] = None,
+ rollout_horizon=5,
+ model_params: Optional[dict[str, Any]] = None,
+ ):
+ Solver.__init__(self, domain_factory=domain_factory)
+ self._domain = domain_factory()
+ self.rollout_horizon = rollout_horizon
+ if plan is None:
+ self.plan = GurobiStraightLinePlan()
+ else:
+ self.plan = plan
+ if model_params is None:
+ self.model_params = {"NonConvex": 2, "OutputFlag": 0}
+ else:
+ self.model_params = model_params
+
+ @classmethod
+ def _check_domain_additional(cls, domain: D) -> bool:
+ return hasattr(domain, "rddl_gym_env")
+
+ def _solve(self, from_memory: Optional[D.T_state] = None) -> None:
+ self.controller = GurobiOnlineController(
+ rddl=self._domain.rddl_gym_env.model,
+ plan=self.plan,
+ rollout_horizon=self.rollout_horizon,
+ model_params=self.model_params,
+ )
+
+ def _sample_action(self, observation: D.T_observation) -> D.T_event:
+ return self.controller.sample_action(observation)
+
+ def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
+ return True
+
+else:
+
+ class RDDLGurobiSolver(Solver, Policies, FromInitialState):
+ T_domain = D
+
+ def __init__(self, domain_factory: Callable[[], RDDLDomain], rollout_horizon=5):
+ raise RuntimeError(
+ "You need pyRDDLGym-gurobi installed for this solver. "
+ "See https://github.com/pyrddlgym-project/pyRDDLGym-gurobi for more information."
+ )
diff --git a/tests/domains/python/test_pyrddlgym.py b/tests/domains/python/test_pyrddlgym.py
new file mode 100644
index 0000000000..48a2761af4
--- /dev/null
+++ b/tests/domains/python/test_pyrddlgym.py
@@ -0,0 +1,80 @@
+import os
+import shutil
+
+from stable_baselines3 import PPO as SB3_PPO
+
+from skdecide.hub.domain.rddl import (
+ RDDLDomain,
+ RDDLDomainRL,
+ RDDLDomainSimplifiedSpaces,
+)
+from skdecide.hub.solver.cgp import CGP
+from skdecide.hub.solver.stable_baselines import StableBaseline
+from skdecide.utils import load_registered_domain, rollout
+
+
+def test_pyrddlgymdomain_sb3():
+ movie_name = "test-sb3"
+ movie_path = f"rddl_movies/{movie_name}"
+ domain_factory = lambda: RDDLDomainRL(
+ rddl_domain="Cartpole_Continuous_gym",
+ rddl_instance="0",
+ movie_name=movie_name,
+ display_with_pygame=False,
+ display_within_jupyter=False,
+ )
+ domain = domain_factory()
+ domain.reset()
+ domain.render()
+
+ solver_factory = lambda: StableBaseline(
+ domain_factory=domain_factory,
+ algo_class=SB3_PPO,
+ baselines_policy="MultiInputPolicy",
+ learn_config={"total_timesteps": 100},
+ verbose=0,
+ )
+
+ shutil.rmtree(movie_path)
+
+ with solver_factory() as solver:
+ solver.solve()
+ rollout(domain_factory(), solver, max_steps=100, render=True, verbose=False)
+
+ assert os.path.isdir(movie_path)
+
+
+def test_pyrddlgymdomainsimp_cgp():
+ movie_name = "test-cgp"
+ movie_path = f"rddl_movies/{movie_name}"
+ domain_factory = lambda: RDDLDomainSimplifiedSpaces(
+ rddl_domain="Cartpole_Continuous_gym",
+ rddl_instance="0",
+ movie_name=movie_name,
+ display_with_pygame=False,
+ display_within_jupyter=False,
+ )
+ domain = domain_factory()
+ domain.reset()
+ domain.render()
+
+ solver_factory = lambda: CGP(
+ domain_factory=domain_factory, folder_name="TEMP_CGP", n_it=5, verbose=False
+ )
+
+ shutil.rmtree(movie_path)
+
+ with solver_factory() as solver:
+ solver.solve()
+ rollout(domain_factory(), solver, max_steps=100, render=True, verbose=False)
+
+ assert os.path.isdir(movie_path)
+
+
+def test_load_rddldomain():
+ cls = load_registered_domain("RDDLDomain")
+ assert cls is RDDLDomain
+ cls = load_registered_domain("RDDLDomainRL")
+ assert cls is RDDLDomainRL
+ cls = load_registered_domain("RDDLDomainSimplifiedSpaces")
+ assert cls is RDDLDomainSimplifiedSpaces
diff --git a/tests/solvers/python/test_pyrddlgym.py b/tests/solvers/python/test_pyrddlgym.py
new file mode 100644
index 0000000000..73490b169f
--- /dev/null
+++ b/tests/solvers/python/test_pyrddlgym.py
@@ -0,0 +1,71 @@
+import os
+import shutil
+from urllib.request import urlcleanup, urlretrieve
+
+import pytest
+from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator
+
+from skdecide.hub.domain.rddl import RDDLDomain
+from skdecide.hub.solver.rddl.rddl import (
+ RDDLGurobiSolver,
+ RDDLJaxSolver,
+ pyrddlgym_gurobi_available,
+)
+from skdecide.utils import load_registered_solver, rollout
+
+
+def test_pyrddlgymdomain_jax():
+ # get solver config
+ config_name = "Cartpole_Continuous_gym_drp.cfg"
+ if not os.path.exists(config_name):
+ url = f"https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/{config_name}"
+ try:
+ local_file_path, headers = urlretrieve(url)
+ shutil.move(local_file_path, config_name)
+ finally:
+ urlcleanup()
+
+ # domain factory (with proper backend and vectorized flag)
+ domain_factory = lambda: RDDLDomain(
+ rddl_domain="Cartpole_Continuous_gym",
+ rddl_instance="0",
+ backend=JaxRDDLSimulator,
+ display_with_pygame=False,
+ display_within_jupyter=False,
+ vectorized=True,
+ )
+ solver_factory = lambda: RDDLJaxSolver(
+ domain_factory=domain_factory, config=config_name
+ )
+
+ # solve
+ with solver_factory() as solver:
+ solver.solve()
+ rollout(domain_factory(), solver, max_steps=100, render=False, verbose=False)
+
+
+@pytest.mark.skipif(
+ not pyrddlgym_gurobi_available,
+ reason="You need to install pyRDDL_gurobi for this solver",
+)
+def test_pyrddlgymdomain_gurobi():
+ # domain factory (with proper backend and vectorized flag)
+ domain_factory = lambda: RDDLDomain(
+ rddl_domain="Cartpole_Continuous_gym",
+ rddl_instance="0",
+ display_with_pygame=False,
+ display_within_jupyter=False,
+ )
+ solver_factory = lambda: RDDLGurobiSolver(domain_factory=domain_factory)
+
+ # solve
+ with solver_factory() as solver:
+ solver.solve()
+ rollout(domain_factory(), solver, max_steps=100, render=False, verbose=False)
+
+
+def test_load_solvers():
+ cls = load_registered_solver("RDDLJaxSolver")
+ assert cls is RDDLJaxSolver
+ cls = load_registered_solver("RDDLGurobiSolver")
+ assert cls is RDDLGurobiSolver