From 88b72f99171e130050013ac7e366ab3dac8840e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 16 Oct 2024 15:31:39 -0400 Subject: [PATCH] Properly close java process upon ScienceWorldEnv termination. --- scienceworld/scienceworld.py | 21 +++++++++++++++++---- tests/test_scienceworld.py | 9 +++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/scienceworld/scienceworld.py b/scienceworld/scienceworld.py index 60b5847..806240f 100644 --- a/scienceworld/scienceworld.py +++ b/scienceworld/scienceworld.py @@ -36,20 +36,21 @@ def __init__(self, taskName: str = None, serverPath: str = None, envStepLimit: i if DEBUG_MODE: import sys import time - port = launch_gateway( + port, proc = launch_gateway( classpath=serverPath, die_on_exit=True, cwd=BASEPATH, javaopts=['-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=5005,quiet=y'], - redirect_stdout=sys.stdout, redirect_stderr=sys.stderr) + redirect_stdout=sys.stdout, redirect_stderr=sys.stderr, return_proc=True) logger.debug("Attach debugger within the next 10 seconds") time.sleep(10) # Give time for user to attach debugger else: - port = launch_gateway(classpath=serverPath, die_on_exit=True, cwd=BASEPATH) + port, proc = launch_gateway(classpath=serverPath, die_on_exit=True, cwd=BASEPATH, return_proc=True) # Connect python side to Java side with Java dynamic port and start python # callback server with a dynamic port self._gateway = JavaGateway( gateway_parameters=GatewayParameters(auto_field=True, port=port), - callback_server_parameters=CallbackServerParameters(port=0, daemonize=True)) + callback_server_parameters=CallbackServerParameters(port=0, daemonize=True), + java_process=proc) # Retrieve the port on which the python callback server was bound to. python_port = self._gateway.get_callback_server().get_listening_port() @@ -147,6 +148,18 @@ def reset(self) -> Tuple[str, Dict[str, Any]]: # Return a tuple that looks like the Jericho signature for reset return observation, info + def close(self) -> None: + self._gateway.shutdown() + + # According to https://github.com/py4j/py4j/issues/320#issuecomment-553599210 + # we need to send a newline to the process to make it exit. + if self._gateway.java_process.poll() is None: + self._gateway.java_process.stdin.write("\n".encode("utf-8")) + self._gateway.java_process.stdin.flush() + + def __del__(self): + self.close() + # Simplifications def get_simplifications_used(self) -> str: ''' Gets the simplifications being used by the simulator. ''' diff --git a/tests/test_scienceworld.py b/tests/test_scienceworld.py index a4fcedd..f6a6a97 100644 --- a/tests/test_scienceworld.py +++ b/tests/test_scienceworld.py @@ -44,6 +44,15 @@ def test_multiple_instances(): assert obs1_2 == obs2_2 +def test_closing_env(): + env = ScienceWorldEnv() + env.task_names # Load task names. + assert env._gateway.java_process.poll() is None + env.close() + env._gateway.java_process.wait(1) + assert env._gateway.java_process.poll() is not None + + def test_variation_sets_are_disjoint(): env = ScienceWorldEnv()