diff --git a/notebooks/16_rddl_tuto.ipynb b/notebooks/16_rddl_tuto.ipynb new file mode 100644 index 0000000000..e641bbee1e --- /dev/null +++ b/notebooks/16_rddl_tuto.ipynb @@ -0,0 +1,604 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Using RDDL domains and solvers with scikit-decide" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we demonstrate how to use the RDLL scikit-decide wrapper domain in order to solve it with scikit-decide solvers. This domain is built upon the RDDL environment from the excellent pyrddlgym-project GitHub project. Some of the solvers tested here are actually also wrapped from the same project but we will see also how to use other solvers (coded directly within scikit-decide or wrapped from other third party libraries)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Concerning the python kernel to use for this notebook:\n", + "- If running locally, be sure to use an environment with scikit-decide[all].\n", + "- If running on colab, the next cell does it for you.\n", + "- If running on binder, the environment should be ready." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# On Colab: install the library\n", + "on_colab = \"google.colab\" in str(get_ipython())\n", + "if on_colab:\n", + " import glob\n", + " import json\n", + " import sys\n", + "\n", + " using_nightly_version = True\n", + "\n", + " if using_nightly_version:\n", + " # look for nightly build download url\n", + " release_curl_res = !curl -L -H \"Accept: application/vnd.github+json\" -H \"X-GitHub-Api-Version: 2022-11-28\" https://api.github.com/repos/airbus/scikit-decide/releases/tags/nightly\n", + " release_dict = json.loads(release_curl_res.s)\n", + " release_download_url = sorted(\n", + " release_dict[\"assets\"], key=lambda d: d[\"updated_at\"]\n", + " )[-1][\"browser_download_url\"]\n", + " print(release_download_url)\n", + "\n", + " # download and unzip\n", + " !wget --output-document=release.zip {release_download_url}\n", + " !unzip -o release.zip\n", + "\n", + " # get proper wheel name according to python version used\n", + " wheel_pythonversion_tag = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n", + " wheel_path = glob.glob(\n", + " f\"dist/scikit_decide*{wheel_pythonversion_tag}*manylinux*.whl\"\n", + " )[0]\n", + "\n", + " skdecide_pip_spec = f\"{wheel_path}[all]\"\n", + " else:\n", + " skdecide_pip_spec = \"scikit-decide[all]\"\n", + "\n", + " # uninstall google protobuf conflicting with ray and sb3\n", + " ! pip uninstall -y protobuf\n", + "\n", + " # install scikit-decide with all extras\n", + " !pip install {skdecide_pip_spec}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "import shutil\n", + "\n", + "from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator\n", + "from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv\n", + "from ray.rllib.algorithms.ppo import PPO as RLLIB_PPO\n", + "from rddlrepository.archive.competitions.IPPC2023.MountainCar.MountainCarViz import (\n", + " MountainCarVisualizer,\n", + ")\n", + "from rddlrepository.archive.standalone.Elevators.ElevatorViz import ElevatorVisualizer\n", + "from rddlrepository.archive.standalone.Quadcopter.QuadcopterViz import (\n", + " QuadcopterVisualizer,\n", + ")\n", + "from rddlrepository.core.manager import RDDLRepoManager\n", + "from stable_baselines3 import PPO as SB3_PPO\n", + "\n", + "from skdecide.hub.domain.rddl import RDDLDomain, RDDLDomainSimplifiedSpaces\n", + "from skdecide.hub.solver.cgp import CGP\n", + "from skdecide.hub.solver.ray_rllib import RayRLlib\n", + "from skdecide.hub.solver.rddl import RDDLGurobiSolver, RDDLJaxSolver\n", + "from skdecide.hub.solver.stable_baselines import StableBaseline\n", + "from skdecide.utils import rollout" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiating and visualizing a RDDL domain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The pyrddlgym-project provides the [rddlrepository](https://github.com/pyrddlgym-project/rddlrepository) library of RDDL benchmarks from past IPPC competitions and third-party contributors. We list below the available problems with our pip installation of the library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "manager = RDDLRepoManager(rebuild=True)\n", + "print(sorted(manager.list_problems()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will use 3 different rddl benchmarks here to demonstrate the scikit-decide integration of pyrddlgym:\n", + "- MountainCar_ippc2023\n", + "- Quadcopter\n", + "- Elevators" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create the scikit-decide `RDDLDomain` instance and render it.\n", + "Note that here we use some options to display within the notebook:\n", + "- `display_with_pygame`: True by default (as in pyRDDLGym), here set to False to avoid a pygame window to pop up\n", + "- `display_within_jupyter`: useful to display within a jupyter notebook\n", + "- `visualizer`: we use a visualizer dedicated to the chosen benchmark\n", + "- `movie_name`: if set, a movie will be created at the end of a rollout " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"MountainCar_ippc2023\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = MountainCarVisualizer\n", + "domain = RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + " movie_name=None, # here left empty because not used in a roll-out\n", + ")\n", + "domain.reset()\n", + "img = domain.render()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"Quadcopter\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = QuadcopterVisualizer\n", + "domain = RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + ")\n", + "domain.reset()\n", + "img = domain.render()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"Elevators\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = ElevatorVisualizer\n", + "domain = RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + ")\n", + "domain.reset()\n", + "img = domain.render()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solving the domain with scikit-decide (potentially bridged) solvers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now comes the fun part: solving the domain with scikit-decide solvers, some of them - especially the reinforcement learning ones - being bridged to state-of-the-art existing libraries (e.g. RLlib, SB3). You will see that once the domain is defined, solving it takes very few lines of code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### RL solvers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we create the domain factory for the benchmark \"MountainCar_ippc2023\". For these RL solvers, we need the underlying rddl env to use the base class `SimplifiedActionRDDLEnv` from [pyRDDLGym-rl](https://github.com/pyrddlgym-project/pyRDDLGym-rl), which uses gym spaces tractable by RL algorithms. This is done thanks to the argument `base_class`, which will be passed directly to `pyRDDLgym.make()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"MountainCar_ippc2023\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = MountainCarVisualizer\n", + "\n", + "domain_factory_rl = lambda alg_name=None: RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " base_class=SimplifiedActionRDDLEnv,\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### RLlib's PPO algorithm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The code below creates a scikit-decide's `RayRLlib` solver, then it calls the `solver.solve()` method, and it finally rollout the optimized policy by using scikit-decide's `rollout` utility function. The latter function will render the solution and the domain will generate a movie in the `rddl_movies` folder when reaching the termination condition of the rollout episode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver_factory = lambda: RayRLlib(\n", + " domain_factory=domain_factory_rl, algo_class=RLLIB_PPO, train_iterations=10\n", + ")\n", + "\n", + "with solver_factory() as solver:\n", + " solver.solve()\n", + " rollout(\n", + " domain_factory_rl(alg_name=\"RLLIB-PPO\"),\n", + " solver,\n", + " max_steps=300,\n", + " render=True,\n", + " verbose=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is an example of executing the RLlib's PPO policy trained for 100 iterations on the mountain car benchmark:\n", + "\n", + "![RLLIB PPO example solution](rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### StableBaselines-3's PPO" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the domain is defined, very few lines of code are sufficient to test another solver whose capabilities are compatible with the domain. In the cell below, we now test Stablebaselines-3's PPO algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver_factory = lambda: StableBaseline(\n", + " domain_factory=domain_factory_rl,\n", + " algo_class=SB3_PPO,\n", + " baselines_policy=\"MultiInputPolicy\",\n", + " learn_config={\"total_timesteps\": 10000},\n", + " verbose=0,\n", + ")\n", + "\n", + "with solver_factory() as solver:\n", + " solver.solve()\n", + " rollout(\n", + " domain_factory_rl(alg_name=\"SB3-PPO\"),\n", + " solver,\n", + " max_steps=1000,\n", + " render=True,\n", + " verbose=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CGP" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Scikit-decide provides an implementation of [Cartesian Genetic Programming](https://dl.acm.org/doi/10.1145/3205455.3205578) (CGP), a form of Genetic Programming which optimizes a function (e.g. control policy) by learning its best representation as a directed acyclic graph of mathematical operators. One of the great capabilities of scikit-decide is to provide simple high-level means to compare algorithms from different communities (RL, GP, search, planning, etc.) on the same domains with few lines of code.\n", + "\n", + "\"Cartesian" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since our current implementation of CGP in scikit-decide does not handle complex observation spaces such as the dictionary spaces returned by the RDDL simulator, we used instead `RDDLDomainSimplifiedSpaces` where all actions and observations are numpy arrays thanks to the powerful `flatten` and `flatten_space` methods of `gymnasium`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We call the CGP solver on this simplified domain and we render the obtained solution after a few iterations (including the generation of the video in the `rddl_movies` folder)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "problem_name = \"MountainCar_ippc2023\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = MountainCarVisualizer\n", + "\n", + "domain_factory_cgp = lambda alg_name=None: RDDLDomainSimplifiedSpaces(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " base_class=SimplifiedActionRDDLEnv,\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n", + " max_frames=200,\n", + ")\n", + "\n", + "if os.path.exists(\"TEMP_CGP\"):\n", + " shutil.rmtree(\"TEMP_CGP\")\n", + "\n", + "solver_factory = lambda: CGP(\n", + " domain_factory=domain_factory_cgp, folder_name=\"TEMP_CGP\", n_it=25, verbose=False\n", + ")\n", + "with solver_factory() as solver:\n", + " solver.solve()\n", + " rollout(\n", + " domain_factory_cgp(\"CGP\"), solver, max_steps=200, render=True, verbose=False\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is an example of executing the CGP policy on the mountain car benchmark:\n", + "\n", + "![CGP example solution](rddl_images/MountainCar_ippc2023-CGP_example.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solving the domain with pyRDDLGym solvers wrapped in scikit-decide\n", + "\n", + "One can also use the solvers implemented in pyRDDLGym project from within scikit-decide like the jax planner (https://github.com/pyrddlgym-project/pyRDDLGym-jax), or the gurobi planner (https://github.com/pyrddlgym-project/pyRDDLGym-gurobi)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JAX Agent\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The scikit-decide solver `RDDLJaxSolver` wraps the offline version of [JaxPlan](https://openreview.net/forum?id=7IKtmUpLEH) planner which compiles the RDDL model to a Jax computation graph allowing for planning by backpropagation. \n", + "The solver constructor takes a configuration file of the `Jax` planner as explained [here](https://github.com/pyrddlgym-project/pyRDDLGym-jax/tree/main?tab=readme-ov-file#writing-a-configuration-file-for-a-custom-domain).\n", + "\n", + "We apply it to the becnhmark \"Quadcopter\". \n", + "\n", + "Note that for this solver the domain needs\n", + "- to use the simulation backend specific to Jax,\n", + "- to be vectorized. \n", + "\n", + "This is done thanks to the arguments `backend` and `vectorized` which are passed to `pyRDDLGym.make()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"Quadcopter\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = QuadcopterVisualizer\n", + "\n", + "if not os.path.exists(\"Quadcopter_slp.cfg\"):\n", + " !wget https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg\n", + "\n", + "domain_factory_jax_agent = lambda alg_name=None: RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(1),\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + " backend=JaxRDDLSimulator,\n", + " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n", + " max_frames=500,\n", + " vectorized=True,\n", + ")\n", + "\n", + "assert RDDLJaxSolver.check_domain(domain_factory_jax_agent())\n", + "\n", + "logging.getLogger(\"matplotlib.font_manager\").disabled = True\n", + "with RDDLJaxSolver(\n", + " domain_factory=domain_factory_jax_agent, config=\"Quadcopter_slp.cfg\"\n", + ") as solver:\n", + " solver.solve()\n", + " rollout(\n", + " domain_factory_jax_agent(alg_name=\"JaxAgent\"),\n", + " solver,\n", + " max_steps=500,\n", + " render=True,\n", + " max_framerate=5,\n", + " verbose=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We obtain the following example execution of the Jax policy, which clearly converges towards the goal (quadcopters flying towards the red triangle):\n", + "\n", + "![JaxAgent example solution](rddl_images/Quadcopter-JaxAgent_example.gif)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Gurobi Agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We finally try the online version of [GurobiPlan](https://openreview.net/forum?id=7IKtmUpLEH) planner which compiles the RDDL model to a [Gurobi](https://www.gurobi.com) MILP model. \n", + "\n", + "We apply it to \"Elevators\" benchmark. \n", + "\n", + "\n", + "
Note: \n", + "The solver needs a real license for Gurobi, as the free license available when installing gurobipy from PyPi is not sufficient to solve this domain.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem_name = \"Elevators\"\n", + "problem_info = manager.get_problem(problem_name)\n", + "problem_visualizer = ElevatorVisualizer\n", + "\n", + "domain_factory_gurobi_agent = lambda alg_name=None: RDDLDomain(\n", + " rddl_domain=problem_info.get_domain(),\n", + " rddl_instance=problem_info.get_instance(0),\n", + " visualizer=problem_visualizer,\n", + " display_with_pygame=False,\n", + " display_within_jupyter=True,\n", + " movie_name=f\"{problem_name}-{alg_name}\" if alg_name is not None else None,\n", + " max_frames=50,\n", + ")\n", + "\n", + "assert RDDLGurobiSolver.check_domain(domain_factory_gurobi_agent())\n", + "\n", + "with RDDLGurobiSolver(\n", + " domain_factory=domain_factory_gurobi_agent, rollout_horizon=10\n", + ") as solver:\n", + " solver.solve()\n", + " rollout(\n", + " domain_factory_gurobi_agent(alg_name=\"GurobiAgent\"),\n", + " solver,\n", + " max_steps=50,\n", + " render=True,\n", + " max_framerate=5,\n", + " verbose=False,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is an example of executing the online `GurobiPlan` strategy on this benchmark:\n", + "\n", + "![GurobiAgent example solution](rddl_images/Elevators-GurobiAgent_example.gif)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/rddl_images/Elevators-GurobiAgent_example.gif b/notebooks/rddl_images/Elevators-GurobiAgent_example.gif new file mode 100644 index 0000000000..fccdb9678b Binary files /dev/null and b/notebooks/rddl_images/Elevators-GurobiAgent_example.gif differ diff --git a/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif b/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif new file mode 100644 index 0000000000..548fbafe5c Binary files /dev/null and b/notebooks/rddl_images/MountainCar_ippc2023-CGP_example.gif differ diff --git a/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif b/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif new file mode 100644 index 0000000000..773cb7a1b0 Binary files /dev/null and b/notebooks/rddl_images/MountainCar_ippc2023-RLLIB-PPO_example.gif differ diff --git a/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif b/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif new file mode 100644 index 0000000000..99b10f84ac Binary files /dev/null and b/notebooks/rddl_images/Quadcopter-JaxAgent_example.gif differ diff --git a/notebooks/rddl_images/cgp-sketch.png b/notebooks/rddl_images/cgp-sketch.png new file mode 100644 index 0000000000..2beec97d8c Binary files /dev/null and b/notebooks/rddl_images/cgp-sketch.png differ diff --git a/poetry.lock b/poetry.lock index 8198224a70..d6a324cfc6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -54,6 +54,21 @@ six = ">=1.12.0" astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] +[[package]] +name = "astunparse" +version = "1.6.3" +description = "An AST unparser for Python" +optional = true +python-versions = "*" +files = [ + {file = "astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8"}, + {file = "astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872"}, +] + +[package.dependencies] +six = ">=1.6.1,<2.0" +wheel = ">=0.23.0,<1.0" + [[package]] name = "atomicwrites" version = "1.4.1" @@ -83,6 +98,23 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "bayesian-optimization" +version = "2.0.0" +description = "Bayesian Optimization package" +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "bayesian_optimization-2.0.0-py3-none-any.whl", hash = "sha256:7dc2b1b78dadab9d38c07f3b7c5d7fe91a138bce0ec71e895085642ddf4c85b3"}, + {file = "bayesian_optimization-2.0.0.tar.gz", hash = "sha256:434abeb87e4a59f285641fbbfd74606fa27603dd831fff7c8984e5d10391e0d0"}, +] + +[package.dependencies] +colorama = ">=0.4.6,<0.5.0" +numpy = ">=1.25" +scikit-learn = ">=1.0.0,<2.0.0" +scipy = ">=1.0.0,<2.0.0" + [[package]] name = "cartopy" version = "0.23.0" @@ -329,6 +361,26 @@ files = [ {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, ] +[[package]] +name = "chex" +version = "0.1.87" +description = "Chex: Testing made fun, in JAX!" +optional = true +python-versions = ">=3.9" +files = [ + {file = "chex-0.1.87-py3-none-any.whl", hash = "sha256:ce536475661fd96d21be0c1728ecdbedd03f8ff950c662dfc338c92ea782cb16"}, + {file = "chex-0.1.87.tar.gz", hash = "sha256:0096d89cc8d898bb521ef4bfbf5c24549022b0e5b301f529ab57238896fe6c5d"}, +] + +[package.dependencies] +absl-py = ">=0.9.0" +jax = ">=0.4.27" +jaxlib = ">=0.4.27" +numpy = ">=1.24.1" +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +toolz = ">=0.9.0" +typing-extensions = ">=4.2.0" + [[package]] name = "click" version = "8.1.7" @@ -876,6 +928,27 @@ typing-extensions = ">=4.4" quantum = ["qiskit (>=1.0.2)", "qiskit-aer (>=0.14.1)", "qiskit-algorithms (>=0.3.0)", "qiskit-ibm-runtime (>=0.24)", "qiskit-optimization (>=0.6.1)"] test = ["optuna", "pytest", "pytest-cov", "scikit-learn (>=1.0)"] +[[package]] +name = "dm-haiku" +version = "0.0.13" +description = "Haiku is a library for building neural networks in JAX." +optional = true +python-versions = "*" +files = [ + {file = "dm_haiku-0.0.13-py3-none-any.whl", hash = "sha256:ee9562c68a059f146ad07f555ca591cb8c11ef751afecc38353863562bd23f43"}, + {file = "dm_haiku-0.0.13.tar.gz", hash = "sha256:029bb91b5b1edb0d3fe23304d3bf12a545ea6e485041f7f5d8c8d85ebcf6e17d"}, +] + +[package.dependencies] +absl-py = ">=0.7.1" +jmp = ">=0.0.2" +numpy = ">=1.18.0" +tabulate = ">=0.8.9" + +[package.extras] +flax = ["flax (>=0.7.1)"] +jax = ["jax (>=0.4.28)", "jaxlib (>=0.4.28)"] + [[package]] name = "dm-tree" version = "0.1.8" @@ -941,6 +1014,40 @@ files = [ {file = "docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491"}, ] +[[package]] +name = "etils" +version = "1.5.2" +description = "Collection of common python utils" +optional = true +python-versions = ">=3.9" +files = [ + {file = "etils-1.5.2-py3-none-any.whl", hash = "sha256:6dc882d355e1e98a5d1a148d6323679dc47c9a5792939b9de72615aa4737eb0b"}, + {file = "etils-1.5.2.tar.gz", hash = "sha256:ba6a3e1aff95c769130776aa176c11540637f5dd881f3b79172a5149b6b1c446"}, +] + +[package.dependencies] +typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""} + +[package.extras] +all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] +array-types = ["etils[enp]"] +dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] +docs = ["etils[all,dev]", "sphinx-apitree[ext]"] +eapp = ["absl-py", "etils[epy]", "simple_parsing"] +ecolab = ["etils[enp]", "etils[epy]", "jupyter", "mediapy", "numpy", "packaging"] +edc = ["etils[epy]"] +enp = ["etils[epy]", "numpy"] +epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] +epath-gcs = ["etils[epath]", "gcsfs"] +epath-s3 = ["etils[epath]", "s3fs"] +epy = ["typing_extensions"] +etqdm = ["absl-py", "etils[epy]", "tqdm"] +etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] +etree-dm = ["dm-tree", "etils[etree]"] +etree-jax = ["etils[etree]", "jax[cpu]"] +etree-tf = ["etils[etree]", "tensorflow"] +lazy-imports = ["etils[ecolab]"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1010,6 +1117,17 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "flatbuffers" +version = "24.3.25" +description = "The FlatBuffers serialization format for Python" +optional = true +python-versions = "*" +files = [ + {file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"}, + {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, +] + [[package]] name = "fonttools" version = "4.54.1" @@ -1221,6 +1339,99 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "gast" +version = "0.6.0" +description = "Python AST that abstracts the underlying Python version" +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +files = [ + {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"}, + {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, +] + +[[package]] +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" +optional = true +python-versions = "*" +files = [ + {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"}, + {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"}, + {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, +] + +[package.dependencies] +six = "*" + +[[package]] +name = "grpcio" +version = "1.67.0" +description = "HTTP/2-based RPC framework" +optional = true +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.67.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:bd79929b3bb96b54df1296cd3bf4d2b770bd1df6c2bdf549b49bab286b925cdc"}, + {file = "grpcio-1.67.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:16724ffc956ea42967f5758c2f043faef43cb7e48a51948ab593570570d1e68b"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:2b7183c80b602b0ad816315d66f2fb7887614ead950416d60913a9a71c12560d"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efe32b45dd6d118f5ea2e5deaed417d8a14976325c93812dd831908522b402c9"}, + {file = "grpcio-1.67.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe89295219b9c9e47780a0f1c75ca44211e706d1c598242249fe717af3385ec8"}, + {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:aa8d025fae1595a207b4e47c2e087cb88d47008494db258ac561c00877d4c8f8"}, + {file = "grpcio-1.67.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f95e15db43e75a534420e04822df91f645664bf4ad21dfaad7d51773c80e6bb4"}, + {file = "grpcio-1.67.0-cp310-cp310-win32.whl", hash = "sha256:a6b9a5c18863fd4b6624a42e2712103fb0f57799a3b29651c0e5b8119a519d65"}, + {file = "grpcio-1.67.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6eb68493a05d38b426604e1dc93bfc0137c4157f7ab4fac5771fd9a104bbaa6"}, + {file = "grpcio-1.67.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:e91d154689639932305b6ea6f45c6e46bb51ecc8ea77c10ef25aa77f75443ad4"}, + {file = "grpcio-1.67.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cb204a742997277da678611a809a8409657b1398aaeebf73b3d9563b7d154c13"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:ae6de510f670137e755eb2a74b04d1041e7210af2444103c8c95f193340d17ee"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74b900566bdf68241118f2918d312d3bf554b2ce0b12b90178091ea7d0a17b3d"}, + {file = "grpcio-1.67.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e95e43447a02aa603abcc6b5e727d093d161a869c83b073f50b9390ecf0fa8"}, + {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0bb94e66cd8f0baf29bd3184b6aa09aeb1a660f9ec3d85da615c5003154bc2bf"}, + {file = "grpcio-1.67.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:82e5bd4b67b17c8c597273663794a6a46a45e44165b960517fe6d8a2f7f16d23"}, + {file = "grpcio-1.67.0-cp311-cp311-win32.whl", hash = "sha256:7fc1d2b9fd549264ae585026b266ac2db53735510a207381be509c315b4af4e8"}, + {file = "grpcio-1.67.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac11ecb34a86b831239cc38245403a8de25037b448464f95c3315819e7519772"}, + {file = "grpcio-1.67.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:227316b5631260e0bef8a3ce04fa7db4cc81756fea1258b007950b6efc90c05d"}, + {file = "grpcio-1.67.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d90cfdafcf4b45a7a076e3e2a58e7bc3d59c698c4f6470b0bb13a4d869cf2273"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:77196216d5dd6f99af1c51e235af2dd339159f657280e65ce7e12c1a8feffd1d"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15c05a26a0f7047f720da41dc49406b395c1470eef44ff7e2c506a47ac2c0591"}, + {file = "grpcio-1.67.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3840994689cc8cbb73d60485c594424ad8adb56c71a30d8948d6453083624b52"}, + {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5a1e03c3102b6451028d5dc9f8591131d6ab3c8a0e023d94c28cb930ed4b5f81"}, + {file = "grpcio-1.67.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:682968427a63d898759474e3b3178d42546e878fdce034fd7474ef75143b64e3"}, + {file = "grpcio-1.67.0-cp312-cp312-win32.whl", hash = "sha256:d01793653248f49cf47e5695e0a79805b1d9d4eacef85b310118ba1dfcd1b955"}, + {file = "grpcio-1.67.0-cp312-cp312-win_amd64.whl", hash = "sha256:985b2686f786f3e20326c4367eebdaed3e7aa65848260ff0c6644f817042cb15"}, + {file = "grpcio-1.67.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:8c9a35b8bc50db35ab8e3e02a4f2a35cfba46c8705c3911c34ce343bd777813a"}, + {file = "grpcio-1.67.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:42199e704095b62688998c2d84c89e59a26a7d5d32eed86d43dc90e7a3bd04aa"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:c4c425f440fb81f8d0237c07b9322fc0fb6ee2b29fbef5f62a322ff8fcce240d"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:323741b6699cd2b04a71cb38f502db98f90532e8a40cb675393d248126a268af"}, + {file = "grpcio-1.67.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:662c8e105c5e5cee0317d500eb186ed7a93229586e431c1bf0c9236c2407352c"}, + {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f6bd2ab135c64a4d1e9e44679a616c9bc944547357c830fafea5c3caa3de5153"}, + {file = "grpcio-1.67.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:2f55c1e0e2ae9bdd23b3c63459ee4c06d223b68aeb1961d83c48fb63dc29bc03"}, + {file = "grpcio-1.67.0-cp313-cp313-win32.whl", hash = "sha256:fd6bc27861e460fe28e94226e3673d46e294ca4673d46b224428d197c5935e69"}, + {file = "grpcio-1.67.0-cp313-cp313-win_amd64.whl", hash = "sha256:cf51d28063338608cd8d3cd64677e922134837902b70ce00dad7f116e3998210"}, + {file = "grpcio-1.67.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:7f200aca719c1c5dc72ab68be3479b9dafccdf03df530d137632c534bb6f1ee3"}, + {file = "grpcio-1.67.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0892dd200ece4822d72dd0952f7112c542a487fc48fe77568deaaa399c1e717d"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:f4d613fbf868b2e2444f490d18af472ccb47660ea3df52f068c9c8801e1f3e85"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c69bf11894cad9da00047f46584d5758d6ebc9b5950c0dc96fec7e0bce5cde9"}, + {file = "grpcio-1.67.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9bca3ca0c5e74dea44bf57d27e15a3a3996ce7e5780d61b7c72386356d231db"}, + {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:014dfc020e28a0d9be7e93a91f85ff9f4a87158b7df9952fe23cc42d29d31e1e"}, + {file = "grpcio-1.67.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d4ea4509d42c6797539e9ec7496c15473177ce9abc89bc5c71e7abe50fc25737"}, + {file = "grpcio-1.67.0-cp38-cp38-win32.whl", hash = "sha256:9d75641a2fca9ae1ae86454fd25d4c298ea8cc195dbc962852234d54a07060ad"}, + {file = "grpcio-1.67.0-cp38-cp38-win_amd64.whl", hash = "sha256:cff8e54d6a463883cda2fab94d2062aad2f5edd7f06ae3ed030f2a74756db365"}, + {file = "grpcio-1.67.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:62492bd534979e6d7127b8a6b29093161a742dee3875873e01964049d5250a74"}, + {file = "grpcio-1.67.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eef1dce9d1a46119fd09f9a992cf6ab9d9178b696382439446ca5f399d7b96fe"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:f623c57a5321461c84498a99dddf9d13dac0e40ee056d884d6ec4ebcab647a78"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54d16383044e681f8beb50f905249e4e7261dd169d4aaf6e52eab67b01cbbbe2"}, + {file = "grpcio-1.67.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2a44e572fb762c668e4812156b81835f7aba8a721b027e2d4bb29fb50ff4d33"}, + {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:391df8b0faac84d42f5b8dfc65f5152c48ed914e13c522fd05f2aca211f8bfad"}, + {file = "grpcio-1.67.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cfd9306511fdfc623a1ba1dc3bc07fbd24e6cfbe3c28b4d1e05177baa2f99617"}, + {file = "grpcio-1.67.0-cp39-cp39-win32.whl", hash = "sha256:30d47dbacfd20cbd0c8be9bfa52fdb833b395d4ec32fe5cff7220afc05d08571"}, + {file = "grpcio-1.67.0-cp39-cp39-win_amd64.whl", hash = "sha256:f55f077685f61f0fbd06ea355142b71e47e4a26d2d678b3ba27248abfe67163a"}, + {file = "grpcio-1.67.0.tar.gz", hash = "sha256:e090b2553e0da1c875449c8e75073dd4415dd71c9bde6a406240fdf4c0ee467c"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.67.0)"] + [[package]] name = "gymnasium" version = "0.28.1" @@ -1254,6 +1465,44 @@ other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-pyt testing = ["pytest (==7.1.3)", "scipy (==1.7.3)"] toy-text = ["pygame (==2.1.3)", "pygame (==2.1.3)"] +[[package]] +name = "h5py" +version = "3.12.1" +description = "Read and write HDF5 files from Python" +optional = true +python-versions = ">=3.9" +files = [ + {file = "h5py-3.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f0f1a382cbf494679c07b4371f90c70391dedb027d517ac94fa2c05299dacda"}, + {file = "h5py-3.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cb65f619dfbdd15e662423e8d257780f9a66677eae5b4b3fc9dca70b5fd2d2a3"}, + {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2"}, + {file = "h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307"}, + {file = "h5py-3.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:577d618d6b6dea3da07d13cc903ef9634cde5596b13e832476dd861aaf651f3e"}, + {file = "h5py-3.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ccd9006d92232727d23f784795191bfd02294a4f2ba68708825cb1da39511a93"}, + {file = "h5py-3.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad8a76557880aed5234cfe7279805f4ab5ce16b17954606cca90d578d3e713ef"}, + {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e"}, + {file = "h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166"}, + {file = "h5py-3.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fdf95092d60e8130ba6ae0ef7a9bd4ade8edbe3569c13ebbaf39baefffc5ba4"}, + {file = "h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed"}, + {file = "h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834"}, + {file = "h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9"}, + {file = "h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc"}, + {file = "h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32"}, + {file = "h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f"}, + {file = "h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8"}, + {file = "h5py-3.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d2b8dd64f127d8b324f5d2cd1c0fd6f68af69084e9e47d27efeb9e28e685af3e"}, + {file = "h5py-3.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4532c7e97fbef3d029735db8b6f5bf01222d9ece41e309b20d63cfaae2fb5c4d"}, + {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdf6d7936fa824acfa27305fe2d9f39968e539d831c5bae0e0d83ed521ad1ac"}, + {file = "h5py-3.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84342bffd1f82d4f036433e7039e241a243531a1d3acd7341b35ae58cdab05bf"}, + {file = "h5py-3.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:62be1fc0ef195891949b2c627ec06bc8e837ff62d5b911b6e42e38e0f20a897d"}, + {file = "h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf"}, +] + +[package.dependencies] +numpy = ">=1.19.3" + [[package]] name = "idna" version = "3.10" @@ -1438,6 +1687,41 @@ qtconsole = ["qtconsole"] test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"] +[[package]] +name = "jax" +version = "0.4.30" +description = "Differentiate, compile, and transform Numpy code." +optional = true +python-versions = ">=3.9" +files = [ + {file = "jax-0.4.30-py3-none-any.whl", hash = "sha256:289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b"}, + {file = "jax-0.4.30.tar.gz", hash = "sha256:94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} +jaxlib = ">=0.4.27,<=0.4.30" +ml-dtypes = ">=0.2.0" +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.22", markers = "python_version < \"3.11\""}, +] +opt-einsum = "*" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, +] + +[package.extras] +ci = ["jaxlib (==0.4.29)"] +cuda = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] +cuda12 = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] +cuda12-local = ["jax-cuda12-plugin (==0.4.30)", "jaxlib (==0.4.30)"] +cuda12-pip = ["jax-cuda12-plugin[with-cuda] (==0.4.30)", "jaxlib (==0.4.30)"] +minimum-jaxlib = ["jaxlib (==0.4.27)"] +tpu = ["jaxlib (==0.4.30)", "libtpu-nightly (==0.1.dev20240617)", "requests"] + [[package]] name = "jax-jumpy" version = "1.0.0" @@ -1456,6 +1740,43 @@ numpy = ">=1.18.0" jax = ["jax (>=0.3.24)", "jaxlib (>=0.3.24)"] testing = ["pytest (==7.1.3)"] +[[package]] +name = "jaxlib" +version = "0.4.30" +description = "XLA library for JAX" +optional = true +python-versions = ">=3.9" +files = [ + {file = "jaxlib-0.4.30-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c40856e28f300938c6824ab1a615166193d6997dec946578823f6d402ad454e5"}, + {file = "jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bdfda6a3c7a2b0cc0a7131009eb279e98ca4a6f25679fabb5302dd135a5e349"}, + {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:28e032c9b394ab7624d89b0d9d3bbcf4d1d71694fe8b3e09d3fe64122eda7b0c"}, + {file = "jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d83f36ef42a403bbf7c7f2da526b34ba286988e170f4df5e58b3bb735417868c"}, + {file = "jaxlib-0.4.30-cp310-cp310-win_amd64.whl", hash = "sha256:a56678b28f96b524ded6da8ef4b38e72a532356d139cfd434da804abf4234e14"}, + {file = "jaxlib-0.4.30-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:bfb5d85b69c29c3c6e8051a0ea715ac1e532d6e54494c8d9c3813dcc00deac30"}, + {file = "jaxlib-0.4.30-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:974998cd8a78550402e6c09935c1f8d850cad9cc19ccd7488bde45b6f7f99c12"}, + {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e93eb0646b41ba213252b51b0b69096b9cd1d81a35ea85c9d06663b5d11efe45"}, + {file = "jaxlib-0.4.30-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:16b2ab18ea90d2e15941bcf45de37afc2f289a029129c88c8d7aba0404dd0043"}, + {file = "jaxlib-0.4.30-cp311-cp311-win_amd64.whl", hash = "sha256:3a2e2c11c179f8851a72249ba1ae40ae817dfaee9877d23b3b8f7c6b7a012f76"}, + {file = "jaxlib-0.4.30-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7704db5962b32a2be3cc07185433cbbcc94ed90ee50c84021a3f8a1ecfd66ee3"}, + {file = "jaxlib-0.4.30-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:57090d33477fd0f0c99dc686274882ea75c44c7d712ae42dd2460b10f896131d"}, + {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:0a3850e76278038e21685975a62b622bcf3708485f13125757a0561ee4512940"}, + {file = "jaxlib-0.4.30-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:c58a8071c4e00898282118169f6a5a97eb15a79c2897858f3a732b17891c99ab"}, + {file = "jaxlib-0.4.30-cp312-cp312-win_amd64.whl", hash = "sha256:b7079a5b1ab6864a7d4f2afaa963841451186d22c90f39719a3ff85735ce3915"}, + {file = "jaxlib-0.4.30-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ea3a00005faafbe3c18b178d3b534208b3b4027b2be6230227e7b87ce399fc29"}, + {file = "jaxlib-0.4.30-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3d31e01191ce8052bd611aaf16ff967d8d0ec0b63f1ea4b199020cecb248d667"}, + {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:11602d5556e8baa2f16314c36518e9be4dfae0c2c256a361403fb29dc9dc79a4"}, + {file = "jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:f74a6b0e09df4b5e2ee399ebb9f0e01190e26e84ccb0a758fadb516415c07f18"}, + {file = "jaxlib-0.4.30-cp39-cp39-win_amd64.whl", hash = "sha256:54987e97a22db70f3829b437b9329e4799d653634bacc8b398554d3b90c76b2a"}, +] + +[package.dependencies] +ml-dtypes = ">=0.2.0" +numpy = ">=1.22" +scipy = [ + {version = ">=1.11.1", markers = "python_version >= \"3.12\""}, + {version = ">=1.9", markers = "python_version < \"3.12\""}, +] + [[package]] name = "jedi" version = "0.19.1" @@ -1492,6 +1813,23 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmp" +version = "0.0.4" +description = "JMP is a Mixed Precision library for JAX." +optional = true +python-versions = "*" +files = [ + {file = "jmp-0.0.4-py3-none-any.whl", hash = "sha256:6aa7adbddf2bd574b28c7faf6e81a735eb11f53386447896909c6968dc36807d"}, + {file = "jmp-0.0.4.tar.gz", hash = "sha256:5dfeb0fd7c7a9f72a70fff0aab9d0cbfae32a809c02f4037ff3485ceb33e1730"}, +] + +[package.dependencies] +numpy = ">=1.19.5" + +[package.extras] +jax = ["jax (>=0.2.20)", "jaxlib (>=0.1.71)"] + [[package]] name = "joblib" version = "1.4.2" @@ -1581,6 +1919,27 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest (<8)", "pytest-cov", "pytest-timeout"] +[[package]] +name = "keras" +version = "3.6.0" +description = "Multi-backend Keras." +optional = true +python-versions = ">=3.9" +files = [ + {file = "keras-3.6.0-py3-none-any.whl", hash = "sha256:49585e4577f6e86bd890d96dfbcb1890f5bab5967ef831c07fd63f9d86e4bfe9"}, + {file = "keras-3.6.0.tar.gz", hash = "sha256:405727525a3522ed8f9ec0b46e0667e4c65fcf714a067322c16a00d902ded41d"}, +] + +[package.dependencies] +absl-py = "*" +h5py = "*" +ml-dtypes = "*" +namex = "*" +numpy = "*" +optree = "*" +packaging = "*" +rich = "*" + [[package]] name = "kiwisolver" version = "1.4.7" @@ -1723,6 +2082,25 @@ dev = ["changelist (==0.5)"] lint = ["pre-commit (==3.7.0)"] test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] +[[package]] +name = "libclang" +version = "18.1.1" +description = "Clang Python Bindings, mirrored from the official LLVM repo: https://github.com/llvm/llvm-project/tree/main/clang/bindings/python, to make the installation process easier." +optional = true +python-versions = "*" +files = [ + {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"}, + {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"}, + {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"}, + {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"}, + {file = "libclang-18.1.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592"}, + {file = "libclang-18.1.1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe"}, + {file = "libclang-18.1.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f"}, + {file = "libclang-18.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb"}, + {file = "libclang-18.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8"}, + {file = "libclang-18.1.1.tar.gz", hash = "sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250"}, +] + [[package]] name = "llvmlite" version = "0.43.0" @@ -1803,6 +2181,24 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] flake8 = ["flake8"] tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] +[[package]] +name = "markdown" +version = "3.7" +description = "Python implementation of John Gruber's Markdown." +optional = true +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"}, + {file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2000,6 +2396,43 @@ files = [ [package.extras] dzn = ["lark-parser (>=0.12.0,<0.13.0)"] +[[package]] +name = "ml-dtypes" +version = "0.4.1" +description = "" +optional = true +python-versions = ">=3.9" +files = [ + {file = "ml_dtypes-0.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1fe8b5b5e70cd67211db94b05cfd58dace592f24489b038dc6f9fe347d2e07d5"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c09a6d11d8475c2a9fd2bc0695628aec105f97cab3b3a3fb7c9660348ff7d24"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5e8f75fa371020dd30f9196e7d73babae2abd51cf59bdd56cb4f8de7e13354"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:15fdd922fea57e493844e5abb930b9c0bd0af217d9edd3724479fc3d7ce70e3f"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2d55b588116a7085d6e074cf0cdb1d6fa3875c059dddc4d2c94a4cc81c23e975"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e138a9b7a48079c900ea969341a5754019a1ad17ae27ee330f7ebf43f23877f9"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74c6cfb5cf78535b103fde9ea3ded8e9f16f75bc07789054edc7776abfb3d752"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:274cc7193dd73b35fb26bef6c5d40ae3eb258359ee71cd82f6e96a8c948bdaa6"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e35e486e97aee577d0890bc3bd9e9f9eece50c08c163304008587ec8cfe7575b"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:560be16dc1e3bdf7c087eb727e2cf9c0e6a3d87e9f415079d2491cc419b3ebf5"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad0b757d445a20df39035c4cdeed457ec8b60d236020d2560dbc25887533cf50"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef0d7e3fece227b49b544fa69e50e607ac20948f0043e9f76b44f35f229ea450"}, + {file = "ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">1.20", markers = "python_version < \"3.10\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "more-itertools" version = "10.5.0" @@ -2129,6 +2562,17 @@ files = [ [package.dependencies] dill = ">=0.3.9" +[[package]] +name = "namex" +version = "0.0.8" +description = "A simple utility to separate the implementation of your Python package and its public API surface." +optional = true +python-versions = "*" +files = [ + {file = "namex-0.0.8-py3-none-any.whl", hash = "sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487"}, + {file = "namex-0.0.8.tar.gz", hash = "sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b"}, +] + [[package]] name = "nbclient" version = "0.6.8" @@ -2469,6 +2913,151 @@ pandas = ">=1.2" pyyaml = ">=5.1" scipy = ">=1.7" +[[package]] +name = "opt-einsum" +version = "3.4.0" +description = "Path optimization of einsum functions." +optional = true +python-versions = ">=3.8" +files = [ + {file = "opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd"}, + {file = "opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac"}, +] + +[[package]] +name = "optax" +version = "0.2.3" +description = "A gradient processing and optimisation library in JAX." +optional = true +python-versions = ">=3.9" +files = [ + {file = "optax-0.2.3-py3-none-any.whl", hash = "sha256:083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f"}, + {file = "optax-0.2.3.tar.gz", hash = "sha256:ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9"}, +] + +[package.dependencies] +absl-py = ">=0.7.1" +chex = ">=0.1.86" +etils = {version = "*", extras = ["epy"]} +jax = ">=0.4.27" +jaxlib = ">=0.4.27" +numpy = ">=1.18.0" + +[package.extras] +docs = ["flax", "ipython (>=8.8.0)", "matplotlib (>=3.5.0)", "myst-nb (>=1.0.0)", "sphinx (>=6.0.0)", "sphinx-autodoc-typehints", "sphinx-book-theme (>=1.0.1)", "sphinx-collections (>=0.0.1)", "sphinx-gallery (>=0.14.0)", "sphinx_contributors", "sphinxcontrib-katex", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] +dp-accounting = ["absl-py (>=1.0.0)", "attrs (>=21.4.0)", "mpmath (>=1.2.1)", "numpy (>=1.21.4)", "scipy (>=1.7.1)"] +examples = ["dp_accounting (>=0.4)", "flax", "ipywidgets", "tensorflow (>=2.4.0)", "tensorflow-datasets (>=4.2.0)"] +test = ["dm-tree (>=0.1.7)", "flax (>=0.5.3)", "scikit-learn", "scipy (>=1.7.1)"] + +[[package]] +name = "optree" +version = "0.13.0" +description = "Optimized PyTree Utilities." +optional = true +python-versions = ">=3.7" +files = [ + {file = "optree-0.13.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7b8fe0442ac5e50b5e6bceb37dcc2cd4908e7716b869cbe6b8901cc0b489884f"}, + {file = "optree-0.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a1aab34de5ac7673fbfb94266bf10482be51985c7f899c3e767ce19d13ce3b4"}, + {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2c79961d5afeb20557c30a0ae899d14ff58cdf1c0e2c8aa3d6807600d00f619"}, + {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb55eb77541cf280a009829280d5844936dc8a2e4a3eb069c010a1f547dbfe97"}, + {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44449e3bc5e7530b50c9a1f5bcf2971ffe317e34edd74d8c9778c5d32078114d"}, + {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b4195a6ba2052c70bac6d73f19aa69644424c5a30fa09f7319cc1b59e15acb6"}, + {file = "optree-0.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7fecc701ece0500fe38fc671b5704d904e2dca9a9284b35263b0bd7e5c62527"}, + {file = "optree-0.13.0-cp310-cp310-win32.whl", hash = "sha256:46a9e66217fdf421e25c133089c94f8f99bc38a2b5a4a2c0c1e0c1b02b01dda4"}, + {file = "optree-0.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:ef68fdcb3b1743a46210f3c888cd15668a07422aa10b4d4130ba512aac595bf7"}, + {file = "optree-0.13.0-cp310-cp310-win_arm64.whl", hash = "sha256:d12a5665169abceb878d50b55571d6a7690bf97aaaf9a7f5438b10e474fde3f2"}, + {file = "optree-0.13.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:92d1c34b6022bedee4b3899f3a9a1105777da11a9abf1a51f4d84bed8f037fa1"}, + {file = "optree-0.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d05c320af21efbc132fe887640f7a2dbb36cfb38af6d4e62396fe104b78f7b72"}, + {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53ae0a0eb128a69a74db4165e7e5f24d54e2711678622198f7073dcb991962f"}, + {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89f08fc3724b2fe7a081b69dfd3ad6625960443e1f61a984cae7c627776f12f4"}, + {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f22f4e46d85f24b5bc49e68043dd754b258b880ac64d72f4f4b9ac1b11f0fb2f"}, + {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbc884f3eab894126398c120b7f92a72a5b9f92db6d8c27d39087da871c642cd"}, + {file = "optree-0.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c58b94669c9072d645e02c0c65c7455f8f136ef8f7b56a5d9123847421f95b"}, + {file = "optree-0.13.0-cp311-cp311-win32.whl", hash = "sha256:54be625517ef3cf52905da7fee63795b2f154dbdb02b37e8cfd63e7fb2f266ea"}, + {file = "optree-0.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:e3d100890a643e12f39de4226ab4f9d0a22842b4f34ae2964d0149419e4d7aff"}, + {file = "optree-0.13.0-cp311-cp311-win_arm64.whl", hash = "sha256:cb8d9a2cebc5fadde98773bb27809a72ff01d11f1037cb58f8e71e740586223e"}, + {file = "optree-0.13.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:abeb8acc83d168063b70168ccf8dfd55f5a7ce50f9af2ca025c41285781ecdd4"}, + {file = "optree-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4771266f05e99e94312add38d45bf97a4d98449aeab100f5c658c521152eb5e5"}, + {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc95c1d0c7acd534184bf3ba243a454e0942e4a7c8b9edd32d939fc15e33d753"}, + {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e48491e042f956d4232ebc138e07074100878c0080e3ba10af4c2db1ba4df9f"}, + {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e001d9c902e98912503eca66c93d4b4b22f5071e4ab777f4db9e140f35288f4"}, + {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:87870346278f46a8c22866ff48716590be35b4aea16e1373e695fb6442c28c41"}, + {file = "optree-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7797c54a35e9d89b4664ec7d542745b87b5ffa9c1201c1062fdcd488eb583390"}, + {file = "optree-0.13.0-cp312-cp312-win32.whl", hash = "sha256:fc90a5373c92f4a9babb4c40fe148516f52160c0ba803bc9b2f936367f2f7437"}, + {file = "optree-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:1bc65743e8edb29e902cab894d1c4665a8fd6f8d10f75db68a2cef6c7246fa5c"}, + {file = "optree-0.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:de2729e1e4ae47a07ac3c70ff977ed1ebe19e7b44d5089075c94f7a9a2dc6f4f"}, + {file = "optree-0.13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dda6efabd0621f53eb46a3789ec89c6fd2c90dfb57aebfce3fcda6eab9ed6a7e"}, + {file = "optree-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5de8da9bbdd08b6200244ee818cd15d1da0f2b06ac926dba0e686260bac7fd40"}, + {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca1e4854134023ba687a7abf45ed3355f773ca7198b6895d88a89030446a9f2e"}, + {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1ac5343e921ce21f8f10f91158ad6404a1488c1cc22ddfa6b34cfb9d997cebd"}, + {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e282212ddf3aafb10ca6ca223772e06ea3c31687c9cae192467b8e0a7dafbfc"}, + {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:24fcd4cb659bcd9b675bc3401950de891b32a047c4787857fb870cd515fcc315"}, + {file = "optree-0.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d735a7d2d2e2eb9a88d932d35b335c10fae9038034f381b6d437dafed46497e"}, + {file = "optree-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ef01e79224f0ee6cf2ca642884f0bc04e446227b96dc576c312717eb33552d57"}, + {file = "optree-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:d3f61fb669b36c1a714346b18c9c488ad33a58049b7b229785c241de18c005d7"}, + {file = "optree-0.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:695b3f1aab50519230e3d8d86abaedaadf91af105b569cce3b8ebe0dc612b312"}, + {file = "optree-0.13.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:1318434b0740a2325c197e191e6dd53d9df0a8ac0338c67d58b476aad9d07829"}, + {file = "optree-0.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d58c6e8d4c4fa4e0c31bc4b876960ccba94eb5fcfb045f2b064ce55707034be9"}, + {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6a290ba771cc9004f9fc194d23ab11ee4aae71550ca874c3dc985af5b5f910b"}, + {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c95488ecbab2916de094e68f2a2c55c9475b2e979c03d91a6cd3565f9e5ff2f9"}, + {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f76a65ff322b3d47af2a23f60409d6d8f184804da551c734e355834e69c0dfb"}, + {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58cc303f982fb0f23644b7f8e98b4f64b0d031365fcc2284da896e96493176d2"}, + {file = "optree-0.13.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6866b6e4154303dc7c48c7ca3b867a8ce31d469334b67976dfc0513455aa1ca0"}, + {file = "optree-0.13.0-cp313-cp313t-win32.whl", hash = "sha256:f5ce67f81fe3d7ca5fed8fdaf93a762a63e1d125e20e425ca7200f9e54a3e3a6"}, + {file = "optree-0.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0008cd39169c1fc10870528b2decfea8b79e61042c12d65a964f3b1cf41cc37d"}, + {file = "optree-0.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:539962675b547957c64b52b7f82178febb9c0f2d47438b810bbc23cfdcf84821"}, + {file = "optree-0.13.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b08e4873814d11aa25ef3927c848b9e5cf21215b925e83875b9fe11c7a035b0e"}, + {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6e236c6601480997c6e1dbbd4ab2b7ea0bc82a9a7baa1f681a1b072c9c02677"}, + {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:557b415b41006cca88d86ad190b795455e9334d3cf5838e63c4c668a65227ccb"}, + {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11b78c8a18894fe9503515d073a60ebaed366aeb3cfa65e61e7e71ae833f640b"}, + {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4207f6fa0bd4730f5496772c139f1444b2b69e4eeb0f454e2100b5a380648f70"}, + {file = "optree-0.13.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe9fd84b7d87f365f720699dedd254882ba7e5ef927d3ba1e13413d45963b691"}, + {file = "optree-0.13.0-cp37-cp37m-win32.whl", hash = "sha256:c0f9f250f617f114061ab718d460be6be8e0a1cbbfdbbfb5541ed1c8fefee150"}, + {file = "optree-0.13.0-cp37-cp37m-win_amd64.whl", hash = "sha256:5cf612aefe0201a2995763cce82b9cd03cbddd2bfd6f8975f910c091dfa7bb5f"}, + {file = "optree-0.13.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:46623259b10f6e3565ea0d37e0b313feb20484bccb005459b3504e1aa706b730"}, + {file = "optree-0.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e7f9184c6040365e79a0b900507c289b6a4e06ade3c9691e501d176d5cf775cf"}, + {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6201c065791422a73d5aeb4916e00879de9b097cf54526f82b5b3c297126d938"}, + {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a423897010c6d8490097671d907da1b6ee90d3fa783aaad5e36e46e0a73bc5e"}, + {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1fb74282fce108e07972e88dbc23f6b7650c2d3bbddbedc2002d3e0becb1c452"}, + {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94ecab158521225b20e44d67c8afc2d9af6760985a9f489d21bf2aa8bbe467f8"}, + {file = "optree-0.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8244d0fbfe1ef15ffb443f3d32a44aff062adbef0a7fd6db3f011999df966223"}, + {file = "optree-0.13.0-cp38-cp38-win32.whl", hash = "sha256:0a34c11d637cb01217828e28eef382c621c9ec53f981d8ccbfe56e0a11cda501"}, + {file = "optree-0.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:ebe56c17bf3754335307b17be7f554c5eae47acf738471cf38dba0ec73a42c37"}, + {file = "optree-0.13.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e9c619a49984212e5f757e10d5e5f95888b0c08d67a7f2b9f395cede30712dc2"}, + {file = "optree-0.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:50a9e2d9ffff99d45b37289a3422ed3723a45225616f5b48cea606ff0f539c0f"}, + {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d702dbcafcd16e8925e30c0e780ab3dc81450e19008fd3e77494111fc161a2b2"}, + {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f44a58f87059161f300e2be66ad3878fff540d27f5dcd69b21feae65c243a02"}, + {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:954899edc024f13079932418f59bbdadabc52d9dcb49c7b559c382c7be352dfc"}, + {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c736ce6f4b8857bd171f3682ef849e3d67692c3fc4db42b99c5d2c7cc1bdf11"}, + {file = "optree-0.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7941d3bd48d860d0e17ca24827b5233ea27bb4227e822eafb3897df1f43f8342"}, + {file = "optree-0.13.0-cp39-cp39-win32.whl", hash = "sha256:9f6fc47c9b10d1a9e77163ebd6f2e251af41fab895475d2ce9643423a41899af"}, + {file = "optree-0.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:246020f0be50fb66791d8a25c4acb59ad0b4bbdea71c998e375eba4c58fbc3e0"}, + {file = "optree-0.13.0-cp39-cp39-win_arm64.whl", hash = "sha256:069bf166b7aa48ccf8dfe76b920d2115dd8261107c7895d02500b2ce39621b40"}, + {file = "optree-0.13.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:496170a3d093a7fb69be7ce847f5b5b3aa30a6da81457ba6b54268e6e97c6b13"}, + {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73543a82be71c041d5b169754089a58d02063eb72ac8688533b6fc26ab6beea8"}, + {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:278e2c620df99f5b1477b375b01cf9658528fa0332c0bc431d3ec65857244094"}, + {file = "optree-0.13.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36b32155dce29edb6f63a99a44d6da2d8fcd1c56353cc2f4af65f793a0b2712f"}, + {file = "optree-0.13.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c98a43204841cc4698155acb523d7b21a78f8b05666704359e0fddecd5d1043d"}, + {file = "optree-0.13.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5c2803d4ef257f2599cffd0e9d60cfb3d4c522abbe8f5a839bd48d8edd26dae7"}, + {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3b454f98d28a89c15a1170e771c61902cbc53eed126db36138b684dba5a729"}, + {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b74afed3db289228e0f95a8909835365f644eb69ff31cd6c0b45608ca9e56d78"}, + {file = "optree-0.13.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc3cebfd7d0826d223662f01ed0fa25932edf3f62479be13c4d6ff0fab090c34"}, + {file = "optree-0.13.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:5703637ede6fba04cbeabbb47aada7d17606c2d4df73305063f4a3c829c21fc7"}, + {file = "optree-0.13.0.tar.gz", hash = "sha256:1ea493cde8c60f7950ccbd682bd67e787bf67ed2251d6d3e9ad7471b72d37538"}, +] + +[package.dependencies] +typing-extensions = ">=4.5.0" + +[package.extras] +benchmark = ["dm-tree (>=0.1,<0.2.0a0)", "jax[cpu] (>=0.4.6,<0.5.0a0)", "pandas", "tabulate", "termcolor", "torch (>=2.0,<2.4.0a0)", "torchvision"] +docs = ["docutils", "jax[cpu]", "numpy", "sphinx", "sphinx-autoapi", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx-copybutton", "sphinx-rtd-theme", "sphinxcontrib-bibtex", "torch"] +jax = ["jax"] +lint = ["black", "cpplint", "doc8", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "flake8-pyi", "flake8-simplify", "isort", "mypy", "pre-commit", "pydocstyle", "pyenchant", "pylint[spelling]", "ruff", "xdoctest"] +numpy = ["numpy"] +test = ["pytest", "pytest-cov", "pytest-xdist"] +torch = ["torch"] + [[package]] name = "ortools" version = "9.10.4067" @@ -2574,9 +3163,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2776,6 +3365,17 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "ply" +version = "3.11" +description = "Python Lex & Yacc" +optional = true +python-versions = "*" +files = [ + {file = "ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce"}, + {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"}, +] + [[package]] name = "pox" version = "0.3.5" @@ -2903,52 +3503,55 @@ files = [ [[package]] name = "pyarrow" -version = "17.0.0" +version = "18.0.0" description = "Python library for Apache Arrow" optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, - {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, - {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, - {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, - {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, - {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, - {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, - {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, + {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2333f93260674e185cfbf208d2da3007132572e56871f451ba1a556b45dae6e2"}, + {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4c381857754da44326f3a49b8b199f7f87a51c2faacd5114352fc78de30d3aba"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:603cd8ad4976568954598ef0a6d4ed3dfb78aff3d57fa8d6271f470f0ce7d34f"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58a62549a3e0bc9e03df32f350e10e1efb94ec6cf63e3920c3385b26663948ce"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bc97316840a349485fbb137eb8d0f4d7057e1b2c1272b1a20eebbbe1848f5122"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2e549a748fa8b8715e734919923f69318c953e077e9c02140ada13e59d043310"}, + {file = "pyarrow-18.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:606e9a3dcb0f52307c5040698ea962685fb1c852d72379ee9412be7de9c5f9e2"}, + {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d5795e37c0a33baa618c5e054cd61f586cf76850a251e2b21355e4085def6280"}, + {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:5f0510608ccd6e7f02ca8596962afb8c6cc84c453e7be0da4d85f5f4f7b0328a"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616ea2826c03c16e87f517c46296621a7c51e30400f6d0a61be645f203aa2b93"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1824f5b029ddd289919f354bc285992cb4e32da518758c136271cf66046ef22"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd1b52d0d58dd8f685ced9971eb49f697d753aa7912f0a8f50833c7a7426319"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:320ae9bd45ad7ecc12ec858b3e8e462578de060832b98fc4d671dee9f10d9954"}, + {file = "pyarrow-18.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2c992716cffb1088414f2b478f7af0175fd0a76fea80841b1706baa8fb0ebaad"}, + {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:e7ab04f272f98ebffd2a0661e4e126036f6936391ba2889ed2d44c5006237802"}, + {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:03f40b65a43be159d2f97fd64dc998f769d0995a50c00f07aab58b0b3da87e1f"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be08af84808dff63a76860847c48ec0416928a7b3a17c2f49a072cac7c45efbd"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70c1965cde991b711a98448ccda3486f2a336457cf4ec4dca257a926e149c9"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:00178509f379415a3fcf855af020e3340254f990a8534294ec3cf674d6e255fd"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a71ab0589a63a3e987beb2bc172e05f000a5c5be2636b4b263c44034e215b5d7"}, + {file = "pyarrow-18.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe92efcdbfa0bcf2fa602e466d7f2905500f33f09eb90bf0bcf2e6ca41b574c8"}, + {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:907ee0aa8ca576f5e0cdc20b5aeb2ad4d3953a3b4769fc4b499e00ef0266f02f"}, + {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:66dcc216ebae2eb4c37b223feaf82f15b69d502821dde2da138ec5a3716e7463"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc1daf7c425f58527900876354390ee41b0ae962a73ad0959b9d829def583bb1"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:871b292d4b696b09120ed5bde894f79ee2a5f109cb84470546471df264cae136"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:082ba62bdcb939824ba1ce10b8acef5ab621da1f4c4805e07bfd153617ac19d4"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2c664ab88b9766413197733c1720d3dcd4190e8fa3bbdc3710384630a0a7207b"}, + {file = "pyarrow-18.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc892be34dbd058e8d189b47db1e33a227d965ea8805a235c8a7286f7fd17d3a"}, + {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:28f9c39a56d2c78bf6b87dcc699d520ab850919d4a8c7418cd20eda49874a2ea"}, + {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:f1a198a50c409ab2d009fbf20956ace84567d67f2c5701511d4dd561fae6f32e"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5bd7fd32e3ace012d43925ea4fc8bd1b02cc6cc1e9813b518302950e89b5a22"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:336addb8b6f5208be1b2398442c703a710b6b937b1a046065ee4db65e782ff5a"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:45476490dd4adec5472c92b4d253e245258745d0ccaabe706f8d03288ed60a79"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b46591222c864e7da7faa3b19455196416cd8355ff6c2cc2e65726a760a3c420"}, + {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb7e3abcda7e1e6b83c2dc2909c8d045881017270a119cc6ee7fdcfe71d02df8"}, + {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:09f30690b99ce34e0da64d20dab372ee54431745e4efb78ac938234a282d15f9"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5ca5d707e158540312e09fd907f9f49bacbe779ab5236d9699ced14d2293b8"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6331f280c6e4521c69b201a42dd978f60f7e129511a55da9e0bfe426b4ebb8d"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3ac24b2be732e78a5a3ac0b3aa870d73766dd00beba6e015ea2ea7394f8b4e55"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b30a927c6dff89ee702686596f27c25160dd6c99be5bcc1513a763ae5b1bfc03"}, + {file = "pyarrow-18.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:8f40ec677e942374e3d7f2fad6a67a4c2811a8b975e8703c6fd26d3b168a90e2"}, + {file = "pyarrow-18.0.0.tar.gz", hash = "sha256:a6aa027b1a9d2970cf328ccd6dbe4a996bc13c39fd427f502782f5bdb9ca20f5"}, ] -[package.dependencies] -numpy = ">=1.16.6" - [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] @@ -3271,6 +3874,61 @@ files = [ [package.dependencies] certifi = "*" +[[package]] +name = "pyrddlgym" +version = "2.0" +description = "pyRDDLGym: RDDL automatic generation tool for OpenAI Gym" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyRDDLGym-2.0-py3-none-any.whl", hash = "sha256:dfbd3a17d7b6843ecd2ca211d31384ce28bdcdc568c5f612b5c8e49ea789a340"}, + {file = "pyRDDLGym-2.0.tar.gz", hash = "sha256:9df637574a40a915f986f55700107ca823542f7e7b7d16d7bcc151b6ec126ad1"}, +] + +[package.dependencies] +gymnasium = "*" +matplotlib = ">=3.5.0" +numpy = ">=1.22" +pillow = ">=9.2.0" +ply = "*" +pygame = "*" +termcolor = "*" + +[[package]] +name = "pyrddlgym-jax" +version = "0.3" +description = "pyRDDLGym-jax: JAX compilation of RDDL description files, and a differentiable planner in JAX." +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyRDDLGym_jax-0.3-py3-none-any.whl", hash = "sha256:4908a0f5e3850a24ea5a75455723e0d9c219e4beed82b2fc738287c232dbc18b"}, + {file = "pyrddlgym_jax-0.3.tar.gz", hash = "sha256:c5dd35ff2cd871135cb035f5266e9247b97863806f2b5b31b74ad6c28377cb24"}, +] + +[package.dependencies] +bayesian-optimization = ">=1.4.3" +dm-haiku = ">=0.0.10" +jax = ">=0.4.12" +optax = ">=0.1.9" +pyRDDLGym = ">=2.0" +tensorflow = ">=2.13.0" +tensorflow-probability = ">=0.21.0" +tqdm = ">=4.66" + +[[package]] +name = "pyrddlgym-rl" +version = "0.1" +description = "pyRDDLGym-rl: Wrappers for reinforcement learning algorithms (i.e. stable baselines 3) to work with pyRDDLGym." +optional = true +python-versions = ">=3.8" +files = [ + {file = "pyRDDLGym-rl-0.1.tar.gz", hash = "sha256:2f1c9b420bc0f2418e0c44fcf587950471366f6c5fce9587295d5f5dd1cc1c33"}, + {file = "pyRDDLGym_rl-0.1-py3-none-any.whl", hash = "sha256:1baa8a5bfb72e7a582aeb40803ca9792bc55d2d5fe1e3ee087d6ea1053b719de"}, +] + +[package.dependencies] +pyRDDLGym = ">=2.0" + [[package]] name = "pyshp" version = "2.3.1" @@ -3654,6 +4312,20 @@ serve-grpc = ["aiohttp (>=3.7)", "aiohttp-cors", "colorful", "fastapi", "grpcio train = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] tune = ["fsspec", "pandas", "pyarrow (>=6.0.1)", "requests", "tensorboardX (>=1.9)"] +[[package]] +name = "rddlrepository" +version = "2.0" +description = "Home for all things RDDL" +optional = true +python-versions = ">=3.8" +files = [ + {file = "rddlrepository-2.0-py3-none-any.whl", hash = "sha256:ed993b010b535128de15aefb7d4507d220df9b68dbc25c5bf0315830184817cb"}, + {file = "rddlrepository-2.0.tar.gz", hash = "sha256:570b6b0044f20b013aac59b44519f5a554fc1097a5f0f68a5110fd7c07977cab"}, +] + +[package.dependencies] +numpy = "*" + [[package]] name = "referencing" version = "0.35.1" @@ -3869,6 +4541,56 @@ docs = ["PyWavelets (>=1.1.1)", "dask[array] (>=2022.9.2)", "ipykernel", "ipywid optional = ["PyWavelets (>=1.1.1)", "SimpleITK", "astropy (>=5.0)", "cloudpickle (>=0.2.1)", "dask[array] (>=2021.1.0)", "matplotlib (>=3.6)", "pooch (>=1.6.0)", "pyamg", "scikit-learn (>=1.1)"] test = ["asv", "numpydoc (>=1.7)", "pooch (>=1.6.0)", "pytest (>=7.0)", "pytest-cov (>=2.11.0)", "pytest-doctestplus", "pytest-faulthandler", "pytest-localserver"] +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = true +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + [[package]] name = "scipy" version = "1.13.1" @@ -4115,6 +4837,54 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + +[[package]] +name = "tensorboard" +version = "2.18.0" +description = "TensorBoard lets you watch Tensors Flow" +optional = true +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +packaging = "*" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + [[package]] name = "tensorboardx" version = "2.6.2.2" @@ -4131,6 +4901,139 @@ numpy = "*" packaging = "*" protobuf = ">=3.20" +[[package]] +name = "tensorflow" +version = "2.18.0" +description = "TensorFlow is an open source machine learning framework for everyone." +optional = true +python-versions = ">=3.9" +files = [ + {file = "tensorflow-2.18.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:8da90a9388a1f6dd00d626590d2b5810faffbb3e7367f9783d80efff882340ee"}, + {file = "tensorflow-2.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:589342fb9bdcab2e9af0f946da4ca97757677e297d934fcdc087e87db99d6353"}, + {file = "tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1eb77fae50d699442726d1b23c7512c97cd688cc7d857b028683d4535bbf3709"}, + {file = "tensorflow-2.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:46f5a8b4e6273f488dc069fc3ac2211b23acd3d0437d919349c787fa341baa8a"}, + {file = "tensorflow-2.18.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:453cb60638a02fd26316fb36c8cbcf1569d33671f17c658ca0cf2b4626f851e7"}, + {file = "tensorflow-2.18.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85f1e7369af6d329b117b52e86093cd1e0458dd5404bf5b665853f873dd00b48"}, + {file = "tensorflow-2.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b8dd70fa3600bfce66ab529eebb804e1f9d7c863d2f71bc8fe9fc7a1ec3976"}, + {file = "tensorflow-2.18.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e8b0f499ef0b7652480a58e358a73844932047f21c42c56f7f3bdcaf0803edc"}, + {file = "tensorflow-2.18.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ec4133a215c59314e929e7cbe914579d3afbc7874d9fa924873ee633fe4f71d0"}, + {file = "tensorflow-2.18.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4822904b3559d8a9c25f0fe5fef191cfc1352ceca42ca64f2a7bc7ae0ff4a1f5"}, + {file = "tensorflow-2.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfdd65ea7e064064283dd78d529dd621257ee617218f63681935fd15817c6286"}, + {file = "tensorflow-2.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:a701c2d3dca5f2efcab315b2c217f140ebd3da80410744e87d77016b3aaf53cb"}, + {file = "tensorflow-2.18.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:336cace378c129c20fee6292f6a541165073d153a9a4c9cf4f14478a81895776"}, + {file = "tensorflow-2.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcfd32134de8f95515b2d0ced89cdae15484b787d3a21893e9291def06c10c4e"}, + {file = "tensorflow-2.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada1f7290c75b34748ee7378c1b77927e4044c94b8dc72dc75e7667c4fdaeb94"}, + {file = "tensorflow-2.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:f8c946df1cb384504578fac1c199a95322373b8e04abd88aa8ae01301df469ea"}, +] + +[package.dependencies] +absl-py = ">=1.0.0" +astunparse = ">=1.6.0" +flatbuffers = ">=24.3.25" +gast = ">=0.2.1,<0.5.0 || >0.5.0,<0.5.1 || >0.5.1,<0.5.2 || >0.5.2" +google-pasta = ">=0.1.1" +grpcio = ">=1.24.3,<2.0" +h5py = ">=3.11.0" +keras = ">=3.5.0" +libclang = ">=13.0.0" +ml-dtypes = ">=0.4.0,<0.5.0" +numpy = ">=1.26.0,<2.1.0" +opt-einsum = ">=2.3.2" +packaging = "*" +protobuf = ">=3.20.3,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" +requests = ">=2.21.0,<3" +setuptools = "*" +six = ">=1.12.0" +tensorboard = ">=2.18,<2.19" +tensorflow-io-gcs-filesystem = {version = ">=0.23.1", markers = "python_version < \"3.12\""} +termcolor = ">=1.1.0" +typing-extensions = ">=3.6.6" +wrapt = ">=1.11.0" + +[package.extras] +and-cuda = ["nvidia-cublas-cu12 (==12.5.3.2)", "nvidia-cuda-cupti-cu12 (==12.5.82)", "nvidia-cuda-nvcc-cu12 (==12.5.82)", "nvidia-cuda-nvrtc-cu12 (==12.5.82)", "nvidia-cuda-runtime-cu12 (==12.5.82)", "nvidia-cudnn-cu12 (==9.3.0.75)", "nvidia-cufft-cu12 (==11.2.3.61)", "nvidia-curand-cu12 (==10.3.6.82)", "nvidia-cusolver-cu12 (==11.6.3.83)", "nvidia-cusparse-cu12 (==12.5.1.3)", "nvidia-nccl-cu12 (==2.21.5)", "nvidia-nvjitlink-cu12 (==12.5.82)"] + +[[package]] +name = "tensorflow-io-gcs-filesystem" +version = "0.37.1" +description = "TensorFlow IO" +optional = true +python-versions = "<3.13,>=3.7" +files = [ + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed"}, + {file = "tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95"}, +] + +[package.extras] +tensorflow = ["tensorflow (>=2.16.0,<2.17.0)"] +tensorflow-aarch64 = ["tensorflow-aarch64 (>=2.16.0,<2.17.0)"] +tensorflow-cpu = ["tensorflow-cpu (>=2.16.0,<2.17.0)"] +tensorflow-gpu = ["tensorflow-gpu (>=2.16.0,<2.17.0)"] +tensorflow-rocm = ["tensorflow-rocm (>=2.16.0,<2.17.0)"] + +[[package]] +name = "tensorflow-probability" +version = "0.24.0" +description = "Probabilistic modeling and statistical inference in TensorFlow" +optional = true +python-versions = ">=3.9" +files = [ + {file = "tensorflow_probability-0.24.0-py2.py3-none-any.whl", hash = "sha256:8c1774683e38359dbcaf3697e79b7e6a4e69b9c7b3679e78ee18f43e59e5759b"}, +] + +[package.dependencies] +absl-py = "*" +cloudpickle = ">=1.3" +decorator = "*" +dm-tree = "*" +gast = ">=0.3.2" +numpy = ">=1.13.3" +six = ">=1.10.0" + +[package.extras] +jax = ["jax", "jaxlib"] +tf = ["tensorflow (>=2.16)", "tf-keras (>=2.16)"] +tfds = ["tensorflow-datasets (>=2.2.0)"] + +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = true +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = true +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tifffile" version = "2024.8.30" @@ -4164,6 +5067,17 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] +[[package]] +name = "toolz" +version = "1.0.0" +description = "List processing tools and functional utilities" +optional = true +python-versions = ">=3.8" +files = [ + {file = "toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236"}, + {file = "toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02"}, +] + [[package]] name = "torch" version = "2.5.0" @@ -4445,6 +5359,23 @@ files = [ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] +[[package]] +name = "werkzeug" +version = "3.0.6" +description = "The comprehensive WSGI web application library." +optional = true +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, + {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "wheel" version = "0.44.0" @@ -4459,6 +5390,85 @@ files = [ [package.extras] test = ["pytest (>=6.0.0)", "setuptools (>=65)"] +[[package]] +name = "wrapt" +version = "1.16.0" +description = "Module for decorators, wrappers and monkey patching." +optional = true +python-versions = ">=3.6" +files = [ + {file = "wrapt-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4"}, + {file = "wrapt-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487"}, + {file = "wrapt-1.16.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0"}, + {file = "wrapt-1.16.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136"}, + {file = "wrapt-1.16.0-cp310-cp310-win32.whl", hash = "sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d"}, + {file = "wrapt-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09"}, + {file = "wrapt-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060"}, + {file = "wrapt-1.16.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956"}, + {file = "wrapt-1.16.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d"}, + {file = "wrapt-1.16.0-cp311-cp311-win32.whl", hash = "sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362"}, + {file = "wrapt-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"}, + {file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"}, + {file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"}, + {file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"}, + {file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"}, + {file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"}, + {file = "wrapt-1.16.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c"}, + {file = "wrapt-1.16.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e"}, + {file = "wrapt-1.16.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465"}, + {file = "wrapt-1.16.0-cp36-cp36m-win32.whl", hash = "sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e"}, + {file = "wrapt-1.16.0-cp36-cp36m-win_amd64.whl", hash = "sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966"}, + {file = "wrapt-1.16.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5"}, + {file = "wrapt-1.16.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f"}, + {file = "wrapt-1.16.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win32.whl", hash = "sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c"}, + {file = "wrapt-1.16.0-cp37-cp37m-win_amd64.whl", hash = "sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0"}, + {file = "wrapt-1.16.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e"}, + {file = "wrapt-1.16.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca"}, + {file = "wrapt-1.16.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6"}, + {file = "wrapt-1.16.0-cp38-cp38-win32.whl", hash = "sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b"}, + {file = "wrapt-1.16.0-cp38-cp38-win_amd64.whl", hash = "sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2"}, + {file = "wrapt-1.16.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c"}, + {file = "wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f"}, + {file = "wrapt-1.16.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537"}, + {file = "wrapt-1.16.0-cp39-cp39-win32.whl", hash = "sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3"}, + {file = "wrapt-1.16.0-cp39-cp39-win_amd64.whl", hash = "sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35"}, + {file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"}, + {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, +] + [[package]] name = "zipp" version = "3.20.2" @@ -4479,11 +5489,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pygeodesy", "pygrib", "pygrib", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] -domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pygeodesy", "pygrib", "pygrib", "scipy", "unified-planning"] -solvers = ["gymnasium", "joblib", "numpy", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +all = ["cartopy", "gymnasium", "joblib", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym-jax", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "ray", "rddlrepository", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] +domains = ["cartopy", "gymnasium", "matplotlib", "numpy", "openap", "pyRDDLGym", "pyRDDLGym-rl", "pygeodesy", "pygrib", "pygrib", "rddlrepository", "scipy", "unified-planning"] +solvers = ["gymnasium", "joblib", "numpy", "pyRDDLGym-jax", "ray", "scipy", "stable-baselines3", "unified-planning", "up-enhsp", "up-fast-downward", "up-pyperplan", "up-tamer"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "0a37a3952616b8038d4f715c75d98aeb903691b1d93b9cfe268b15b09334ebb3" +content-hash = "55a5472c7d8ca6d30cd149a54055e1a2322d87320c3f970fde56898801645455" diff --git a/pyproject.toml b/pyproject.toml index 1c57043c1f..c53fd3d598 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scikit-decide" -version = "0.0.0" # placeholder for poetry-dynamic-versioning +version = "1.0.3.dev5+90a1f41c" # placeholder for poetry-dynamic-versioning description = "The AI framework for Reinforcement Learning, Automated Planning and Scheduling" authors = ["Airbus AI Research "] license = "MIT" @@ -34,7 +34,7 @@ script = "builder.py" generate-setup-file = true [tool.poetry-dynamic-versioning] -enable = true +enable = false vcs = "git" format-jinja = """ {%- if distance == 0 -%} @@ -71,6 +71,10 @@ pygrib = [ { version = "<=2.1.5", platform = "linux", optional = true }, { version = ">=2.1.5", platform = "darwin", optional = true }, ] +pyRDDLGym = { version = ">=2.0, <2.1", optional = true } +pyRDDLGym-rl = { version = ">=0.1", optional = true } +pyRDDLGym-jax = { version = ">=0.3", optional = true } +rddlrepository = {version = ">=2.0", optional = true } [tool.poetry.extras] domains = [ @@ -82,7 +86,10 @@ domains = [ "unified-planning", "cartopy", "pygrib", - "scipy" + "scipy", + "pyRDDLGym", + "pyRDDLGym-rl", + "rddlrepository" ] solvers = [ "gymnasium", @@ -95,7 +102,8 @@ solvers = [ "up-fast-downward", "up-enhsp", "up-pyperplan", - "scipy" + "scipy", + "pyRDDLGym-jax" ] all = [ "gymnasium", @@ -113,7 +121,11 @@ all = [ "up-pyperplan", "cartopy", "pygrib", - "scipy" + "scipy", + "pyRDDLGym", + "pyRDDLGym-rl", + "rddlrepository", + "pyRDDLGym-jax" ] [tool.poetry.plugins."skdecide.domains"] @@ -136,6 +148,9 @@ Stochastic_RCPSP = "skdecide.hub.domain.rcpsp:Stochastic_RCPSP [domains]" SMRCPSPCalendar = "skdecide.hub.domain.rcpsp:SMRCPSPCalendar [domains]" MSRCPSP = "skdecide.hub.domain.rcpsp:MSRCPSP [domains]" MSRCPSPCalendar = "skdecide.hub.domain.rcpsp:MSRCPSPCalendar [domains]" +RDDLDomain = "skdecide.hub.domain.rddl:RDDLDomain [domains]" +RDDLDomainRL = "skdecide.hub.domain.rddl:RDDLDomainRL [domains]" +RDDLDomainSimplifiedSpaces = "skdecide.hub.domain.rddl:RDDLDomainSimplifiedSpaces [domains]" [tool.poetry.plugins."skdecide.solvers"] AOstar = "skdecide.hub.solver.aostar:AOstar [solvers]" @@ -162,6 +177,8 @@ DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]" GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]" PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]" UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]" +RDDLJaxSolver = "skdecide.hub.solver.rddl:RDDLJaxSolver [solvers]" +RDDLGurobiSolver = "skdecide.hub.solver.rddl:RDDLGurobiSolver [solvers]" [tool.poetry.dev-dependencies] pytest = "^6.2.2" diff --git a/skdecide/hub/domain/rddl/__init__.py b/skdecide/hub/domain/rddl/__init__.py new file mode 100644 index 0000000000..d9de2e0c34 --- /dev/null +++ b/skdecide/hub/domain/rddl/__init__.py @@ -0,0 +1 @@ +from .rddl import RDDLDomain, RDDLDomainRL, RDDLDomainSimplifiedSpaces diff --git a/skdecide/hub/domain/rddl/rddl.py b/skdecide/hub/domain/rddl/rddl.py new file mode 100644 index 0000000000..5c9fe549fc --- /dev/null +++ b/skdecide/hub/domain/rddl/rddl.py @@ -0,0 +1,179 @@ +import os +import shutil +from datetime import datetime +from typing import Any + +import numpy as np +import pyRDDLGym +from gymnasium.spaces.utils import flatten, flatten_space +from pyRDDLGym import RDDLEnv +from pyRDDLGym.core.simulator import RDDLSimulator +from pyRDDLGym.core.visualizer.chart import ChartVisualizer +from pyRDDLGym.core.visualizer.movie import MovieGenerator +from pyRDDLGym.core.visualizer.viz import BaseViz +from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv + +from skdecide.builders.domain import FullyObservable, Renderable, UnrestrictedActions +from skdecide.core import Space, TransitionOutcome, Value +from skdecide.domains import RLDomain +from skdecide.hub.space.gym import GymSpace + +try: + import IPython +except ImportError: + ipython_available = False +else: + ipython_available = True + from IPython.display import clear_output, display + + +class D(RLDomain, UnrestrictedActions, FullyObservable, Renderable): + T_state = dict[str, Any] # Type of states + T_observation = T_state # Type of observations + T_event = np.array # Type of events + T_value = float # Type of transition values (rewards or costs) + T_info = None # Type of additional information in environment outcome + + +class RDDLDomain(D): + def __init__( + self, + rddl_domain: str, + rddl_instance: str, + base_class: type[RDDLEnv] = RDDLEnv, + backend: type[RDDLSimulator] = RDDLSimulator, + display_with_pygame: bool = True, + display_within_jupyter: bool = False, + visualizer: BaseViz = ChartVisualizer, + movie_name: str = None, + movie_dir: str = "rddl_movies", + max_frames=1000, + enforce_action_constraints=True, + **kwargs + ): + self.rddl_gym_env = pyRDDLGym.make( + rddl_domain, + rddl_instance, + base_class=base_class, + backend=backend, + enforce_action_constraints=enforce_action_constraints, + **kwargs + ) + self.display_within_jupyter = display_within_jupyter + self.display_with_pygame = display_with_pygame + self.movie_name = movie_name + self._nb_step = 0 + if movie_name is not None: + self.movie_path = os.path.join(movie_dir, movie_name) + if not os.path.exists(self.movie_path): + os.makedirs(self.movie_path) + tmp_pngs = os.path.join(self.movie_path, "tmp_pngs") + if os.path.exists(tmp_pngs): + shutil.rmtree(tmp_pngs) + os.makedirs(tmp_pngs) + self.movie_gen = MovieGenerator(tmp_pngs, movie_name, max_frames=max_frames) + self.rddl_gym_env.set_visualizer(visualizer, self.movie_gen) + else: + self.movie_gen = None + self.rddl_gym_env.set_visualizer(visualizer) + + def _state_step( + self, action: D.T_event + ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: + next_state, reward, terminated, truncated, _ = self.rddl_gym_env.step(action) + termination = terminated or truncated + if self.movie_gen is not None and ( + termination or self._nb_step >= self.movie_gen.max_frames - 1 + ): + self.movie_gen.save_animation(self.movie_name) + tmp_pngs = os.path.join(self.movie_path, "tmp_pngs") + shutil.move( + os.path.join(tmp_pngs, self.movie_name + ".gif"), + os.path.join( + self.movie_path, + self.movie_name + + "_" + + str(datetime.now().strftime("%Y%m%d-%H%M%S")) + + ".gif", + ), + ) + self._nb_step += 1 + # TransitionOutcome and Value are scikit-decide types + return TransitionOutcome( + state=next_state, value=Value(reward=reward), termination=termination + ) + + def _get_action_space_(self) -> Space[D.T_event]: + # Cast to skdecide's GymSpace + return GymSpace(self.rddl_gym_env.action_space) + + def _state_reset(self) -> D.T_state: + self._nb_step = 0 + # SkDecide only needs the state, not the info + return self.rddl_gym_env.reset()[0] + + def _get_observation_space_(self) -> Space[D.T_observation]: + # Cast to skdecide's GymSpace + return GymSpace(self.rddl_gym_env.observation_space) + + def _render_from(self, memory: D.T_state = None, **kwargs: Any) -> Any: + # We do not want the image to be displayed in a pygame window, but rather in this notebook + rddl_gym_img = self.rddl_gym_env.render(to_display=self.display_with_pygame) + if self.display_within_jupyter and ipython_available: + clear_output(wait=True) + display(rddl_gym_img) + return rddl_gym_img + + +class RDDLDomainRL(RDDLDomain): + def __init__( + self, + rddl_domain: str, + rddl_instance: str, + base_class: type[RDDLEnv] = SimplifiedActionRDDLEnv, + backend: type[RDDLSimulator] = RDDLSimulator, + display_with_pygame: bool = True, + display_within_jupyter: bool = False, + visualizer: BaseViz = ChartVisualizer, + movie_name: str = None, + movie_dir: str = "rddl_movies", + max_frames=1000, + enforce_action_constraints=True, + **kwargs + ): + super().__init__( + rddl_domain=rddl_domain, + rddl_instance=rddl_instance, + base_class=base_class, + backend=backend, + display_with_pygame=display_with_pygame, + display_within_jupyter=display_within_jupyter, + visualizer=visualizer, + movie_name=movie_name, + movie_dir=movie_dir, + max_frames=max_frames, + enforce_action_constraints=enforce_action_constraints, + **kwargs + ) + + +class RDDLDomainSimplifiedSpaces(RDDLDomainRL): + def _state_step( + self, action: D.T_event + ) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]: + outcome = super()._state_step(action) + return TransitionOutcome( + state=flatten(self.rddl_gym_env.observation_space, outcome.state), + value=outcome.value, + termination=outcome.termination, + ) + + def _get_action_space_(self) -> Space[D.T_event]: + return GymSpace(flatten_space(self.rddl_gym_env.action_space)) + + def _state_reset(self) -> D.T_state: + # SkDecide only needs the state, not the info + return flatten(self.rddl_gym_env.observation_space, super()._state_reset()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + return GymSpace(flatten_space(self.rddl_gym_env.observation_space)) diff --git a/skdecide/hub/solver/rddl/__init__.py b/skdecide/hub/solver/rddl/__init__.py new file mode 100644 index 0000000000..052f974b5e --- /dev/null +++ b/skdecide/hub/solver/rddl/__init__.py @@ -0,0 +1 @@ +from .rddl import RDDLGurobiSolver, RDDLJaxSolver diff --git a/skdecide/hub/solver/rddl/rddl.py b/skdecide/hub/solver/rddl/rddl.py new file mode 100644 index 0000000000..8775256c13 --- /dev/null +++ b/skdecide/hub/solver/rddl/rddl.py @@ -0,0 +1,117 @@ +from collections.abc import Callable +from typing import Any, Optional + +from pyRDDLGym_jax.core.planner import ( + JaxBackpropPlanner, + JaxOfflineController, + JaxOnlineController, + load_config, +) +from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator + +from skdecide import Solver +from skdecide.builders.solver import FromInitialState, Policies +from skdecide.hub.domain.rddl import RDDLDomain + +try: + from pyRDDLGym_gurobi.core.planner import ( + GurobiOnlineController, + GurobiPlan, + GurobiStraightLinePlan, + ) +except ImportError: + pyrddlgym_gurobi_available = False +else: + pyrddlgym_gurobi_available = True + + +class D(RDDLDomain): + pass + + +class RDDLJaxSolver(Solver, Policies, FromInitialState): + T_domain = D + + def __init__( + self, domain_factory: Callable[[], RDDLDomain], config: Optional[str] = None + ): + Solver.__init__(self, domain_factory=domain_factory) + self._domain = domain_factory() + if config is not None: + self.planner_args, _, self.train_args = load_config(config) + + @classmethod + def _check_domain_additional(cls, domain: D) -> bool: + return hasattr(domain, "rddl_gym_env") + + def _solve(self, from_memory: Optional[D.T_state] = None) -> None: + planner = JaxBackpropPlanner( + rddl=self._domain.rddl_gym_env.model, + **(self.planner_args if self.planner_args is not None else {}) + ) + self.controller = JaxOfflineController( + planner, **(self.train_args if self.train_args is not None else {}) + ) + + def _sample_action(self, observation: D.T_observation) -> D.T_event: + return self.controller.sample_action(observation) + + def _is_policy_defined_for(self, observation: D.T_observation) -> bool: + return True + + +if pyrddlgym_gurobi_available: + + class D(RDDLDomain): + pass + + class RDDLGurobiSolver(Solver, Policies, FromInitialState): + T_domain = D + + def __init__( + self, + domain_factory: Callable[[], RDDLDomain], + plan: Optional[GurobiPlan] = None, + rollout_horizon=5, + model_params: Optional[dict[str, Any]] = None, + ): + Solver.__init__(self, domain_factory=domain_factory) + self._domain = domain_factory() + self.rollout_horizon = rollout_horizon + if plan is None: + self.plan = GurobiStraightLinePlan() + else: + self.plan = plan + if model_params is None: + self.model_params = {"NonConvex": 2, "OutputFlag": 0} + else: + self.model_params = model_params + + @classmethod + def _check_domain_additional(cls, domain: D) -> bool: + return hasattr(domain, "rddl_gym_env") + + def _solve(self, from_memory: Optional[D.T_state] = None) -> None: + self.controller = GurobiOnlineController( + rddl=self._domain.rddl_gym_env.model, + plan=self.plan, + rollout_horizon=self.rollout_horizon, + model_params=self.model_params, + ) + + def _sample_action(self, observation: D.T_observation) -> D.T_event: + return self.controller.sample_action(observation) + + def _is_policy_defined_for(self, observation: D.T_observation) -> bool: + return True + +else: + + class RDDLGurobiSolver(Solver, Policies, FromInitialState): + T_domain = D + + def __init__(self, domain_factory: Callable[[], RDDLDomain], rollout_horizon=5): + raise RuntimeError( + "You need pyRDDLGym-gurobi installed for this solver. " + "See https://github.com/pyrddlgym-project/pyRDDLGym-gurobi for more information." + ) diff --git a/tests/domains/python/test_pyrddlgym.py b/tests/domains/python/test_pyrddlgym.py new file mode 100644 index 0000000000..48a2761af4 --- /dev/null +++ b/tests/domains/python/test_pyrddlgym.py @@ -0,0 +1,80 @@ +import os +import shutil + +from stable_baselines3 import PPO as SB3_PPO + +from skdecide.hub.domain.rddl import ( + RDDLDomain, + RDDLDomainRL, + RDDLDomainSimplifiedSpaces, +) +from skdecide.hub.solver.cgp import CGP +from skdecide.hub.solver.stable_baselines import StableBaseline +from skdecide.utils import load_registered_domain, rollout + + +def test_pyrddlgymdomain_sb3(): + movie_name = "test-sb3" + movie_path = f"rddl_movies/{movie_name}" + domain_factory = lambda: RDDLDomainRL( + rddl_domain="Cartpole_Continuous_gym", + rddl_instance="0", + movie_name=movie_name, + display_with_pygame=False, + display_within_jupyter=False, + ) + domain = domain_factory() + domain.reset() + domain.render() + + solver_factory = lambda: StableBaseline( + domain_factory=domain_factory, + algo_class=SB3_PPO, + baselines_policy="MultiInputPolicy", + learn_config={"total_timesteps": 100}, + verbose=0, + ) + + shutil.rmtree(movie_path) + + with solver_factory() as solver: + solver.solve() + rollout(domain_factory(), solver, max_steps=100, render=True, verbose=False) + + assert os.path.isdir(movie_path) + + +def test_pyrddlgymdomainsimp_cgp(): + movie_name = "test-cgp" + movie_path = f"rddl_movies/{movie_name}" + domain_factory = lambda: RDDLDomainSimplifiedSpaces( + rddl_domain="Cartpole_Continuous_gym", + rddl_instance="0", + movie_name=movie_name, + display_with_pygame=False, + display_within_jupyter=False, + ) + domain = domain_factory() + domain.reset() + domain.render() + + solver_factory = lambda: CGP( + domain_factory=domain_factory, folder_name="TEMP_CGP", n_it=5, verbose=False + ) + + shutil.rmtree(movie_path) + + with solver_factory() as solver: + solver.solve() + rollout(domain_factory(), solver, max_steps=100, render=True, verbose=False) + + assert os.path.isdir(movie_path) + + +def test_load_rddldomain(): + cls = load_registered_domain("RDDLDomain") + assert cls is RDDLDomain + cls = load_registered_domain("RDDLDomainRL") + assert cls is RDDLDomainRL + cls = load_registered_domain("RDDLDomainSimplifiedSpaces") + assert cls is RDDLDomainSimplifiedSpaces diff --git a/tests/solvers/python/test_pyrddlgym.py b/tests/solvers/python/test_pyrddlgym.py new file mode 100644 index 0000000000..73490b169f --- /dev/null +++ b/tests/solvers/python/test_pyrddlgym.py @@ -0,0 +1,71 @@ +import os +import shutil +from urllib.request import urlcleanup, urlretrieve + +import pytest +from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator + +from skdecide.hub.domain.rddl import RDDLDomain +from skdecide.hub.solver.rddl.rddl import ( + RDDLGurobiSolver, + RDDLJaxSolver, + pyrddlgym_gurobi_available, +) +from skdecide.utils import load_registered_solver, rollout + + +def test_pyrddlgymdomain_jax(): + # get solver config + config_name = "Cartpole_Continuous_gym_drp.cfg" + if not os.path.exists(config_name): + url = f"https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/{config_name}" + try: + local_file_path, headers = urlretrieve(url) + shutil.move(local_file_path, config_name) + finally: + urlcleanup() + + # domain factory (with proper backend and vectorized flag) + domain_factory = lambda: RDDLDomain( + rddl_domain="Cartpole_Continuous_gym", + rddl_instance="0", + backend=JaxRDDLSimulator, + display_with_pygame=False, + display_within_jupyter=False, + vectorized=True, + ) + solver_factory = lambda: RDDLJaxSolver( + domain_factory=domain_factory, config=config_name + ) + + # solve + with solver_factory() as solver: + solver.solve() + rollout(domain_factory(), solver, max_steps=100, render=False, verbose=False) + + +@pytest.mark.skipif( + not pyrddlgym_gurobi_available, + reason="You need to install pyRDDL_gurobi for this solver", +) +def test_pyrddlgymdomain_gurobi(): + # domain factory (with proper backend and vectorized flag) + domain_factory = lambda: RDDLDomain( + rddl_domain="Cartpole_Continuous_gym", + rddl_instance="0", + display_with_pygame=False, + display_within_jupyter=False, + ) + solver_factory = lambda: RDDLGurobiSolver(domain_factory=domain_factory) + + # solve + with solver_factory() as solver: + solver.solve() + rollout(domain_factory(), solver, max_steps=100, render=False, verbose=False) + + +def test_load_solvers(): + cls = load_registered_solver("RDDLJaxSolver") + assert cls is RDDLJaxSolver + cls = load_registered_solver("RDDLGurobiSolver") + assert cls is RDDLGurobiSolver