Skip to content

Commit

Permalink
Properly close java process upon ScienceWorldEnv termination.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcCote committed Oct 16, 2024
1 parent 60df29a commit 88b72f9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
21 changes: 17 additions & 4 deletions scienceworld/scienceworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,21 @@ def __init__(self, taskName: str = None, serverPath: str = None, envStepLimit: i
if DEBUG_MODE:
import sys
import time
port = launch_gateway(
port, proc = launch_gateway(
classpath=serverPath, die_on_exit=True, cwd=BASEPATH,
javaopts=['-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=5005,quiet=y'],
redirect_stdout=sys.stdout, redirect_stderr=sys.stderr)
redirect_stdout=sys.stdout, redirect_stderr=sys.stderr, return_proc=True)
logger.debug("Attach debugger within the next 10 seconds")
time.sleep(10) # Give time for user to attach debugger
else:
port = launch_gateway(classpath=serverPath, die_on_exit=True, cwd=BASEPATH)
port, proc = launch_gateway(classpath=serverPath, die_on_exit=True, cwd=BASEPATH, return_proc=True)

# Connect python side to Java side with Java dynamic port and start python
# callback server with a dynamic port
self._gateway = JavaGateway(
gateway_parameters=GatewayParameters(auto_field=True, port=port),
callback_server_parameters=CallbackServerParameters(port=0, daemonize=True))
callback_server_parameters=CallbackServerParameters(port=0, daemonize=True),
java_process=proc)

# Retrieve the port on which the python callback server was bound to.
python_port = self._gateway.get_callback_server().get_listening_port()
Expand Down Expand Up @@ -147,6 +148,18 @@ def reset(self) -> Tuple[str, Dict[str, Any]]:
# Return a tuple that looks like the Jericho signature for reset
return observation, info

def close(self) -> None:
self._gateway.shutdown()

# According to https://github.com/py4j/py4j/issues/320#issuecomment-553599210
# we need to send a newline to the process to make it exit.
if self._gateway.java_process.poll() is None:
self._gateway.java_process.stdin.write("\n".encode("utf-8"))
self._gateway.java_process.stdin.flush()

def __del__(self):
self.close()

# Simplifications
def get_simplifications_used(self) -> str:
''' Gets the simplifications being used by the simulator. '''
Expand Down
9 changes: 9 additions & 0 deletions tests/test_scienceworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_multiple_instances():
assert obs1_2 == obs2_2


def test_closing_env():
env = ScienceWorldEnv()
env.task_names # Load task names.
assert env._gateway.java_process.poll() is None
env.close()
env._gateway.java_process.wait(1)
assert env._gateway.java_process.poll() is not None


def test_variation_sets_are_disjoint():
env = ScienceWorldEnv()

Expand Down

0 comments on commit 88b72f9

Please sign in to comment.