Skip to content

Commit

Permalink
added tmp_dir stuff to maze_v3 as well
Browse files Browse the repository at this point in the history
  • Loading branch information
mishmish66 committed Apr 27, 2024
1 parent 4425b24 commit 6606f59
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions gymnasium_robotics/envs/maze/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,19 @@ def make_maze(
maze._unique_reset_locations += maze._combined_locations

# Save new xml with maze to a temporary file
with tempfile.TemporaryDirectory() as tmp_dir:
temp_xml_path = path.join(path.dirname(tmp_dir), "ant_maze.xml")
tree.write(temp_xml_path)

return maze, temp_xml_path
# Make temporary file object and make the string path to our new file
tmp_dir = tempfile.TemporaryDirectory()
temp_xml_path = path.join(tmp_dir.name, "ant_maze.xml")

# Write the new xml to the temporary file
with open(temp_xml_path, "wb") as xml_file:
tree.write(xml_file)

return (
maze,
temp_xml_path,
tmp_dir, # The tmp_dir object is returned to keep it alive
)


class MazeEnv(GoalEnv):
Expand All @@ -172,7 +180,7 @@ def __init__(

self.reward_type = reward_type
self.continuing_task = continuing_task
self.maze, self.tmp_xml_file_path = Maze.make_maze(
self.maze, self.tmp_xml_file_path, self.tmp_dir = Maze.make_maze(
agent_xml_path, maze_map, maze_size_scaling, maze_height
)

Expand Down Expand Up @@ -308,3 +316,7 @@ def compute_truncated(

def update_target_site_pos(self, pos):
raise NotImplementedError

def __del__(self):
self.tmp_dir.cleanup()
super().__del__()

0 comments on commit 6606f59

Please sign in to comment.