From eecec61168a37c59e8d005cd45a2ec44b7aa342b Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Sun, 17 Nov 2024 17:58:14 +0900 Subject: [PATCH] [BugFix] solve PDP tour issues #231 --- rl4co/envs/routing/pdp/env.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/rl4co/envs/routing/pdp/env.py b/rl4co/envs/routing/pdp/env.py index 4ff0016b..5996415f 100644 --- a/rl4co/envs/routing/pdp/env.py +++ b/rl4co/envs/routing/pdp/env.py @@ -42,6 +42,9 @@ class PDPEnv(RL4COEnvBase): Args: generator: PDPGenerator instance as the data generator generator_params: parameters for the generator + force_start_at_depot: whether to force the agent to start at the depot + If False (default), the agent won't consider the depot, which is added in the `get_reward` method + If True, the only valid action at the first step is to visit the depot (=0) """ name = "pdp" @@ -50,12 +53,14 @@ def __init__( self, generator: PDPGenerator = None, generator_params: dict = {}, + force_start_at_depot: bool = False, **kwargs, ): super().__init__(**kwargs) if generator is None: generator = PDPGenerator(**generator_params) self.generator = generator + self.force_start_at_depot = force_start_at_depot self._make_spec(self.generator) @staticmethod @@ -85,7 +90,7 @@ def _step(td: TensorDict) -> TensorDict: # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + # Update step td.update( { @@ -122,12 +127,17 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict dim=-1, ) - # Cannot visit depot at first step # [0,1...1] so set not available + # Masking variables available = torch.ones( (*batch_size, self.generator.num_loc + 1), dtype=torch.bool ).to(device) - action_mask = ~available.contiguous() # [batch_size, graph_size+1] - action_mask[..., 0] = 1 # First step is always the depot + action_mask = torch.ones_like(available) # [batch_size, graph_size+1] + if self.force_start_at_depot: + action_mask[..., 1:] = False # can only visit the depot at the first step + else: + action_mask = action_mask & to_deliver + available[..., 0] = False # depot is already visited (during reward calculation) + action_mask[..., 0] = False # depot is not available to visit # Other variables current_node = torch.zeros((*batch_size, 1), dtype=torch.int64).to(device) @@ -194,13 +204,18 @@ def _get_reward(td, actions) -> TensorDict: return -get_tour_length(locs_ordered) def check_solution_validity(self, td, actions): - # assert (actions[:, 0] == 0).all(), "Not starting at depot" + if not self.force_start_at_depot: + actions = torch.cat((torch.zeros_like(actions[:, 0:1]), actions), dim=-1) + assert ( - torch.arange(actions.size(1), out=actions.data.new()) + (torch.arange(actions.size(1), out=actions.data.new())) .view(1, -1) .expand_as(actions) == actions.data.sort(1)[0] ).all(), "Not visiting all nodes" + + # make sure we don't go back to the depot in the middle of the tour + assert (actions[:, 1:-1] != 0).all(), "Going back to depot in the middle of the tour (not allowed)" visited_time = torch.argsort( actions, 1