Skip to content

Commit

Permalink
Use config files included in pyRDDLGym-jax install directory
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Dec 5, 2024
1 parent e5e4a19 commit c34e675
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
8 changes: 5 additions & 3 deletions notebooks/16_rddl_tuto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"import os\n",
"import shutil\n",
"\n",
"import pyRDDLGym_jax.examples.configs\n",
"from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator\n",
"from pyRDDLGym_rl.core.env import SimplifiedActionRDDLEnv\n",
"from ray.rllib.algorithms.ppo import PPO as RLLIB_PPO\n",
Expand Down Expand Up @@ -472,8 +473,9 @@
"problem_info = manager.get_problem(problem_name)\n",
"problem_visualizer = QuadcopterVisualizer\n",
"\n",
"if not os.path.exists(\"Quadcopter_slp.cfg\"):\n",
" !wget https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/Quadcopter_slp.cfg\n",
"config_name = \"Quadcopter_slp.cfg\"\n",
"config_dir = pyRDDLGym_jax.examples.configs.__path__[0]\n",
"config_path = f\"{config_dir}/{config_name}\"\n",
"\n",
"domain_factory_jax_agent = lambda alg_name=None: RDDLDomain(\n",
" rddl_domain=problem_info.get_domain(),\n",
Expand All @@ -491,7 +493,7 @@
"\n",
"logging.getLogger(\"matplotlib.font_manager\").disabled = True\n",
"with RDDLJaxSolver(\n",
" domain_factory=domain_factory_jax_agent, config=\"Quadcopter_slp.cfg\"\n",
" domain_factory=domain_factory_jax_agent, config=config_path\n",
") as solver:\n",
" solver.solve()\n",
" rollout(\n",
Expand Down
12 changes: 4 additions & 8 deletions tests/solvers/python/test_pyrddlgym_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
from urllib.request import urlcleanup, urlretrieve

import pyRDDLGym_jax.examples.configs
from pyRDDLGym_jax.core.simulator import JaxRDDLSimulator

from skdecide.hub.domain.rddl import RDDLDomain
Expand All @@ -12,13 +13,8 @@
def test_pyrddlgymdomain_jax():
# get solver config
config_name = "Cartpole_Continuous_gym_drp.cfg"
if not os.path.exists(config_name):
url = f"https://raw.githubusercontent.com/pyrddlgym-project/pyRDDLGym-jax/main/pyRDDLGym_jax/examples/configs/{config_name}"
try:
local_file_path, headers = urlretrieve(url)
shutil.move(local_file_path, config_name)
finally:
urlcleanup()
config_dir = pyRDDLGym_jax.examples.configs.__path__[0]
config_path = f"{config_dir}/{config_name}"

# domain factory (with proper backend and vectorized flag)
domain_factory = lambda: RDDLDomain(
Expand All @@ -30,7 +26,7 @@ def test_pyrddlgymdomain_jax():
vectorized=True,
)
solver_factory = lambda: RDDLJaxSolver(
domain_factory=domain_factory, config=config_name
domain_factory=domain_factory, config=config_path
)

# solve
Expand Down

0 comments on commit c34e675

Please sign in to comment.