Skip to content

Commit

Permalink
[BugFix] solve PDP tour issues #231
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Nov 17, 2024
1 parent b4a49b5 commit eecec61
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions rl4co/envs/routing/pdp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class PDPEnv(RL4COEnvBase):
Args:
generator: PDPGenerator instance as the data generator
generator_params: parameters for the generator
force_start_at_depot: whether to force the agent to start at the depot
If False (default), the agent won't consider the depot, which is added in the `get_reward` method
If True, the only valid action at the first step is to visit the depot (=0)
"""

name = "pdp"
Expand All @@ -50,12 +53,14 @@ def __init__(
self,
generator: PDPGenerator = None,
generator_params: dict = {},
force_start_at_depot: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if generator is None:
generator = PDPGenerator(**generator_params)
self.generator = generator
self.force_start_at_depot = force_start_at_depot
self._make_spec(self.generator)

@staticmethod
Expand Down Expand Up @@ -85,7 +90,7 @@ def _step(td: TensorDict) -> TensorDict:

# The reward is calculated outside via get_reward for efficiency, so we set it to 0 here
reward = torch.zeros_like(done)

# Update step
td.update(
{
Expand Down Expand Up @@ -122,12 +127,17 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict
dim=-1,
)

# Cannot visit depot at first step # [0,1...1] so set not available
# Masking variables
available = torch.ones(
(*batch_size, self.generator.num_loc + 1), dtype=torch.bool
).to(device)
action_mask = ~available.contiguous() # [batch_size, graph_size+1]
action_mask[..., 0] = 1 # First step is always the depot
action_mask = torch.ones_like(available) # [batch_size, graph_size+1]
if self.force_start_at_depot:
action_mask[..., 1:] = False # can only visit the depot at the first step
else:
action_mask = action_mask & to_deliver
available[..., 0] = False # depot is already visited (during reward calculation)
action_mask[..., 0] = False # depot is not available to visit

# Other variables
current_node = torch.zeros((*batch_size, 1), dtype=torch.int64).to(device)
Expand Down Expand Up @@ -194,13 +204,18 @@ def _get_reward(td, actions) -> TensorDict:
return -get_tour_length(locs_ordered)

def check_solution_validity(self, td, actions):
# assert (actions[:, 0] == 0).all(), "Not starting at depot"
if not self.force_start_at_depot:
actions = torch.cat((torch.zeros_like(actions[:, 0:1]), actions), dim=-1)

assert (
torch.arange(actions.size(1), out=actions.data.new())
(torch.arange(actions.size(1), out=actions.data.new()))
.view(1, -1)
.expand_as(actions)
== actions.data.sort(1)[0]
).all(), "Not visiting all nodes"

# make sure we don't go back to the depot in the middle of the tour
assert (actions[:, 1:-1] != 0).all(), "Going back to depot in the middle of the tour (not allowed)"

visited_time = torch.argsort(
actions, 1
Expand Down

0 comments on commit eecec61

Please sign in to comment.