From d742f5b119585c8f289916349c467d62f6a7b23a Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Tue, 9 Apr 2024 21:23:24 +0800 Subject: [PATCH 1/6] four new routing envs --- rl4co/envs/__init__.py | 10 +- rl4co/envs/routing/__init__.py | 5 +- rl4co/envs/routing/ovrp.py | 440 +++++++++++++++ rl4co/envs/routing/vrpb.py | 441 +++++++++++++++ rl4co/envs/routing/vrpl.py | 511 ++++++++++++++++++ rl4co/envs/routing/vrptw.py | 358 ++++++++++++ rl4co/models/nn/env_embeddings/context.py | 71 +-- rl4co/models/nn/env_embeddings/init.py | 81 +-- .../zoo/common/autoregressive/decoder.py | 11 +- rl4co/utils/ops.py | 26 +- 10 files changed, 1826 insertions(+), 128 deletions(-) create mode 100644 rl4co/envs/routing/ovrp.py create mode 100644 rl4co/envs/routing/vrpb.py create mode 100644 rl4co/envs/routing/vrpl.py create mode 100644 rl4co/envs/routing/vrptw.py diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index 0a6ad5cd..eac73b81 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -17,7 +17,10 @@ SVRPEnv, SPCTSPEnv, TSPEnv, - MDCPDPEnv, + VRPLEnv, + OVRPEnv, + VRPTWEnv, + VRPBEnv, ) # Scheduling @@ -40,7 +43,10 @@ "spctsp": SPCTSPEnv, "tsp": TSPEnv, "smtwtp": SMTWTPEnv, - "mdcpdp": MDCPDPEnv, + "vrpl": VRPLEnv, + "ovrp": OVRPEnv, + "vrptw": VRPTWEnv, + "vrpb": VRPBEnv, } diff --git a/rl4co/envs/routing/__init__.py b/rl4co/envs/routing/__init__.py index 392f2640..9c5285f5 100644 --- a/rl4co/envs/routing/__init__.py +++ b/rl4co/envs/routing/__init__.py @@ -9,4 +9,7 @@ from rl4co.envs.routing.spctsp import SPCTSPEnv from rl4co.envs.routing.svrp import SVRPEnv from rl4co.envs.routing.tsp import TSPEnv -from rl4co.envs.routing.mdcpdp import MDCPDPEnv +from rl4co.envs.routing.vrpl import VRPLEnv +from rl4co.envs.routing.ovrp import OVRPEnv +from rl4co.envs.routing.vrptw import VRPTWEnv +from rl4co.envs.routing.vrpb import VRPBEnv diff --git a/rl4co/envs/routing/ovrp.py b/rl4co/envs/routing/ovrp.py new file mode 100644 index 00000000..7f49be86 --- /dev/null +++ b/rl4co/envs/routing/ovrp.py @@ -0,0 +1,440 @@ +from typing import Optional + +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.ops import gather_by_index, get_open_tour_length +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 +CAPACITIES = { + 10: 20.0, + 15: 25.0, + 20: 30.0, + 30: 33.0, + 40: 37.0, + 50: 40.0, + 60: 43.0, + 75: 45.0, + 100: 50.0, + 125: 55.0, + 150: 60.0, + 200: 70.0, + 500: 100.0, + 1000: 150.0, +} + + +class OVRPEnv(RL4COEnvBase): + """Capacitated Vehicle Routing Problem (CVRP) environment. + At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. + When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to + visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. + In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + + Args: + num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + vehicle_capacity: capacity of the vehicle + td_params: parameters of the environment + """ + + name = "ovrp" + + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0, + max_loc: float = 1, + min_demand: float = 1, + max_demand: float = 10, + vehicle_capacity: float = 1.0, + capacity: float = None, + td_params: TensorDict = None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_demand = min_demand + self.max_demand = max_demand + self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity + if self.capacity is None: + raise ValueError( + f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." + ) + self.vehicle_capacity = vehicle_capacity + self._make_spec(td_params) + + def _step(self, td: TensorDict) -> TensorDict: + current_node = td["action"][:, None] # Add dimension for step + n_loc = td["demand"].size(-1) # Excludes depot + + # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! + selected_demand = gather_by_index( + td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False + ) + + # Increase capacity if depot is not visited, otherwise set to 0 + used_capacity = (td["used_capacity"] + selected_demand) * ( + current_node != 0 + ).float() + + # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot + # Add one dimension since we write a single value + visited = td["visited"].scatter(-1, current_node[..., None], 1) + + # SECTION: get done + done = visited.sum(-1) == visited.size(-1) + reward = torch.zeros_like(done) + + td.update( + { + "current_node": current_node, + "used_capacity": used_capacity, + "visited": visited, + "reward": reward, + "done": done, + } + ) + td.set("action_mask", self.get_action_mask(td)) + return td + + def _reset( + self, + td: Optional[TensorDict] = None, + batch_size: Optional[list] = None, + ) -> TensorDict: + if batch_size is None: + batch_size = self.batch_size if td is None else td["locs"].shape[:-2] + if td is None or td.is_empty(): + td = self.generate_data(batch_size=batch_size) + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + self.to(td.device) + + # Create reset TensorDict + td_reset = TensorDict( + { + "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), + "demand": td["demand"], + "current_node": torch.zeros( + *batch_size, 1, dtype=torch.long, device=self.device + ), + "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "vehicle_capacity": torch.full( + (*batch_size, 1), self.vehicle_capacity, device=self.device + ), + "visited": torch.zeros( + (*batch_size, 1, td["locs"].shape[-2] + 1), + dtype=torch.uint8, + device=self.device, + ), + }, + batch_size=batch_size, + ) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + return td_reset + + @staticmethod + def get_action_mask(td: TensorDict) -> torch.Tensor: + # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting + exceeds_cap = ( + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] + ) + + # Nodes that cannot be visited are already visited or too much demand to be served now + mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap + + # Cannot visit the depot if just visited and still unserved nodes + mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) + return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) + + def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + # Check that the solution is valid + if self.check_solution: + self.check_solution_validity(td, actions) + + # Gather dataset in order of tour + batch_size = td["locs"].shape[0] + depot = td["locs"][..., 0:1, :] + locs_ordered = torch.cat( + [ + depot, + gather_by_index(td["locs"], actions).reshape( + [batch_size, actions.size(-1), 2] + ), + ], + dim=1, + ) + return -get_open_tour_length(locs_ordered) + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" + # Check if tour is valid, i.e. contain 0 to n-1 + batch_size, graph_size = td["demand"].size() + sorted_pi = actions.data.sort(1)[0] + + # Sorting it should give all zeros at front and then 1...n + assert ( + torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) + .view(1, -1) + .expand(batch_size, graph_size) + == sorted_pi[:, -graph_size:] + ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" + + # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) + demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) + d = demand_with_depot.gather(1, actions) + + used_cap = torch.zeros_like(td["demand"][:, 0]) + for i in range(actions.size(1)): + used_cap += d[ + :, i + ] # This will reset/make capacity negative if i == 0, e.g. depot visited + # Cannot use less than 0 + used_cap[used_cap < 0] = 0 + assert ( + used_cap <= td["vehicle_capacity"] + 1e-5 + ).all(), "Used more than capacity" + + def generate_data(self, batch_size) -> TensorDict: + # Batch size input check + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + # Initialize the locations (including the depot which is always the first node) + locs_with_depot = ( + torch.FloatTensor(*batch_size, self.num_loc + 1, 2) + .uniform_(self.min_loc, self.max_loc) + .to(self.device) + ) + + # Initialize the demand for nodes except the depot + # Demand sampling Following Kool et al. (2019) + # Generates a slightly different distribution than using torch.randint + demand = ( + ( + torch.FloatTensor(*batch_size, self.num_loc) + .uniform_(self.min_demand - 1, self.max_demand - 1) + .int() + + 1 + ) + .float() + .to(self.device) + ) + + # Support for heterogeneous capacity if provided + if not isinstance(self.capacity, torch.Tensor): + capacity = torch.full((*batch_size,), self.capacity, device=self.device) + else: + capacity = self.capacity + + return TensorDict( + { + "locs": locs_with_depot[..., 1:, :], + "depot": locs_with_depot[..., 0, :], + "demand": demand / self.capacity, + "capacity": capacity, + }, + batch_size=batch_size, + device=self.device, + ) + + @staticmethod + def load_data(fpath, batch_size=[]): + """Dataset loading from file + Normalize demand by capacity to be in [0, 1] + """ + td_load = load_npz_to_tensordict(fpath) + td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) + return td_load + + def _make_spec(self, td_params: TensorDict): + """Make the observation and action specs from the parameters.""" + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=self.min_loc, + high=self.max_loc, + shape=(self.num_loc + 1, 2), + dtype=torch.float32, + ), + current_node=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + demand=BoundedTensorSpec( + low=-self.capacity, + high=self.max_demand, + shape=(self.num_loc, 1), # demand is only for customers + dtype=torch.float32, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(self.num_loc + 1, 1), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=self.num_loc + 1, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + @staticmethod + def render( + td: TensorDict, + actions=None, + ax=None, + scale_xy: bool = True, + ): + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import cm, colormaps + + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) + demands = td["demand"] * scale_demand + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # plot demand bars + for node_idx in range(1, len(locs)): + ax.add_patch( + plt.Rectangle( + (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), + 0.01, + demands[node_idx - 1] / (scale_demand * 10), + edgecolor=cm.Set2(0), + facecolor=cm.Set2(0), + fill=True, + ) + ) + + # text demand + for node_idx in range(1, len(locs)): + ax.text( + locs[node_idx, 0], + locs[node_idx, 1] - 0.025, + f"{demands[node_idx-1].item():.2f}", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(0), + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx+1] == 0: + continue + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + # Setup limits and show + if scale_xy: + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) + plt.show() diff --git a/rl4co/envs/routing/vrpb.py b/rl4co/envs/routing/vrpb.py new file mode 100644 index 00000000..df5da1b1 --- /dev/null +++ b/rl4co/envs/routing/vrpb.py @@ -0,0 +1,441 @@ +from typing import Optional + +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.ops import gather_by_index, get_tour_length +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 +CAPACITIES = { + 10: 20.0, + 15: 25.0, + 20: 30.0, + 30: 33.0, + 40: 37.0, + 50: 40.0, + 60: 43.0, + 75: 45.0, + 100: 50.0, + 125: 55.0, + 150: 60.0, + 200: 70.0, + 500: 100.0, + 1000: 150.0, +} + + +class VRPBEnv(RL4COEnvBase): + """Capacitated Vehicle Routing Problem (CVRP) environment. + At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. + When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to + visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. + In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + + Args: + num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + vehicle_capacity: capacity of the vehicle + td_params: parameters of the environment + """ + + name = "vrpb" + + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0, + max_loc: float = 1, + min_demand: float = 1, + max_demand: float = 10, + vehicle_capacity: float = 1.0, + capacity: float = None, + td_params: TensorDict = None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_demand = min_demand + self.max_demand = max_demand + self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity + if self.capacity is None: + raise ValueError( + f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." + ) + self.vehicle_capacity = vehicle_capacity + self._make_spec(td_params) + + def _step(self, td: TensorDict) -> TensorDict: + current_node = td["action"][:, None] # Add dimension for step + n_loc = td["demand"].size(-1) # Excludes depot + + # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! + selected_demand = gather_by_index( + td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False + ) + + # Increase capacity if depot is not visited, otherwise set to 0 + used_capacity = (td["used_capacity"] + selected_demand) * ( + current_node != 0 + ).float() + + # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot + # Add one dimension since we write a single value + visited = td["visited"].scatter(-1, current_node[..., None], 1) + + # SECTION: get done + done = visited.sum(-1) == visited.size(-1) + reward = torch.zeros_like(done) + + td.update( + { + "current_node": current_node, + "used_capacity": used_capacity, + "visited": visited, + "reward": reward, + "done": done, + } + ) + td.set("action_mask", self.get_action_mask(td)) + return td + + def _reset( + self, + td: Optional[TensorDict] = None, + batch_size: Optional[list] = None, + ) -> TensorDict: + if batch_size is None: + batch_size = self.batch_size if td is None else td["locs"].shape[:-2] + if td is None or td.is_empty(): + td = self.generate_data(batch_size=batch_size) + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + self.to(td.device) + + # Create reset TensorDict + td_reset = TensorDict( + { + "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), + "demand": td["demand"], + "current_node": torch.zeros( + *batch_size, 1, dtype=torch.long, device=self.device + ), + "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "vehicle_capacity": torch.full( + (*batch_size, 1), self.vehicle_capacity, device=self.device + ), + "visited": torch.zeros( + (*batch_size, 1, td["locs"].shape[-2] + 1), + dtype=torch.uint8, + device=self.device, + ), + }, + batch_size=batch_size, + ) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + return td_reset + + @staticmethod + def get_action_mask(td: TensorDict) -> torch.Tensor: + # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting + exceeds_cap = ( + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] + ) + + # Nodes that cannot be visited are already visited or too much demand to be served now + mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap + + # Cannot visit the depot if just visited and still unserved nodes + mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) + return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) + + def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + # Check that the solution is valid + # if self.check_solution: + # self.check_solution_validity(td, actions) + + # Gather dataset in order of tour + batch_size = td["locs"].shape[0] + depot = td["locs"][..., 0:1, :] + locs_ordered = torch.cat( + [ + depot, + gather_by_index(td["locs"], actions).reshape( + [batch_size, actions.size(-1), 2] + ), + ], + dim=1, + ) + return -get_tour_length(locs_ordered) + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" + # Check if tour is valid, i.e. contain 0 to n-1 + batch_size, graph_size = td["demand"].size() + sorted_pi = actions.data.sort(1)[0] + + # Sorting it should give all zeros at front and then 1...n + assert ( + torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) + .view(1, -1) + .expand(batch_size, graph_size) + == sorted_pi[:, -graph_size:] + ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" + + # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) + demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) + d = demand_with_depot.gather(1, actions) + + used_cap = torch.zeros_like(td["demand"][:, 0]) + for i in range(actions.size(1)): + used_cap += d[ + :, i + ] # This will reset/make capacity negative if i == 0, e.g. depot visited + # Cannot use less than 0 + used_cap[used_cap < 0] = 0 + assert ( + used_cap <= td["vehicle_capacity"] + 1e-5 + ).all(), "Used more than capacity" + + def generate_data(self, batch_size) -> TensorDict: + # Batch size input check + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + # Initialize the locations (including the depot which is always the first node) + locs_with_depot = ( + torch.FloatTensor(*batch_size, self.num_loc + 1, 2) + .uniform_(self.min_loc, self.max_loc) + .to(self.device) + ) + + # Initialize the demand for nodes except the depot + # Demand sampling Following Kool et al. (2019) + # Generates a slightly different distribution than using torch.randint + demand = ( + ( + torch.FloatTensor(*batch_size, self.num_loc) + .uniform_(self.min_demand - 1, self.max_demand - 1) + .int() + + 1 + ) + .float() + .to(self.device) + ) + + # set 20% to backhaul + linehaul = int(0.8*self.num_loc) + demand[:,linehaul:] = -demand[:,linehaul:] + + # Support for heterogeneous capacity if provided + if not isinstance(self.capacity, torch.Tensor): + capacity = torch.full((*batch_size,), self.capacity, device=self.device) + else: + capacity = self.capacity + + return TensorDict( + { + "locs": locs_with_depot[..., 1:, :], + "depot": locs_with_depot[..., 0, :], + "demand": demand / self.capacity, + "capacity": capacity, + }, + batch_size=batch_size, + device=self.device, + ) + + @staticmethod + def load_data(fpath, batch_size=[]): + """Dataset loading from file + Normalize demand by capacity to be in [0, 1] + """ + td_load = load_npz_to_tensordict(fpath) + td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) + return td_load + + def _make_spec(self, td_params: TensorDict): + """Make the observation and action specs from the parameters.""" + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=self.min_loc, + high=self.max_loc, + shape=(self.num_loc + 1, 2), + dtype=torch.float32, + ), + current_node=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + demand=BoundedTensorSpec( + low=-self.capacity, + high=self.max_demand, + shape=(self.num_loc, 1), # demand is only for customers + dtype=torch.float32, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(self.num_loc + 1, 1), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=self.num_loc + 1, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + @staticmethod + def render( + td: TensorDict, + actions=None, + ax=None, + scale_xy: bool = True, + ): + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import cm, colormaps + + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) + demands = td["demand"] * scale_demand + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # plot demand bars + for node_idx in range(1, len(locs)): + ax.add_patch( + plt.Rectangle( + (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), + 0.01, + demands[node_idx - 1] / (scale_demand * 10), + edgecolor=cm.Set2(0), + facecolor=cm.Set2(0), + fill=True, + ) + ) + + # text demand + for node_idx in range(1, len(locs)): + ax.text( + locs[node_idx, 0], + locs[node_idx, 1] - 0.025, + f"{demands[node_idx-1].item():.2f}", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(0), + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + # Setup limits and show + if scale_xy: + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) + plt.show() diff --git a/rl4co/envs/routing/vrpl.py b/rl4co/envs/routing/vrpl.py new file mode 100644 index 00000000..c5e1d4d9 --- /dev/null +++ b/rl4co/envs/routing/vrpl.py @@ -0,0 +1,511 @@ +from typing import Optional + +import torch + +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.utils.ops import gather_by_index, get_tour_length, get_distance +from rl4co.utils.pylogger import get_pylogger + + +log = get_pylogger(__name__) + + +# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 +CAPACITIES = { + 10: 20.0, + 15: 25.0, + 20: 30.0, + 30: 33.0, + 40: 37.0, + 50: 40.0, + 60: 43.0, + 75: 45.0, + 100: 50.0, + 125: 55.0, + 150: 60.0, + 200: 70.0, + 500: 100.0, + 1000: 150.0, +} + + +class VRPLEnv(RL4COEnvBase): + """Capacitated Vehicle Routing Problem (CVRP) environment. + At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. + When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to + visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. + In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. + + Args: + num_loc: number of + + locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc: minimum value for the location coordinates + max_loc: maximum value for the location coordinates + min_demand: minimum value for the demand of each customer + max_demand: maximum value for the demand of each customer + vehicle_capacity: capacity of the vehicle + td_params: parameters of the environment + """ + + name = "vrpl" + + def __init__( + self, + num_loc: int = 20, + min_loc: float = 0, + max_loc: float = 1, + min_demand: float = 1, + max_demand: float = 10, + vehicle_capacity: float = 1.0, + capacity: float = None, + duration_limit: float = None, + #selected_node_list: torch.Tensor = None, + td_params: TensorDict = None, + **kwargs, + ): + super().__init__(**kwargs) + self.num_loc = num_loc + self.min_loc = min_loc + self.max_loc = max_loc + self.min_demand = min_demand + self.max_demand = max_demand + self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity + if self.capacity is None: + raise ValueError( + f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." + ) + self.vehicle_capacity = vehicle_capacity + self.duration_limit = 3.0 if duration_limit is None else duration_limit + self.selected_node_list = None + self._make_spec(td_params) + + def _step(self, td: TensorDict) -> TensorDict: + current_node = td["action"][:, None] # Add dimension for step + n_loc = td["demand"].size(-1) # Excludes depot + + self.selected_node_list = torch.cat((self.selected_node_list, current_node), dim=1) + # shape: (batch, pomo, 0~) + + # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! + selected_demand = gather_by_index( + td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False + ) + + # Increase capacity if depot is not visited, otherwise set to 0 + used_capacity = (td["used_capacity"] + selected_demand) * ( + current_node != 0 + ).float() + + # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot + # Add one dimension since we write a single value + visited = td["visited"].scatter(-1, current_node[..., None], 1) + + #### update distance information ### + + selected_xy = gather_by_index( + td["locs"], current_node, squeeze=False + ) + + gathering_index_last = self.selected_node_list[:, -2][:,None,None].expand(-1,-1,2) + + last_xy = gather_by_index( + td["locs"], gathering_index_last, squeeze=False + ) + + selected_distance = ((selected_xy - last_xy)**2).sum(dim=2).sqrt() + + td["duration_limit"] -= selected_distance + + td["duration_limit"][current_node == 0] = self.duration_limit # refill length at the depot + + + # SECTION: get done + done = visited.sum(-1) == visited.size(-1) + # print(done[0]) + # print(done) + reward = torch.zeros_like(done) + + td.update( + { + "current_node": current_node, + "used_capacity": used_capacity, + "visited": visited, + "reward": reward, + "done": done, + } + ) + td.set("action_mask", self.get_action_mask(td)) + return td + + def _reset( + self, + td: Optional[TensorDict] = None, + batch_size: Optional[list] = None, + ) -> TensorDict: + if batch_size is None: + batch_size = self.batch_size if td is None else td["locs"].shape[:-2] + if td is None or td.is_empty(): + td = self.generate_data(batch_size=batch_size) + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + self.to(td.device) + + self.selected_node_list = torch.zeros((*batch_size,1),dtype=torch.int64,device=self.device) + + # Create reset TensorDict + td_reset = TensorDict( + { + "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), + "demand": td["demand"], + "current_node": torch.zeros( + *batch_size, 1, dtype=torch.long, device=self.device + ), + "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "duration_limit": torch.full( + (*batch_size, 1), self.duration_limit, device=self.device + ), + "vehicle_capacity": torch.full( + (*batch_size, 1), self.vehicle_capacity, device=self.device + ), + + "visited": torch.zeros( + (*batch_size, 1, td["locs"].shape[-2] + 1), + dtype=torch.uint8, + device=self.device, + ), + }, + batch_size=batch_size, + ) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + + return td_reset + + @staticmethod + def get_action_mask(td: TensorDict) -> torch.Tensor: + # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting + exceeds_cap = ( + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] + ) + + if "action" not in td.keys(): + current_node =torch.zeros((td["locs"].shape[0],1),dtype=torch.int64, device=td["locs"].device) + else: + current_node = td["action"][:, None] # Add dimension for step + + selected_xy = gather_by_index( + td["locs"], current_node, squeeze=False + ) + + length_to_next = ((selected_xy - td["locs"])**2).sum(dim=2).sqrt() + + # length_to_next = ((selected_xy[:,None,:].expand(-1,-1,self.problem_size+1,-1) - xy_list)**2).sum(dim=3).sqrt() + # # shape: (batch, pomo, problem+1) + depot_xy = td["locs"][:,0,:] + next_to_depot = ((depot_xy[:,None,:].expand(td["locs"].shape) - td["locs"])**2).sum(dim=2).sqrt() + # shape: (batch, pomo, problem+1) + + length_too_small = td["duration_limit"] - 1E-6 < (length_to_next + next_to_depot ) + + # Nodes that cannot be visited are already visited or too much demand to be served now + + mask_loc = (td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap) | length_too_small[:,1:][:,None,:] + + # print(td["visited"][..., 1:].to(exceeds_cap.dtype).shape) + # print(exceeds_cap) + # print(length_too_small[:,1:][:,None,:]) + # print(mask_loc) + # print(td["visited"][..., 1:][0]) + # print(td["visited"][..., 1:][-1]) + #input() + # print(td["visited"][..., 1:].to(exceeds_cap.dtype).shape) + + + # Cannot visit the depot if just visited and still unserved nodes + mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) + + #print(~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2)) + return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) + + def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + # Check that the solution is valid + # if self.check_solution: + # self.check_solution_validity(td, actions) + + # Gather dataset in order of tour + batch_size = td["locs"].shape[0] + depot = td["locs"][..., 0:1, :] + locs_ordered = torch.cat( + [ + depot, + gather_by_index(td["locs"], actions).reshape( + [batch_size, actions.size(-1), 2] + ), + ], + dim=1, + ) + return -get_tour_length(locs_ordered) + + # @staticmethod + # def check_solution_validity(td: TensorDict, actions: torch.Tensor): + # """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" + # # Check if tour is valid, i.e. contain 0 to n-1 + # batch_size, graph_size = td["demand"].size() + # sorted_pi = actions.data.sort(1)[0] + + # # Sorting it should give all zeros at front and then 1...n + # assert ( + # torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) + # .view(1, -1) + # .expand(batch_size, graph_size) + # == sorted_pi[:, -graph_size:] + # ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" + + # # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) + # demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) + # d = demand_with_depot.gather(1, actions) + + # used_cap = torch.zeros_like(td["demand"][:, 0]) + # for i in range(actions.size(1)): + # used_cap += d[ + # :, i + # ] # This will reset/make capacity negative if i == 0, e.g. depot visited + # # Cannot use less than 0 + # used_cap[used_cap < 0] = 0 + # assert ( + # used_cap <= td["vehicle_capacity"] + 1e-5 + # ).all(), "Used more than capacity" + + def generate_data(self, batch_size) -> TensorDict: + # Batch size input check + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + # Initialize the locations (including the depot which is always the first node) + locs_with_depot = ( + torch.FloatTensor(*batch_size, self.num_loc + 1, 2) + .uniform_(self.min_loc, self.max_loc) + .to(self.device) + ) + + # Initialize the demand for nodes except the depot + # Demand sampling Following Kool et al. (2019) + # Generates a slightly different distribution than using torch.randint + demand = ( + ( + torch.FloatTensor(*batch_size, self.num_loc) + .uniform_(self.min_demand - 1, self.max_demand - 1) + .int() + + 1 + ) + .float() + .to(self.device) + ) + + # Support for heterogeneous capacity if provided + if not isinstance(self.capacity, torch.Tensor): + capacity = torch.full((*batch_size,), self.capacity, device=self.device) + else: + capacity = self.capacity + + # duration limit + duration_limit = torch.full ((*batch_size,), self.duration_limit, device=self.device) + + return TensorDict( + { + "locs": locs_with_depot[..., 1:, :], + "depot": locs_with_depot[..., 0, :], + "demand": demand / self.capacity, + "capacity": capacity, + "duration_limit":duration_limit, + }, + batch_size=batch_size, + device=self.device, + ) + + @staticmethod + def load_data(fpath, batch_size=[]): + """Dataset loading from file + Normalize demand by capacity to be in [0, 1] + """ + td_load = load_npz_to_tensordict(fpath) + td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) + return td_load + + def _make_spec(self, td_params: TensorDict): + """Make the observation and action specs from the parameters.""" + self.observation_spec = CompositeSpec( + locs=BoundedTensorSpec( + low=self.min_loc, + high=self.max_loc, + shape=(self.num_loc + 1, 2), + dtype=torch.float32, + ), + current_node=UnboundedDiscreteTensorSpec( + shape=(1), + dtype=torch.int64, + ), + demand=BoundedTensorSpec( + low=-self.capacity, + high=self.max_demand, + shape=(self.num_loc, 1), # demand is only for customers + dtype=torch.float32, + ), + action_mask=UnboundedDiscreteTensorSpec( + shape=(self.num_loc + 1, 1), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=0, + high=self.num_loc + 1, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + @staticmethod + def render( + td: TensorDict, + actions=None, + ax=None, + scale_xy: bool = True, + ): + import matplotlib.pyplot as plt + import numpy as np + + from matplotlib import cm, colormaps + + num_routine = (actions == 0).sum().item() + 2 + base = colormaps["nipy_spectral"] + color_list = base(np.linspace(0, 1, num_routine)) + cmap_name = base.name + str(num_routine) + out = base.from_list(cmap_name, color_list, num_routine) + + if ax is None: + # Create a plot of the nodes + _, ax = plt.subplots() + + td = td.detach().cpu() + + if actions is None: + actions = td.get("action", None) + + # if batch_size greater than 0 , we need to select the first batch element + if td.batch_size != torch.Size([]): + td = td[0] + actions = actions[0] + + locs = td["locs"] + scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) + demands = td["demand"] * scale_demand + + # add the depot at the first action and the end action + actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) + + # gather locs in order of action if available + if actions is None: + log.warning("No action in TensorDict, rendering unsorted locs") + else: + locs = locs + + # Cat the first node to the end to complete the tour + x, y = locs[:, 0], locs[:, 1] + + # plot depot + ax.scatter( + locs[0, 0], + locs[0, 1], + edgecolors=cm.Set2(2), + facecolors="none", + s=100, + linewidths=2, + marker="s", + alpha=1, + ) + + # plot visited nodes + ax.scatter( + x[1:], + y[1:], + edgecolors=cm.Set2(0), + facecolors="none", + s=50, + linewidths=2, + marker="o", + alpha=1, + ) + + # plot demand bars + for node_idx in range(1, len(locs)): + ax.add_patch( + plt.Rectangle( + (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), + 0.01, + demands[node_idx - 1] / (scale_demand * 10), + edgecolor=cm.Set2(0), + facecolor=cm.Set2(0), + fill=True, + ) + ) + + # text demand + for node_idx in range(1, len(locs)): + ax.text( + locs[node_idx, 0], + locs[node_idx, 1] - 0.025, + f"{demands[node_idx-1].item():.2f}", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(0), + ) + + # text depot + ax.text( + locs[0, 0], + locs[0, 1] - 0.025, + "Depot", + horizontalalignment="center", + verticalalignment="top", + fontsize=10, + color=cm.Set2(2), + ) + + # plot actions + color_idx = 0 + for action_idx in range(len(actions) - 1): + if actions[action_idx] == 0: + color_idx += 1 + from_loc = locs[actions[action_idx]] + to_loc = locs[actions[action_idx + 1]] + ax.plot( + [from_loc[0], to_loc[0]], + [from_loc[1], to_loc[1]], + color=out(color_idx), + lw=1, + ) + ax.annotate( + "", + xy=(to_loc[0], to_loc[1]), + xytext=(from_loc[0], from_loc[1]), + arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), + size=15, + annotation_clip=False, + ) + + # Setup limits and show + if scale_xy: + ax.set_xlim(-0.05, 1.05) + ax.set_ylim(-0.05, 1.05) + plt.show() diff --git a/rl4co/envs/routing/vrptw.py b/rl4co/envs/routing/vrptw.py new file mode 100644 index 00000000..51aedc25 --- /dev/null +++ b/rl4co/envs/routing/vrptw.py @@ -0,0 +1,358 @@ +from math import sqrt +from typing import Optional +import torch +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, +) + +from rl4co.envs.routing.cvrp import CVRPEnv, CAPACITIES +from rl4co.utils.ops import gather_by_index, get_distance +from rl4co.data.utils import ( + load_npz_to_tensordict, + load_solomon_instance, + load_solomon_solution, +) + + +class VRPTWEnv(CVRPEnv): + """Capacitated Vehicle Routing Problem with Time Windows (CVRPTW) environment. + Inherits from the CVRPEnv class in which capacities are considered. + Additionally considers time windows within which a service has to be started. + + Args: + num_loc (int): number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) + min_loc (float): minimum value for the location coordinates + max_loc (float): maximum value for the location coordinates. Defaults to 150. + min_demand (float): minimum value for the demand of each customer + max_demand (float): maximum value for the demand of each customer + max_time (float): maximum time for the environment. Defaults to 480. + vehicle_capacity (float): capacity of the vehicle + capacity (float): capacity of the vehicle + scale (bool): if True, the time windows and service durations are scaled to [0, 1]. Defaults to False. + td_params: parameters of the environment + """ + + name = "cvrptw" + + def __init__( + self, + max_loc: float = 1, # different default value to CVRPEnv to match max_time, will be scaled + max_time: float = 4.6, + scale: bool = False, + **kwargs, + ): + self.min_time = 0 # always 0 + self.max_time = max_time + self.scale = scale + super().__init__(max_loc=max_loc, **kwargs) + + def _make_spec(self, td_params: TensorDict): + super()._make_spec(td_params) + + current_time = UnboundedContinuousTensorSpec( + shape=(1), dtype=torch.float32, device=self.device + ) + + current_loc = UnboundedContinuousTensorSpec( + shape=(2), dtype=torch.float32, device=self.device + ) + + durations = BoundedTensorSpec( + low=self.min_time, + high=self.max_time, + shape=(self.num_loc, 1), + dtype=torch.float, + device=self.device, + ) + + time_windows = BoundedTensorSpec( + low=self.min_time, + high=self.max_time, + shape=( + self.num_loc, + 2, + ), # each location has a 2D time window (start, end) + dtype=torch.float, + device=self.device, + ) + + # extend observation specs + self.observation_spec = CompositeSpec( + **self.observation_spec, + current_time=current_time, + current_loc=current_loc, + durations=durations, + time_windows=time_windows, + # vehicle_idx=vehicle_idx, + ) + + def generate_data(self, batch_size) -> TensorDict: + """ + Generates time windows and service durations for the locations. The depot has a time window of [0, self.max_time]. + The time windows define the time span within which a service has to be started. To reach the depot in time from the last node, + the end time of each node is bounded by the service duration and the distance back to the depot. + The start times of the time windows are bounded by how long it takes to travel there from the depot. + """ + td = super().generate_data(batch_size) + + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + ## define service durations + durations = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device) * 0.05 +0.15 + # shape: (batch, problem) + # range: (0.15,0.2) for T=4.6 + + node_lengthTW = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device) * 0.05 +0.15 + # shape: (batch, problem) + # range: (0.15,0.2) for T=4.6 + + + d0i = ((torch.cat((td["depot"][..., None, :], td["locs"]), -2) - td["depot"][..., None, :].expand(size=(*batch_size,self.num_loc + 1,2)))**2).sum(2).sqrt() + # shape: (batch, problem) + + + ei = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device).mul((torch.div((self.max_time*torch.ones(size=(*batch_size,self.num_loc + 1), dtype=torch.float32, device=self.device) - durations - node_lengthTW),d0i) - 1)-1)+1 + # shape: (batch, problem) + # default velocity = 1.0 + + min_times = ei.mul(d0i) + # shape: (batch, problem) + # default velocity = 1.0 + + max_times = min_times + node_lengthTW + # shape: (batch, problem) + + min_times[..., :, 0] = 0.0 + max_times[..., :, 0] = self.max_time + + + # scale to [0, 1] + if self.scale: + durations = durations / self.max_time + min_times = min_times / self.max_time + max_times = max_times / self.max_time + td["depot"] = td["depot"] / self.max_time + td["locs"] = td["locs"] / self.max_time + + # 8. stack to tensor time_windows + time_windows = torch.stack((min_times, max_times), dim=-1) + + assert torch.all( + min_times < max_times + ), "Please make sure the relation between max_loc and max_time allows for feasible solutions." + + # reset duration at depot to 0 + durations[:, 0] = 0.0 + + td.update( + { + "durations": durations, + "time_windows": time_windows, + } + ) + return td + + @staticmethod + def get_action_mask(td: TensorDict) -> torch.Tensor: + """In addition to the constraints considered in the CVRPEnv, the time windows are considered. + The vehicle can only visit a location if it can reach it in time, i.e. before its time window ends. + """ + not_masked = CVRPEnv.get_action_mask(td) + batch_size = td["locs"].shape[0] + current_loc = gather_by_index(td["locs"], td["current_node"]).reshape( + [batch_size, 2] + ) + dist = get_distance(current_loc, td["locs"].transpose(0, 1)).transpose(0, 1) + td.update({"current_loc": current_loc, "distances": dist}) + can_reach_in_time = ( + td["current_time"] + dist <= td["time_windows"][..., 1] + ) # I only need to start the service before the time window ends, not finish it. + + return not_masked & can_reach_in_time + + def _step(self, td: TensorDict) -> TensorDict: + """In addition to the calculations in the CVRPEnv, the current time is + updated to keep track of which nodes are still reachable in time. + The current_node is updeted in the parent class' _step() function. + """ + + batch_size = td["locs"].shape[0] + # update current_time + distance = gather_by_index(td["distances"], td["action"]).reshape([batch_size, 1]) + duration = gather_by_index(td["durations"], td["action"]).reshape([batch_size, 1]) + start_times = gather_by_index(td["time_windows"], td["action"])[..., 0].reshape( + [batch_size, 1] + ) + td["current_time"] = (td["action"][:, None] != 0) * ( + torch.max(td["current_time"] + distance, start_times) + duration + ) + # current_node is updated to the selected action + td = super()._step(td) + return td + + def _reset( + self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None + ) -> TensorDict: + if batch_size is None: + batch_size = self.batch_size if td is None else td["locs"].shape[:-2] + if td is None or td.is_empty(): + td = self.generate_data(batch_size=batch_size) + batch_size = [batch_size] if isinstance(batch_size, int) else batch_size + + self.to(td.device) + # Create reset TensorDict + td_reset = TensorDict( + { + "locs": torch.cat((td["depot"][..., None, :], td["locs"]), -2), + "demand": td["demand"], + "current_node": torch.zeros( + *batch_size, 1, dtype=torch.long, device=self.device + ), + "current_time": torch.zeros( + *batch_size, 1, dtype=torch.float32, device=self.device + ), + "used_capacity": torch.zeros((*batch_size, 1), device=self.device), + "vehicle_capacity": torch.full( + (*batch_size, 1), self.vehicle_capacity, device=self.device + ), + "visited": torch.zeros( + (*batch_size, 1, td["locs"].shape[-2] + 1), + dtype=torch.uint8, + device=self.device, + ), + "durations": td["durations"], + "time_windows": td["time_windows"], + }, + batch_size=batch_size, + ) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + return td_reset + + def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: + """The reward is the negative tour length. Time windows + are not considered for the calculation of the reward.""" + return super().get_reward(td, actions) + + @staticmethod + def check_solution_validity(td: TensorDict, actions: torch.Tensor): + CVRPEnv.check_solution_validity(td, actions) + batch_size = td["locs"].shape[0] + # distances to depot + distances = get_distance( + td["locs"][..., 0, :], td["locs"].transpose(0, 1) + ).transpose(0, 1) + # basic checks on time windows + assert torch.all(distances >= 0.0), "Distances must be non-negative." + assert torch.all(td["time_windows"] >= 0.0), "Time windows must be non-negative." + assert torch.all( + td["time_windows"][..., :, 0] + distances + td["durations"] + <= td["time_windows"][..., 0, 1][0] # max_time is the same for all batches + ), "vehicle cannot perform service and get back to depot in time." + assert torch.all( + td["durations"] >= 0.0 + ), "Service durations must be non-negative." + assert torch.all( + td["time_windows"][..., 0] < td["time_windows"][..., 1] + ), "there are unfeasible time windows" + # check vehicles can meet deadlines + curr_time = torch.zeros(batch_size, 1, dtype=torch.float32, device=td.device) + curr_node = torch.zeros_like(curr_time, dtype=torch.int64, device=td.device) + for ii in range(actions.size(1)): + next_node = actions[:, ii] + dist = get_distance( + gather_by_index(td["locs"], curr_node).reshape([batch_size, 2]), + gather_by_index(td["locs"], next_node).reshape([batch_size, 2]), + ).reshape([batch_size, 1]) + curr_time = torch.max( + (curr_time + dist).int(), + gather_by_index(td["time_windows"], next_node)[..., 0].reshape( + [batch_size, 1] + ), + ) + assert torch.all( + curr_time + <= gather_by_index(td["time_windows"], next_node)[..., 1].reshape( + [batch_size, 1] + ) + ), "vehicle cannot start service before deadline" + curr_time = curr_time + gather_by_index(td["durations"], next_node).reshape( + [batch_size, 1] + ) + curr_node = next_node + curr_time[curr_node == 0] = 0.0 # reset time for depot + + @staticmethod + def render(td: TensorDict, actions=None, ax=None, scale_xy: bool = False, **kwargs): + CVRPEnv.render(td=td, actions=actions, ax=ax, scale_xy=scale_xy, **kwargs) + + @staticmethod + def load_data( + name: str, + solomon=False, + path_instances: str = None, + type: str = None, + compute_edge_weights: bool = False, + ): + if solomon == True: + assert type in [ + "instance", + "solution", + ], "type must be either 'instance' or 'solution'" + if type == "instance": + instance = load_solomon_instance( + name=name, path=path_instances, edge_weights=compute_edge_weights + ) + elif type == "solution": + instance = load_solomon_solution(name=name, path=path_instances) + return instance + return load_npz_to_tensordict(filename=name) + + def extract_from_solomon(self, instance: dict, batch_size: int = 1): + # extract parameters for the environment from the Solomon instance + self.min_demand = instance["demand"][1:].min() + self.max_demand = instance["demand"][1:].max() + self.vehicle_capacity = instance["capacity"] + self.min_loc = instance["node_coord"][1:].min() + self.max_loc = instance["node_coord"][1:].max() + self.min_time = instance["time_window"][:, 0].min() + self.max_time = instance["time_window"][:, 1].max() + # assert the time window of the depot starts at 0 and ends at max_time + assert self.min_time == 0, "Time window of depot must start at 0." + assert ( + self.max_time == instance["time_window"][0, 1] + ), "Depot must have latest end time." + # convert to format used in CVRPTWEnv + td = TensorDict( + { + "depot": torch.tensor( + instance["node_coord"][0], + dtype=torch.float32, + device=self.device, + ).repeat(batch_size, 1), + "locs": torch.tensor( + instance["node_coord"][1:], + dtype=torch.float32, + device=self.device, + ).repeat(batch_size, 1, 1), + "demand": torch.tensor( + instance["demand"][1:], + dtype=torch.float32, + device=self.device, + ).repeat(batch_size, 1), + "durations": torch.tensor( + instance["service_time"], + dtype=torch.int64, + device=self.device, + ).repeat(batch_size, 1), + "time_windows": torch.tensor( + instance["time_window"], + dtype=torch.int64, + device=self.device, + ).repeat(batch_size, 1, 1), + }, + batch_size=1, # we assume batch_size will always be 1 for loaded instances + ) + return self.reset(td, batch_size=batch_size) diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index ab7510c3..75843c42 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -1,8 +1,6 @@ import torch import torch.nn as nn -from tensordict import TensorDict - from rl4co.utils.ops import gather_by_index @@ -19,8 +17,11 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "tsp": TSPContext, "atsp": TSPContext, "cvrp": VRPContext, + "vrpb": VRPContext, + "ovrp": VRPContext, + "vrpl": VRPContext, "cvrptw": VRPTWContext, - "ffsp": FFSPContext, + "vrptw": VRPTWContext, "svrp": SVRPContext, "sdvrp": VRPContext, "pctsp": PCTSPContext, @@ -31,7 +32,6 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "pdp": PDPContext, "mtsp": MTSPContext, "smtwtp": SMTWTPContext, - "mdcpdp": MDCPDPContext, } if env_name not in embedding_registry: @@ -73,34 +73,6 @@ def forward(self, embeddings, td): return self.project_context(context_embedding) -class FFSPContext(EnvContext): - def __init__(self, embedding_dim, stage_cnt=None): - self.has_stage_emb = stage_cnt is not None - step_context_dim = (1 + int(self.has_stage_emb)) * embedding_dim - super().__init__(embedding_dim=embedding_dim, step_context_dim=step_context_dim) - if self.has_stage_emb: - self.stage_emb = nn.Parameter(torch.rand(stage_cnt, embedding_dim)) - - def _cur_node_embedding(self, embeddings: TensorDict, td): - cur_node_embedding = gather_by_index( - embeddings["machine_embeddings"], td["stage_machine_idx"] - ) - return cur_node_embedding - - def forward(self, embeddings, td): - cur_node_embedding = self._cur_node_embedding(embeddings, td) - if self.has_stage_emb: - state_embedding = self._state_embedding(embeddings, td) - context_embedding = torch.cat([cur_node_embedding, state_embedding], -1) - return self.project_context(context_embedding) - else: - return self.project_context(cur_node_embedding) - - def _state_embedding(self, _, td): - cur_stage_emb = self.stage_emb[td["stage_idx"]] - return cur_stage_emb - - class TSPContext(EnvContext): """Context embedding for the Traveling Salesman Problem (TSP). Project the following to the embedding space: @@ -150,6 +122,27 @@ def _state_embedding(self, embeddings, td): state_embedding = td["vehicle_capacity"] - td["used_capacity"] return state_embedding +class VRPLContext(VRPContext): + """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). + Project the following to the embedding space: + - current node embedding + - remaining capacity (vehicle_capacity - used_capacity) + - current time + """ + + def __init__(self, embedding_dim): + super(VRPContext, self).__init__( + embedding_dim=embedding_dim, step_context_dim=embedding_dim + 2 + ) + + def _cur_node_embedding(self, embeddings, td): + return super()._cur_node_embedding(embeddings, td).reshape(embeddings.size(0), -1) + + def _state_embedding(self, embeddings, td): + capacity = super()._state_embedding(embeddings, td) + duration_limit = td["duration_limit"] + return torch.cat([capacity, duration_limit], -1) + class VRPTWContext(VRPContext): """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). @@ -308,17 +301,3 @@ def _cur_node_embedding(self, embeddings, td): def _state_embedding(self, embeddings, td): state_embedding = td["current_time"] return state_embedding - - -class MDCPDPContext(EnvContext): - """Context embedding for the MDCPDP. - Project the following to the embedding space: - - current node embedding - """ - - def __init__(self, embedding_dim): - super(MDCPDPContext, self).__init__(embedding_dim, embedding_dim) - - def forward(self, embeddings, td): - cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze() - return self.project_context(cur_node_embedding) diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index cf934e1a..d72253ec 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -1,8 +1,6 @@ import torch import torch.nn as nn -from tensordict.tensordict import TensorDict - def env_init_embedding(env_name: str, config: dict) -> nn.Module: """Get environment initial embedding. The init embedding is used to initialize the @@ -16,9 +14,12 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: embedding_registry = { "tsp": TSPInitEmbedding, "atsp": TSPInitEmbedding, - "matnet": MatNetInitEmbedding, "cvrp": VRPInitEmbedding, + "vrpb": VRPInitEmbedding, + "vrpl": VRPInitEmbedding, + "ovrp": VRPInitEmbedding, "cvrptw": VRPTWInitEmbedding, + "vrptw": VRPTWInitEmbedding, "svrp": SVRPInitEmbedding, "sdvrp": VRPInitEmbedding, "pctsp": PCTSPInitEmbedding, @@ -29,7 +30,6 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "pdp": PDPInitEmbedding, "mtsp": MTSPInitEmbedding, "smtwtp": SMTWTPInitEmbedding, - "mdcpdp": MDCPDPInitEmbedding, } if env_name not in embedding_registry: @@ -56,50 +56,6 @@ def forward(self, td): return out -class MatNetInitEmbedding(nn.Module): - """ - Preparing the initial row and column embeddings for FFSP. - - Reference: - https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L51 - - - """ - - def __init__(self, embedding_dim: int, mode: str = "RandomOneHot") -> None: - super().__init__() - - self.embedding_dim = embedding_dim - assert mode in { - "RandomOneHot", - "Random", - }, "mode must be one of ['RandomOneHot', 'Random']" - self.mode = mode - - def forward(self, td: TensorDict): - dmat = td["cost_matrix"] - b, r, c = dmat.shape - - row_emb = torch.zeros(b, r, self.embedding_dim, device=dmat.device) - - if self.mode == "RandomOneHot": - # MatNet uses one-hot encoding for column embeddings - # https://github.com/yd-kwon/MatNet/blob/782698b60979effe2e7b61283cca155b7cdb727f/ATSP/ATSP_MatNet/ATSPModel.py#L60 - col_emb = torch.zeros(b, c, self.embedding_dim, device=dmat.device) - rand = torch.rand(b, c) - rand_idx = rand.argsort(dim=1) - b_idx = torch.arange(b)[:, None].expand(b, c) - n_idx = torch.arange(c)[None, :].expand(b, c) - col_emb[b_idx, n_idx, rand_idx] = 1.0 - - elif self.mode == "Random": - col_emb = torch.rand(b, r, self.embedding_dim, device=dmat.device) - else: - raise NotImplementedError - - return row_emb, col_emb, dmat - - class VRPInitEmbedding(nn.Module): """Initial embedding for the Vehicle Routing Problems (VRP). Embed the following node features to the embedding space: @@ -358,32 +314,3 @@ def forward(self, td): feat = torch.stack((job_due_time, job_weight, job_process_time), dim=-1) out = self.init_embed(feat) return out - - -class MDCPDPInitEmbedding(nn.Module): - """Initial embedding for the MDCPDP environment - Embed the following node features to the embedding space: - - locs: x, y coordinates of the nodes (depot, pickups and deliveries separately) - Note that pickups and deliveries are interleaved in the input. - """ - - def __init__(self, embedding_dim, linear_bias=True): - super(MDCPDPInitEmbedding, self).__init__() - node_dim = 2 # x, y - self.init_embed_depot = nn.Linear(2, embedding_dim, linear_bias) - self.init_embed_pick = nn.Linear(node_dim * 2, embedding_dim, linear_bias) - self.init_embed_delivery = nn.Linear(node_dim, embedding_dim, linear_bias) - - def forward(self, td): - num_depots = td["capacity"].size(-1) - depot, locs = td["locs"][..., 0:num_depots, :], td["locs"][..., num_depots:, :] - num_locs = locs.size(-2) - pick_feats = torch.cat( - [locs[:, : num_locs // 2, :], locs[:, num_locs // 2 :, :]], -1 - ) # [batch_size, graph_size//2, 4] - delivery_feats = locs[:, num_locs // 2 :, :] # [batch_size, graph_size//2, 2] - depot_embeddings = self.init_embed_depot(depot) - pick_embeddings = self.init_embed_pick(pick_feats) - delivery_embeddings = self.init_embed_delivery(delivery_feats) - # concatenate on graph size dimension - return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) diff --git a/rl4co/models/zoo/common/autoregressive/decoder.py b/rl4co/models/zoo/common/autoregressive/decoder.py index 250191fb..0917d034 100644 --- a/rl4co/models/zoo/common/autoregressive/decoder.py +++ b/rl4co/models/zoo/common/autoregressive/decoder.py @@ -14,7 +14,7 @@ from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding from rl4co.models.nn.env_embeddings.dynamic import StaticEmbedding from rl4co.models.nn.utils import get_log_likelihood -from rl4co.utils.ops import batchify, unbatchify +from rl4co.utils.ops import batchify, select_start_nodes, unbatchify from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) @@ -52,6 +52,7 @@ class AutoregressiveDecoder(nn.Module): embedding_dim: Dimension of the embeddings num_heads: Number of heads for the attention use_graph_context: Whether to use the initial graph context to modify the query + select_start_nodes_fn: Function to select the start nodes for multi-start decoding linear_bias: Whether to use a bias in the linear projection of the embeddings context_embedding: Module to compute the context embedding. If None, the default is used dynamic_embedding: Module to compute the dynamic embedding. If None, the default is used @@ -63,6 +64,7 @@ def __init__( embedding_dim: int, num_heads: int, use_graph_context: bool = True, + select_start_nodes_fn: callable = select_start_nodes, linear_bias: bool = False, context_embedding: nn.Module = None, dynamic_embedding: nn.Module = None, @@ -107,6 +109,8 @@ def __init__( embedding_dim, num_heads, **logit_attn_kwargs ) + self.select_start_nodes_fn = select_start_nodes_fn + def forward( self, td: TensorDict, @@ -148,6 +152,10 @@ def forward( # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeds = self._precompute_cache(embeddings, td=td) + # If `select_start_nodes_fn` is not being passed, we use the class attribute + if "select_start_nodes_fn" not in strategy_kwargs: + strategy_kwargs["select_start_nodes_fn"] = self.select_start_nodes_fn + # Setup decoding strategy decode_strategy: DecodingStrategy = get_decoding_strategy( decode_type, **strategy_kwargs @@ -159,6 +167,7 @@ def forward( # Main decoding: loop until all sequences are done while not td["done"].all(): log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts) + #print(mask) td = decode_strategy.step(log_p, mask, td) td = env.step(td)["next"] diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py index cff8a70f..3a104c5e 100644 --- a/rl4co/utils/ops.py +++ b/rl4co/utils/ops.py @@ -93,6 +93,31 @@ def get_tour_length(ordered_locs): ordered_locs_next = torch.roll(ordered_locs, 1, dims=-2) return get_distance(ordered_locs_next, ordered_locs).sum(-1) +@torch.jit.script +def get_open_tour_length(ordered_locs): + """Compute the total tour distance for a batch of ordered tours. + Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up. + + Args: + ordered_locs: Tensor of shape [batch_size, num_nodes, 2] containing the ordered locations of the tour + """ + ordered_locs_next = torch.roll(ordered_locs, 1, dims=-2) + + segment_lengths = ((ordered_locs_next-ordered_locs)**2).sum(-1).sqrt() + + # Get the first value of ordered_locs + first_loc = ordered_locs[:, 0, :][:,None,:].expand(ordered_locs_next.shape) + + # Check the ids where the location is the same as the first value + same_loc_ids = torch.all(ordered_locs_next == first_loc, dim=-1) + + # for open VRP, the distance between last customer and the depot is not counted + segment_lengths[same_loc_ids] = 0 + + travel_distances = segment_lengths.sum(1) + + return travel_distances + @torch.jit.script def get_distance_matrix(locs: Tensor): @@ -114,7 +139,6 @@ def get_num_starts(td, env_name=None): ) // 2 # only half of the nodes (i.e. pickup nodes) can be start nodes elif env_name in ["cvrp", "sdvrp", "mtsp", "op", "pctsp", "spctsp"]: num_starts = num_starts - 1 # depot cannot be a start node - return num_starts From 044c7899799dc3c76a496549b022f98fea554df0 Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Mon, 13 May 2024 17:33:51 +0800 Subject: [PATCH 2/6] update mtvrp --- rl4co/envs/__init__.py | 13 +- rl4co/envs/routing/__init__.py | 42 +- rl4co/envs/routing/mtvrp/env.py | 26 +- rl4co/envs/routing/mtvrp/generator.py | 3 +- rl4co/envs/routing/ovrp.py | 440 ----------------- rl4co/envs/routing/vrpb.py | 441 ----------------- rl4co/envs/routing/vrpl.py | 511 -------------------- rl4co/envs/routing/vrptw.py | 358 -------------- rl4co/envs/scheduling/__init__.py | 1 - rl4co/envs/scheduling/fjsp/__init__.py | 2 - rl4co/envs/scheduling/fjsp/env.py | 424 ----------------- rl4co/envs/scheduling/fjsp/generator.py | 216 --------- rl4co/envs/scheduling/fjsp/parser.py | 180 ------- rl4co/envs/scheduling/fjsp/render.py | 72 --- rl4co/envs/scheduling/fjsp/utils.py | 333 ------------- rl4co/models/__init__.py | 1 - rl4co/models/nn/env_embeddings/context.py | 49 +- rl4co/models/nn/env_embeddings/init.py | 92 +--- rl4co/models/nn/ops.py | 30 -- rl4co/models/rl/reinforce/reinforce.py | 2 +- rl4co/models/zoo/__init__.py | 1 - rl4co/models/zoo/hetgnn/__init__.py | 1 - rl4co/models/zoo/hetgnn/decoder.py | 51 -- rl4co/models/zoo/hetgnn/encoder.py | 132 ----- rl4co/models/zoo/hetgnn/model.py | 38 -- rl4co/models/zoo/hetgnn/policy.py | 99 ---- rl4co/tasks/__init__.py | 0 rl4co/tasks/eval.py | 405 ---------------- rl4co/tasks/train.py | 117 ----- rl4co/utils/__init__.py | 11 - rl4co/utils/callbacks/speed_monitor.py | 123 ----- rl4co/utils/decoding.py | 555 ---------------------- rl4co/utils/instantiators.py | 51 -- rl4co/utils/lightning.py | 76 --- rl4co/utils/ops.py | 268 ----------- rl4co/utils/optim_helpers.py | 38 -- rl4co/utils/param_grouping.py | 138 ------ rl4co/utils/pylogger.py | 25 - rl4co/utils/rich_utils.py | 97 ---- rl4co/utils/test_utils.py | 62 --- rl4co/utils/trainer.py | 152 ------ rl4co/utils/utils.py | 287 ----------- 42 files changed, 123 insertions(+), 5840 deletions(-) delete mode 100644 rl4co/envs/routing/ovrp.py delete mode 100644 rl4co/envs/routing/vrpb.py delete mode 100644 rl4co/envs/routing/vrpl.py delete mode 100644 rl4co/envs/routing/vrptw.py delete mode 100644 rl4co/envs/scheduling/fjsp/__init__.py delete mode 100644 rl4co/envs/scheduling/fjsp/env.py delete mode 100644 rl4co/envs/scheduling/fjsp/generator.py delete mode 100644 rl4co/envs/scheduling/fjsp/parser.py delete mode 100644 rl4co/envs/scheduling/fjsp/render.py delete mode 100644 rl4co/envs/scheduling/fjsp/utils.py delete mode 100644 rl4co/models/zoo/hetgnn/__init__.py delete mode 100644 rl4co/models/zoo/hetgnn/decoder.py delete mode 100644 rl4co/models/zoo/hetgnn/encoder.py delete mode 100644 rl4co/models/zoo/hetgnn/model.py delete mode 100644 rl4co/models/zoo/hetgnn/policy.py delete mode 100644 rl4co/tasks/__init__.py delete mode 100644 rl4co/tasks/eval.py delete mode 100644 rl4co/tasks/train.py delete mode 100644 rl4co/utils/__init__.py delete mode 100644 rl4co/utils/callbacks/speed_monitor.py delete mode 100644 rl4co/utils/decoding.py delete mode 100644 rl4co/utils/instantiators.py delete mode 100644 rl4co/utils/lightning.py delete mode 100644 rl4co/utils/ops.py delete mode 100644 rl4co/utils/optim_helpers.py delete mode 100644 rl4co/utils/param_grouping.py delete mode 100644 rl4co/utils/pylogger.py delete mode 100644 rl4co/utils/rich_utils.py delete mode 100644 rl4co/utils/test_utils.py delete mode 100644 rl4co/utils/trainer.py delete mode 100644 rl4co/utils/utils.py diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index dbad713c..17961f6f 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -9,19 +9,16 @@ ATSPEnv, CVRPEnv, CVRPTWEnv, + MDCPDPEnv, MTSPEnv, + MTVRPEnv, OPEnv, PCTSPEnv, PDPEnv, SDVRPEnv, - SVRPEnv, SPCTSPEnv, + SVRPEnv, TSPEnv, - MDCPDPEnv, - VRPLEnv, - OVRPEnv, - VRPTWEnv, - VRPBEnv, ) # Scheduling @@ -45,10 +42,6 @@ "tsp": TSPEnv, "smtwtp": SMTWTPEnv, "mdcpdp": MDCPDPEnv, - "vrpl": VRPLEnv, - "ovrp": OVRPEnv, - "vrptw": VRPTWEnv, - "vrpb": VRPBEnv, } diff --git a/rl4co/envs/routing/__init__.py b/rl4co/envs/routing/__init__.py index 50870257..9c16f758 100644 --- a/rl4co/envs/routing/__init__.py +++ b/rl4co/envs/routing/__init__.py @@ -1,16 +1,26 @@ -from rl4co.envs.routing.atsp import ATSPEnv -from rl4co.envs.routing.cvrp import CVRPEnv -from rl4co.envs.routing.cvrptw import CVRPTWEnv -from rl4co.envs.routing.mtsp import MTSPEnv -from rl4co.envs.routing.op import OPEnv -from rl4co.envs.routing.pctsp import PCTSPEnv -from rl4co.envs.routing.pdp import PDPEnv -from rl4co.envs.routing.sdvrp import SDVRPEnv -from rl4co.envs.routing.spctsp import SPCTSPEnv -from rl4co.envs.routing.svrp import SVRPEnv -from rl4co.envs.routing.tsp import TSPEnv -from rl4co.envs.routing.mdcpdp import MDCPDPEnv -from rl4co.envs.routing.vrpl import VRPLEnv -from rl4co.envs.routing.ovrp import OVRPEnv -from rl4co.envs.routing.vrptw import VRPTWEnv -from rl4co.envs.routing.vrpb import VRPBEnv +from rl4co.envs.routing.atsp.env import ATSPEnv +from rl4co.envs.routing.cvrp.env import CVRPEnv +from rl4co.envs.routing.cvrptw.env import CVRPTWEnv +from rl4co.envs.routing.mdcpdp.env import MDCPDPEnv +from rl4co.envs.routing.mtsp.env import MTSPEnv +from rl4co.envs.routing.mtvrp.env import MTVRPEnv +from rl4co.envs.routing.op.env import OPEnv +from rl4co.envs.routing.pctsp.env import PCTSPEnv +from rl4co.envs.routing.pdp.env import PDPEnv +from rl4co.envs.routing.sdvrp.env import SDVRPEnv +from rl4co.envs.routing.spctsp.env import SPCTSPEnv +from rl4co.envs.routing.svrp.env import SVRPEnv +from rl4co.envs.routing.tsp.env import TSPEnv + +from rl4co.envs.routing.atsp.generator import ATSPGenerator +from rl4co.envs.routing.cvrp.generator import CVRPGenerator +from rl4co.envs.routing.cvrptw.generator import CVRPTWGenerator +from rl4co.envs.routing.mtsp.generator import MTSPGenerator +from rl4co.envs.routing.mtvrp.generator import MTVRPGenerator +from rl4co.envs.routing.op.generator import OPGenerator +from rl4co.envs.routing.pctsp.generator import PCTSPGenerator +from rl4co.envs.routing.pdp.generator import PDPGenerator +from rl4co.envs.routing.svrp.generator import SVRPGenerator +from rl4co.envs.routing.tsp.generator import TSPGenerator +from rl4co.envs.routing.mdcpdp.generator import MDCPDPGenerator + diff --git a/rl4co/envs/routing/mtvrp/env.py b/rl4co/envs/routing/mtvrp/env.py index ce9d34a7..e46ccc1e 100644 --- a/rl4co/envs/routing/mtvrp/env.py +++ b/rl4co/envs/routing/mtvrp/env.py @@ -281,7 +281,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: & ~exceeds_dist_limit & ~td["visited"] ) - + #print(can_visit) # Mask depot: don't visit depot if coming from there and there are still customer nodes I can visit can_visit[:, 0] = ~((curr_node == 0) & (can_visit[:, 1:].sum(-1) > 0)) return can_visit @@ -349,9 +349,29 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): curr_time = torch.max( curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] ) + # if not torch.all( + # curr_time-1E-6 <= gather_by_index(td["time_windows"], next_node)[..., 1] + # ): + # unsatisfied_indices = torch.nonzero(~(curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] + # ), as_tuple=True) + # print() + new_shape = curr_time.size() + skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape) + if not torch.all( + (curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end + ): + unsatisfied_indices = torch.nonzero(~((curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end + ), as_tuple=True) + print(skip_open_end) + print(unsatisfied_indices) + print(curr_time) + print(curr_time[unsatisfied_indices]) + print(next_node[unsatisfied_indices]) + input() assert torch.all( - curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] + (curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end ), "vehicle cannot start service before deadline" + curr_time = curr_time + gather_by_index(td["service_time"], next_node) curr_node = next_node curr_time[curr_node == 0] = 0.0 # reset time for depot @@ -450,7 +470,7 @@ def _make_spec(self, td_params: TensorDict): def check_variants(td): """Check if the problem has the variants""" has_open = td["open_route"].squeeze(-1) - has_tw = (td["time_windows"][:, :, 1] != float("inf")).any(-1) + has_tw = (td["time_windows"][:, :, 1] != 4.6).any(-1) has_limit = (td["distance_limit"] != float("inf")).squeeze(-1) has_backhaul = (td["demand_backhaul"] != 0).any(-1) return has_open, has_tw, has_limit, has_backhaul diff --git a/rl4co/envs/routing/mtvrp/generator.py b/rl4co/envs/routing/mtvrp/generator.py index 81692ade..d441b306 100644 --- a/rl4co/envs/routing/mtvrp/generator.py +++ b/rl4co/envs/routing/mtvrp/generator.py @@ -256,7 +256,8 @@ def _default_open(td, remove): @staticmethod def _default_time_window(td, remove): default_tw = torch.zeros_like(td["time_windows"]) - default_tw[..., 1] = float("inf") + #default_tw[..., 1] = float("inf") + default_tw[..., 1] = 4.6 # max tw td["time_windows"][remove] = default_tw[remove] td["service_time"][remove] = torch.zeros_like(td["service_time"][remove]) return td diff --git a/rl4co/envs/routing/ovrp.py b/rl4co/envs/routing/ovrp.py deleted file mode 100644 index 7f49be86..00000000 --- a/rl4co/envs/routing/ovrp.py +++ /dev/null @@ -1,440 +0,0 @@ -from typing import Optional - -import torch - -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.data.utils import load_npz_to_tensordict -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.ops import gather_by_index, get_open_tour_length -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 -CAPACITIES = { - 10: 20.0, - 15: 25.0, - 20: 30.0, - 30: 33.0, - 40: 37.0, - 50: 40.0, - 60: 43.0, - 75: 45.0, - 100: 50.0, - 125: 55.0, - 150: 60.0, - 200: 70.0, - 500: 100.0, - 1000: 150.0, -} - - -class OVRPEnv(RL4COEnvBase): - """Capacitated Vehicle Routing Problem (CVRP) environment. - At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. - When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. - In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. - - Args: - num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc: minimum value for the location coordinates - max_loc: maximum value for the location coordinates - min_demand: minimum value for the demand of each customer - max_demand: maximum value for the demand of each customer - vehicle_capacity: capacity of the vehicle - td_params: parameters of the environment - """ - - name = "ovrp" - - def __init__( - self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_demand: float = 1, - max_demand: float = 10, - vehicle_capacity: float = 1.0, - capacity: float = None, - td_params: TensorDict = None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_demand = min_demand - self.max_demand = max_demand - self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity - if self.capacity is None: - raise ValueError( - f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." - ) - self.vehicle_capacity = vehicle_capacity - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - current_node = td["action"][:, None] # Add dimension for step - n_loc = td["demand"].size(-1) # Excludes depot - - # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! - selected_demand = gather_by_index( - td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False - ) - - # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() - - # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot - # Add one dimension since we write a single value - visited = td["visited"].scatter(-1, current_node[..., None], 1) - - # SECTION: get done - done = visited.sum(-1) == visited.size(-1) - reward = torch.zeros_like(done) - - td.update( - { - "current_node": current_node, - "used_capacity": used_capacity, - "visited": visited, - "reward": reward, - "done": done, - } - ) - td.set("action_mask", self.get_action_mask(td)) - return td - - def _reset( - self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, - ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - - # Create reset TensorDict - td_reset = TensorDict( - { - "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), - "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device - ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), - "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device - ), - "visited": torch.zeros( - (*batch_size, 1, td["locs"].shape[-2] + 1), - dtype=torch.uint8, - device=self.device, - ), - }, - batch_size=batch_size, - ) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - return td_reset - - @staticmethod - def get_action_mask(td: TensorDict) -> torch.Tensor: - # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting - exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] - ) - - # Nodes that cannot be visited are already visited or too much demand to be served now - mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap - - # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) - return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) - - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - # Check that the solution is valid - if self.check_solution: - self.check_solution_validity(td, actions) - - # Gather dataset in order of tour - batch_size = td["locs"].shape[0] - depot = td["locs"][..., 0:1, :] - locs_ordered = torch.cat( - [ - depot, - gather_by_index(td["locs"], actions).reshape( - [batch_size, actions.size(-1), 2] - ), - ], - dim=1, - ) - return -get_open_tour_length(locs_ordered) - - @staticmethod - def check_solution_validity(td: TensorDict, actions: torch.Tensor): - """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" - # Check if tour is valid, i.e. contain 0 to n-1 - batch_size, graph_size = td["demand"].size() - sorted_pi = actions.data.sort(1)[0] - - # Sorting it should give all zeros at front and then 1...n - assert ( - torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) - .view(1, -1) - .expand(batch_size, graph_size) - == sorted_pi[:, -graph_size:] - ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" - - # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) - demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) - d = demand_with_depot.gather(1, actions) - - used_cap = torch.zeros_like(td["demand"][:, 0]) - for i in range(actions.size(1)): - used_cap += d[ - :, i - ] # This will reset/make capacity negative if i == 0, e.g. depot visited - # Cannot use less than 0 - used_cap[used_cap < 0] = 0 - assert ( - used_cap <= td["vehicle_capacity"] + 1e-5 - ).all(), "Used more than capacity" - - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize the demand for nodes except the depot - # Demand sampling Following Kool et al. (2019) - # Generates a slightly different distribution than using torch.randint - demand = ( - ( - torch.FloatTensor(*batch_size, self.num_loc) - .uniform_(self.min_demand - 1, self.max_demand - 1) - .int() - + 1 - ) - .float() - .to(self.device) - ) - - # Support for heterogeneous capacity if provided - if not isinstance(self.capacity, torch.Tensor): - capacity = torch.full((*batch_size,), self.capacity, device=self.device) - else: - capacity = self.capacity - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "demand": demand / self.capacity, - "capacity": capacity, - }, - batch_size=batch_size, - device=self.device, - ) - - @staticmethod - def load_data(fpath, batch_size=[]): - """Dataset loading from file - Normalize demand by capacity to be in [0, 1] - """ - td_load = load_npz_to_tensordict(fpath) - td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) - return td_load - - def _make_spec(self, td_params: TensorDict): - """Make the observation and action specs from the parameters.""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), - dtype=torch.float32, - ), - current_node=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - demand=BoundedTensorSpec( - low=-self.capacity, - high=self.max_demand, - shape=(self.num_loc, 1), # demand is only for customers - dtype=torch.float32, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.num_loc + 1, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render( - td: TensorDict, - actions=None, - ax=None, - scale_xy: bool = True, - ): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) - demands = td["demand"] * scale_demand - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # plot demand bars - for node_idx in range(1, len(locs)): - ax.add_patch( - plt.Rectangle( - (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), - 0.01, - demands[node_idx - 1] / (scale_demand * 10), - edgecolor=cm.Set2(0), - facecolor=cm.Set2(0), - fill=True, - ) - ) - - # text demand - for node_idx in range(1, len(locs)): - ax.text( - locs[node_idx, 0], - locs[node_idx, 1] - 0.025, - f"{demands[node_idx-1].item():.2f}", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(0), - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx+1] == 0: - continue - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - - # Setup limits and show - if scale_xy: - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) - plt.show() diff --git a/rl4co/envs/routing/vrpb.py b/rl4co/envs/routing/vrpb.py deleted file mode 100644 index df5da1b1..00000000 --- a/rl4co/envs/routing/vrpb.py +++ /dev/null @@ -1,441 +0,0 @@ -from typing import Optional - -import torch - -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.data.utils import load_npz_to_tensordict -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.ops import gather_by_index, get_tour_length -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 -CAPACITIES = { - 10: 20.0, - 15: 25.0, - 20: 30.0, - 30: 33.0, - 40: 37.0, - 50: 40.0, - 60: 43.0, - 75: 45.0, - 100: 50.0, - 125: 55.0, - 150: 60.0, - 200: 70.0, - 500: 100.0, - 1000: 150.0, -} - - -class VRPBEnv(RL4COEnvBase): - """Capacitated Vehicle Routing Problem (CVRP) environment. - At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. - When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. - In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. - - Args: - num_loc: number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc: minimum value for the location coordinates - max_loc: maximum value for the location coordinates - min_demand: minimum value for the demand of each customer - max_demand: maximum value for the demand of each customer - vehicle_capacity: capacity of the vehicle - td_params: parameters of the environment - """ - - name = "vrpb" - - def __init__( - self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_demand: float = 1, - max_demand: float = 10, - vehicle_capacity: float = 1.0, - capacity: float = None, - td_params: TensorDict = None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_demand = min_demand - self.max_demand = max_demand - self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity - if self.capacity is None: - raise ValueError( - f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." - ) - self.vehicle_capacity = vehicle_capacity - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - current_node = td["action"][:, None] # Add dimension for step - n_loc = td["demand"].size(-1) # Excludes depot - - # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! - selected_demand = gather_by_index( - td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False - ) - - # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() - - # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot - # Add one dimension since we write a single value - visited = td["visited"].scatter(-1, current_node[..., None], 1) - - # SECTION: get done - done = visited.sum(-1) == visited.size(-1) - reward = torch.zeros_like(done) - - td.update( - { - "current_node": current_node, - "used_capacity": used_capacity, - "visited": visited, - "reward": reward, - "done": done, - } - ) - td.set("action_mask", self.get_action_mask(td)) - return td - - def _reset( - self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, - ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - - # Create reset TensorDict - td_reset = TensorDict( - { - "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), - "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device - ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), - "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device - ), - "visited": torch.zeros( - (*batch_size, 1, td["locs"].shape[-2] + 1), - dtype=torch.uint8, - device=self.device, - ), - }, - batch_size=batch_size, - ) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - return td_reset - - @staticmethod - def get_action_mask(td: TensorDict) -> torch.Tensor: - # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting - exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] - ) - - # Nodes that cannot be visited are already visited or too much demand to be served now - mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap - - # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) - return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) - - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - # Check that the solution is valid - # if self.check_solution: - # self.check_solution_validity(td, actions) - - # Gather dataset in order of tour - batch_size = td["locs"].shape[0] - depot = td["locs"][..., 0:1, :] - locs_ordered = torch.cat( - [ - depot, - gather_by_index(td["locs"], actions).reshape( - [batch_size, actions.size(-1), 2] - ), - ], - dim=1, - ) - return -get_tour_length(locs_ordered) - - @staticmethod - def check_solution_validity(td: TensorDict, actions: torch.Tensor): - """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" - # Check if tour is valid, i.e. contain 0 to n-1 - batch_size, graph_size = td["demand"].size() - sorted_pi = actions.data.sort(1)[0] - - # Sorting it should give all zeros at front and then 1...n - assert ( - torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) - .view(1, -1) - .expand(batch_size, graph_size) - == sorted_pi[:, -graph_size:] - ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" - - # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) - demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) - d = demand_with_depot.gather(1, actions) - - used_cap = torch.zeros_like(td["demand"][:, 0]) - for i in range(actions.size(1)): - used_cap += d[ - :, i - ] # This will reset/make capacity negative if i == 0, e.g. depot visited - # Cannot use less than 0 - used_cap[used_cap < 0] = 0 - assert ( - used_cap <= td["vehicle_capacity"] + 1e-5 - ).all(), "Used more than capacity" - - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize the demand for nodes except the depot - # Demand sampling Following Kool et al. (2019) - # Generates a slightly different distribution than using torch.randint - demand = ( - ( - torch.FloatTensor(*batch_size, self.num_loc) - .uniform_(self.min_demand - 1, self.max_demand - 1) - .int() - + 1 - ) - .float() - .to(self.device) - ) - - # set 20% to backhaul - linehaul = int(0.8*self.num_loc) - demand[:,linehaul:] = -demand[:,linehaul:] - - # Support for heterogeneous capacity if provided - if not isinstance(self.capacity, torch.Tensor): - capacity = torch.full((*batch_size,), self.capacity, device=self.device) - else: - capacity = self.capacity - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "demand": demand / self.capacity, - "capacity": capacity, - }, - batch_size=batch_size, - device=self.device, - ) - - @staticmethod - def load_data(fpath, batch_size=[]): - """Dataset loading from file - Normalize demand by capacity to be in [0, 1] - """ - td_load = load_npz_to_tensordict(fpath) - td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) - return td_load - - def _make_spec(self, td_params: TensorDict): - """Make the observation and action specs from the parameters.""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), - dtype=torch.float32, - ), - current_node=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - demand=BoundedTensorSpec( - low=-self.capacity, - high=self.max_demand, - shape=(self.num_loc, 1), # demand is only for customers - dtype=torch.float32, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.num_loc + 1, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render( - td: TensorDict, - actions=None, - ax=None, - scale_xy: bool = True, - ): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) - demands = td["demand"] * scale_demand - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # plot demand bars - for node_idx in range(1, len(locs)): - ax.add_patch( - plt.Rectangle( - (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), - 0.01, - demands[node_idx - 1] / (scale_demand * 10), - edgecolor=cm.Set2(0), - facecolor=cm.Set2(0), - fill=True, - ) - ) - - # text demand - for node_idx in range(1, len(locs)): - ax.text( - locs[node_idx, 0], - locs[node_idx, 1] - 0.025, - f"{demands[node_idx-1].item():.2f}", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(0), - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - - # Setup limits and show - if scale_xy: - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) - plt.show() diff --git a/rl4co/envs/routing/vrpl.py b/rl4co/envs/routing/vrpl.py deleted file mode 100644 index c5e1d4d9..00000000 --- a/rl4co/envs/routing/vrpl.py +++ /dev/null @@ -1,511 +0,0 @@ -from typing import Optional - -import torch - -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.data.utils import load_npz_to_tensordict -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.utils.ops import gather_by_index, get_tour_length, get_distance -from rl4co.utils.pylogger import get_pylogger - - -log = get_pylogger(__name__) - - -# From Kool et al. 2019, Hottung et al. 2022, Kim et al. 2023 -CAPACITIES = { - 10: 20.0, - 15: 25.0, - 20: 30.0, - 30: 33.0, - 40: 37.0, - 50: 40.0, - 60: 43.0, - 75: 45.0, - 100: 50.0, - 125: 55.0, - 150: 60.0, - 200: 70.0, - 500: 100.0, - 1000: 150.0, -} - - -class VRPLEnv(RL4COEnvBase): - """Capacitated Vehicle Routing Problem (CVRP) environment. - At each step, the agent chooses a customer to visit depending on the current location and the remaining capacity. - When the agent visits a customer, the remaining capacity is updated. If the remaining capacity is not enough to - visit any customer, the agent must go back to the depot. The reward is 0 unless the agent visits all the cities. - In that case, the reward is (-)length of the path: maximizing the reward is equivalent to minimizing the path length. - - Args: - num_loc: number of - - locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc: minimum value for the location coordinates - max_loc: maximum value for the location coordinates - min_demand: minimum value for the demand of each customer - max_demand: maximum value for the demand of each customer - vehicle_capacity: capacity of the vehicle - td_params: parameters of the environment - """ - - name = "vrpl" - - def __init__( - self, - num_loc: int = 20, - min_loc: float = 0, - max_loc: float = 1, - min_demand: float = 1, - max_demand: float = 10, - vehicle_capacity: float = 1.0, - capacity: float = None, - duration_limit: float = None, - #selected_node_list: torch.Tensor = None, - td_params: TensorDict = None, - **kwargs, - ): - super().__init__(**kwargs) - self.num_loc = num_loc - self.min_loc = min_loc - self.max_loc = max_loc - self.min_demand = min_demand - self.max_demand = max_demand - self.capacity = CAPACITIES.get(num_loc, None) if capacity is None else capacity - if self.capacity is None: - raise ValueError( - f"Capacity for {num_loc} locations is not defined. Please provide a capacity manually." - ) - self.vehicle_capacity = vehicle_capacity - self.duration_limit = 3.0 if duration_limit is None else duration_limit - self.selected_node_list = None - self._make_spec(td_params) - - def _step(self, td: TensorDict) -> TensorDict: - current_node = td["action"][:, None] # Add dimension for step - n_loc = td["demand"].size(-1) # Excludes depot - - self.selected_node_list = torch.cat((self.selected_node_list, current_node), dim=1) - # shape: (batch, pomo, 0~) - - # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot! - selected_demand = gather_by_index( - td["demand"], torch.clamp(current_node - 1, 0, n_loc - 1), squeeze=False - ) - - # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() - - # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot - # Add one dimension since we write a single value - visited = td["visited"].scatter(-1, current_node[..., None], 1) - - #### update distance information ### - - selected_xy = gather_by_index( - td["locs"], current_node, squeeze=False - ) - - gathering_index_last = self.selected_node_list[:, -2][:,None,None].expand(-1,-1,2) - - last_xy = gather_by_index( - td["locs"], gathering_index_last, squeeze=False - ) - - selected_distance = ((selected_xy - last_xy)**2).sum(dim=2).sqrt() - - td["duration_limit"] -= selected_distance - - td["duration_limit"][current_node == 0] = self.duration_limit # refill length at the depot - - - # SECTION: get done - done = visited.sum(-1) == visited.size(-1) - # print(done[0]) - # print(done) - reward = torch.zeros_like(done) - - td.update( - { - "current_node": current_node, - "used_capacity": used_capacity, - "visited": visited, - "reward": reward, - "done": done, - } - ) - td.set("action_mask", self.get_action_mask(td)) - return td - - def _reset( - self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, - ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - - self.selected_node_list = torch.zeros((*batch_size,1),dtype=torch.int64,device=self.device) - - # Create reset TensorDict - td_reset = TensorDict( - { - "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), - "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device - ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), - "duration_limit": torch.full( - (*batch_size, 1), self.duration_limit, device=self.device - ), - "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device - ), - - "visited": torch.zeros( - (*batch_size, 1, td["locs"].shape[-2] + 1), - dtype=torch.uint8, - device=self.device, - ), - }, - batch_size=batch_size, - ) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - - return td_reset - - @staticmethod - def get_action_mask(td: TensorDict) -> torch.Tensor: - # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting - exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] - ) - - if "action" not in td.keys(): - current_node =torch.zeros((td["locs"].shape[0],1),dtype=torch.int64, device=td["locs"].device) - else: - current_node = td["action"][:, None] # Add dimension for step - - selected_xy = gather_by_index( - td["locs"], current_node, squeeze=False - ) - - length_to_next = ((selected_xy - td["locs"])**2).sum(dim=2).sqrt() - - # length_to_next = ((selected_xy[:,None,:].expand(-1,-1,self.problem_size+1,-1) - xy_list)**2).sum(dim=3).sqrt() - # # shape: (batch, pomo, problem+1) - depot_xy = td["locs"][:,0,:] - next_to_depot = ((depot_xy[:,None,:].expand(td["locs"].shape) - td["locs"])**2).sum(dim=2).sqrt() - # shape: (batch, pomo, problem+1) - - length_too_small = td["duration_limit"] - 1E-6 < (length_to_next + next_to_depot ) - - # Nodes that cannot be visited are already visited or too much demand to be served now - - mask_loc = (td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap) | length_too_small[:,1:][:,None,:] - - # print(td["visited"][..., 1:].to(exceeds_cap.dtype).shape) - # print(exceeds_cap) - # print(length_too_small[:,1:][:,None,:]) - # print(mask_loc) - # print(td["visited"][..., 1:][0]) - # print(td["visited"][..., 1:][-1]) - #input() - # print(td["visited"][..., 1:].to(exceeds_cap.dtype).shape) - - - # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0) - - #print(~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2)) - return ~torch.cat((mask_depot[..., None], mask_loc), -1).squeeze(-2) - - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - # Check that the solution is valid - # if self.check_solution: - # self.check_solution_validity(td, actions) - - # Gather dataset in order of tour - batch_size = td["locs"].shape[0] - depot = td["locs"][..., 0:1, :] - locs_ordered = torch.cat( - [ - depot, - gather_by_index(td["locs"], actions).reshape( - [batch_size, actions.size(-1), 2] - ), - ], - dim=1, - ) - return -get_tour_length(locs_ordered) - - # @staticmethod - # def check_solution_validity(td: TensorDict, actions: torch.Tensor): - # """Check that solution is valid: nodes are not visited twice except depot and capacity is not exceeded""" - # # Check if tour is valid, i.e. contain 0 to n-1 - # batch_size, graph_size = td["demand"].size() - # sorted_pi = actions.data.sort(1)[0] - - # # Sorting it should give all zeros at front and then 1...n - # assert ( - # torch.arange(1, graph_size + 1, out=sorted_pi.data.new()) - # .view(1, -1) - # .expand(batch_size, graph_size) - # == sorted_pi[:, -graph_size:] - # ).all() and (sorted_pi[:, :-graph_size] == 0).all(), "Invalid tour" - - # # Visiting depot resets capacity so we add demand = -capacity (we make sure it does not become negative) - # demand_with_depot = torch.cat((-td["vehicle_capacity"], td["demand"]), 1) - # d = demand_with_depot.gather(1, actions) - - # used_cap = torch.zeros_like(td["demand"][:, 0]) - # for i in range(actions.size(1)): - # used_cap += d[ - # :, i - # ] # This will reset/make capacity negative if i == 0, e.g. depot visited - # # Cannot use less than 0 - # used_cap[used_cap < 0] = 0 - # assert ( - # used_cap <= td["vehicle_capacity"] + 1e-5 - # ).all(), "Used more than capacity" - - def generate_data(self, batch_size) -> TensorDict: - # Batch size input check - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - # Initialize the locations (including the depot which is always the first node) - locs_with_depot = ( - torch.FloatTensor(*batch_size, self.num_loc + 1, 2) - .uniform_(self.min_loc, self.max_loc) - .to(self.device) - ) - - # Initialize the demand for nodes except the depot - # Demand sampling Following Kool et al. (2019) - # Generates a slightly different distribution than using torch.randint - demand = ( - ( - torch.FloatTensor(*batch_size, self.num_loc) - .uniform_(self.min_demand - 1, self.max_demand - 1) - .int() - + 1 - ) - .float() - .to(self.device) - ) - - # Support for heterogeneous capacity if provided - if not isinstance(self.capacity, torch.Tensor): - capacity = torch.full((*batch_size,), self.capacity, device=self.device) - else: - capacity = self.capacity - - # duration limit - duration_limit = torch.full ((*batch_size,), self.duration_limit, device=self.device) - - return TensorDict( - { - "locs": locs_with_depot[..., 1:, :], - "depot": locs_with_depot[..., 0, :], - "demand": demand / self.capacity, - "capacity": capacity, - "duration_limit":duration_limit, - }, - batch_size=batch_size, - device=self.device, - ) - - @staticmethod - def load_data(fpath, batch_size=[]): - """Dataset loading from file - Normalize demand by capacity to be in [0, 1] - """ - td_load = load_npz_to_tensordict(fpath) - td_load.set("demand", td_load["demand"] / td_load["capacity"][:, None]) - return td_load - - def _make_spec(self, td_params: TensorDict): - """Make the observation and action specs from the parameters.""" - self.observation_spec = CompositeSpec( - locs=BoundedTensorSpec( - low=self.min_loc, - high=self.max_loc, - shape=(self.num_loc + 1, 2), - dtype=torch.float32, - ), - current_node=UnboundedDiscreteTensorSpec( - shape=(1), - dtype=torch.int64, - ), - demand=BoundedTensorSpec( - low=-self.capacity, - high=self.max_demand, - shape=(self.num_loc, 1), # demand is only for customers - dtype=torch.float32, - ), - action_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_loc + 1, 1), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=0, - high=self.num_loc + 1, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render( - td: TensorDict, - actions=None, - ax=None, - scale_xy: bool = True, - ): - import matplotlib.pyplot as plt - import numpy as np - - from matplotlib import cm, colormaps - - num_routine = (actions == 0).sum().item() + 2 - base = colormaps["nipy_spectral"] - color_list = base(np.linspace(0, 1, num_routine)) - cmap_name = base.name + str(num_routine) - out = base.from_list(cmap_name, color_list, num_routine) - - if ax is None: - # Create a plot of the nodes - _, ax = plt.subplots() - - td = td.detach().cpu() - - if actions is None: - actions = td.get("action", None) - - # if batch_size greater than 0 , we need to select the first batch element - if td.batch_size != torch.Size([]): - td = td[0] - actions = actions[0] - - locs = td["locs"] - scale_demand = CAPACITIES.get(td["locs"].size(-2) - 1, 1) - demands = td["demand"] * scale_demand - - # add the depot at the first action and the end action - actions = torch.cat([torch.tensor([0]), actions, torch.tensor([0])]) - - # gather locs in order of action if available - if actions is None: - log.warning("No action in TensorDict, rendering unsorted locs") - else: - locs = locs - - # Cat the first node to the end to complete the tour - x, y = locs[:, 0], locs[:, 1] - - # plot depot - ax.scatter( - locs[0, 0], - locs[0, 1], - edgecolors=cm.Set2(2), - facecolors="none", - s=100, - linewidths=2, - marker="s", - alpha=1, - ) - - # plot visited nodes - ax.scatter( - x[1:], - y[1:], - edgecolors=cm.Set2(0), - facecolors="none", - s=50, - linewidths=2, - marker="o", - alpha=1, - ) - - # plot demand bars - for node_idx in range(1, len(locs)): - ax.add_patch( - plt.Rectangle( - (locs[node_idx, 0] - 0.005, locs[node_idx, 1] + 0.015), - 0.01, - demands[node_idx - 1] / (scale_demand * 10), - edgecolor=cm.Set2(0), - facecolor=cm.Set2(0), - fill=True, - ) - ) - - # text demand - for node_idx in range(1, len(locs)): - ax.text( - locs[node_idx, 0], - locs[node_idx, 1] - 0.025, - f"{demands[node_idx-1].item():.2f}", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(0), - ) - - # text depot - ax.text( - locs[0, 0], - locs[0, 1] - 0.025, - "Depot", - horizontalalignment="center", - verticalalignment="top", - fontsize=10, - color=cm.Set2(2), - ) - - # plot actions - color_idx = 0 - for action_idx in range(len(actions) - 1): - if actions[action_idx] == 0: - color_idx += 1 - from_loc = locs[actions[action_idx]] - to_loc = locs[actions[action_idx + 1]] - ax.plot( - [from_loc[0], to_loc[0]], - [from_loc[1], to_loc[1]], - color=out(color_idx), - lw=1, - ) - ax.annotate( - "", - xy=(to_loc[0], to_loc[1]), - xytext=(from_loc[0], from_loc[1]), - arrowprops=dict(arrowstyle="-|>", color=out(color_idx)), - size=15, - annotation_clip=False, - ) - - # Setup limits and show - if scale_xy: - ax.set_xlim(-0.05, 1.05) - ax.set_ylim(-0.05, 1.05) - plt.show() diff --git a/rl4co/envs/routing/vrptw.py b/rl4co/envs/routing/vrptw.py deleted file mode 100644 index 51aedc25..00000000 --- a/rl4co/envs/routing/vrptw.py +++ /dev/null @@ -1,358 +0,0 @@ -from math import sqrt -from typing import Optional -import torch -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, -) - -from rl4co.envs.routing.cvrp import CVRPEnv, CAPACITIES -from rl4co.utils.ops import gather_by_index, get_distance -from rl4co.data.utils import ( - load_npz_to_tensordict, - load_solomon_instance, - load_solomon_solution, -) - - -class VRPTWEnv(CVRPEnv): - """Capacitated Vehicle Routing Problem with Time Windows (CVRPTW) environment. - Inherits from the CVRPEnv class in which capacities are considered. - Additionally considers time windows within which a service has to be started. - - Args: - num_loc (int): number of locations (cities) in the VRP, without the depot. (e.g. 10 means 10 locs + 1 depot) - min_loc (float): minimum value for the location coordinates - max_loc (float): maximum value for the location coordinates. Defaults to 150. - min_demand (float): minimum value for the demand of each customer - max_demand (float): maximum value for the demand of each customer - max_time (float): maximum time for the environment. Defaults to 480. - vehicle_capacity (float): capacity of the vehicle - capacity (float): capacity of the vehicle - scale (bool): if True, the time windows and service durations are scaled to [0, 1]. Defaults to False. - td_params: parameters of the environment - """ - - name = "cvrptw" - - def __init__( - self, - max_loc: float = 1, # different default value to CVRPEnv to match max_time, will be scaled - max_time: float = 4.6, - scale: bool = False, - **kwargs, - ): - self.min_time = 0 # always 0 - self.max_time = max_time - self.scale = scale - super().__init__(max_loc=max_loc, **kwargs) - - def _make_spec(self, td_params: TensorDict): - super()._make_spec(td_params) - - current_time = UnboundedContinuousTensorSpec( - shape=(1), dtype=torch.float32, device=self.device - ) - - current_loc = UnboundedContinuousTensorSpec( - shape=(2), dtype=torch.float32, device=self.device - ) - - durations = BoundedTensorSpec( - low=self.min_time, - high=self.max_time, - shape=(self.num_loc, 1), - dtype=torch.float, - device=self.device, - ) - - time_windows = BoundedTensorSpec( - low=self.min_time, - high=self.max_time, - shape=( - self.num_loc, - 2, - ), # each location has a 2D time window (start, end) - dtype=torch.float, - device=self.device, - ) - - # extend observation specs - self.observation_spec = CompositeSpec( - **self.observation_spec, - current_time=current_time, - current_loc=current_loc, - durations=durations, - time_windows=time_windows, - # vehicle_idx=vehicle_idx, - ) - - def generate_data(self, batch_size) -> TensorDict: - """ - Generates time windows and service durations for the locations. The depot has a time window of [0, self.max_time]. - The time windows define the time span within which a service has to be started. To reach the depot in time from the last node, - the end time of each node is bounded by the service duration and the distance back to the depot. - The start times of the time windows are bounded by how long it takes to travel there from the depot. - """ - td = super().generate_data(batch_size) - - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - ## define service durations - durations = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device) * 0.05 +0.15 - # shape: (batch, problem) - # range: (0.15,0.2) for T=4.6 - - node_lengthTW = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device) * 0.05 +0.15 - # shape: (batch, problem) - # range: (0.15,0.2) for T=4.6 - - - d0i = ((torch.cat((td["depot"][..., None, :], td["locs"]), -2) - td["depot"][..., None, :].expand(size=(*batch_size,self.num_loc + 1,2)))**2).sum(2).sqrt() - # shape: (batch, problem) - - - ei = torch.rand(size=(*batch_size, self.num_loc + 1), dtype=torch.float32, device=self.device).mul((torch.div((self.max_time*torch.ones(size=(*batch_size,self.num_loc + 1), dtype=torch.float32, device=self.device) - durations - node_lengthTW),d0i) - 1)-1)+1 - # shape: (batch, problem) - # default velocity = 1.0 - - min_times = ei.mul(d0i) - # shape: (batch, problem) - # default velocity = 1.0 - - max_times = min_times + node_lengthTW - # shape: (batch, problem) - - min_times[..., :, 0] = 0.0 - max_times[..., :, 0] = self.max_time - - - # scale to [0, 1] - if self.scale: - durations = durations / self.max_time - min_times = min_times / self.max_time - max_times = max_times / self.max_time - td["depot"] = td["depot"] / self.max_time - td["locs"] = td["locs"] / self.max_time - - # 8. stack to tensor time_windows - time_windows = torch.stack((min_times, max_times), dim=-1) - - assert torch.all( - min_times < max_times - ), "Please make sure the relation between max_loc and max_time allows for feasible solutions." - - # reset duration at depot to 0 - durations[:, 0] = 0.0 - - td.update( - { - "durations": durations, - "time_windows": time_windows, - } - ) - return td - - @staticmethod - def get_action_mask(td: TensorDict) -> torch.Tensor: - """In addition to the constraints considered in the CVRPEnv, the time windows are considered. - The vehicle can only visit a location if it can reach it in time, i.e. before its time window ends. - """ - not_masked = CVRPEnv.get_action_mask(td) - batch_size = td["locs"].shape[0] - current_loc = gather_by_index(td["locs"], td["current_node"]).reshape( - [batch_size, 2] - ) - dist = get_distance(current_loc, td["locs"].transpose(0, 1)).transpose(0, 1) - td.update({"current_loc": current_loc, "distances": dist}) - can_reach_in_time = ( - td["current_time"] + dist <= td["time_windows"][..., 1] - ) # I only need to start the service before the time window ends, not finish it. - - return not_masked & can_reach_in_time - - def _step(self, td: TensorDict) -> TensorDict: - """In addition to the calculations in the CVRPEnv, the current time is - updated to keep track of which nodes are still reachable in time. - The current_node is updeted in the parent class' _step() function. - """ - - batch_size = td["locs"].shape[0] - # update current_time - distance = gather_by_index(td["distances"], td["action"]).reshape([batch_size, 1]) - duration = gather_by_index(td["durations"], td["action"]).reshape([batch_size, 1]) - start_times = gather_by_index(td["time_windows"], td["action"])[..., 0].reshape( - [batch_size, 1] - ) - td["current_time"] = (td["action"][:, None] != 0) * ( - torch.max(td["current_time"] + distance, start_times) + duration - ) - # current_node is updated to the selected action - td = super()._step(td) - return td - - def _reset( - self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None - ) -> TensorDict: - if batch_size is None: - batch_size = self.batch_size if td is None else td["locs"].shape[:-2] - if td is None or td.is_empty(): - td = self.generate_data(batch_size=batch_size) - batch_size = [batch_size] if isinstance(batch_size, int) else batch_size - - self.to(td.device) - # Create reset TensorDict - td_reset = TensorDict( - { - "locs": torch.cat((td["depot"][..., None, :], td["locs"]), -2), - "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=self.device - ), - "current_time": torch.zeros( - *batch_size, 1, dtype=torch.float32, device=self.device - ), - "used_capacity": torch.zeros((*batch_size, 1), device=self.device), - "vehicle_capacity": torch.full( - (*batch_size, 1), self.vehicle_capacity, device=self.device - ), - "visited": torch.zeros( - (*batch_size, 1, td["locs"].shape[-2] + 1), - dtype=torch.uint8, - device=self.device, - ), - "durations": td["durations"], - "time_windows": td["time_windows"], - }, - batch_size=batch_size, - ) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - return td_reset - - def get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: - """The reward is the negative tour length. Time windows - are not considered for the calculation of the reward.""" - return super().get_reward(td, actions) - - @staticmethod - def check_solution_validity(td: TensorDict, actions: torch.Tensor): - CVRPEnv.check_solution_validity(td, actions) - batch_size = td["locs"].shape[0] - # distances to depot - distances = get_distance( - td["locs"][..., 0, :], td["locs"].transpose(0, 1) - ).transpose(0, 1) - # basic checks on time windows - assert torch.all(distances >= 0.0), "Distances must be non-negative." - assert torch.all(td["time_windows"] >= 0.0), "Time windows must be non-negative." - assert torch.all( - td["time_windows"][..., :, 0] + distances + td["durations"] - <= td["time_windows"][..., 0, 1][0] # max_time is the same for all batches - ), "vehicle cannot perform service and get back to depot in time." - assert torch.all( - td["durations"] >= 0.0 - ), "Service durations must be non-negative." - assert torch.all( - td["time_windows"][..., 0] < td["time_windows"][..., 1] - ), "there are unfeasible time windows" - # check vehicles can meet deadlines - curr_time = torch.zeros(batch_size, 1, dtype=torch.float32, device=td.device) - curr_node = torch.zeros_like(curr_time, dtype=torch.int64, device=td.device) - for ii in range(actions.size(1)): - next_node = actions[:, ii] - dist = get_distance( - gather_by_index(td["locs"], curr_node).reshape([batch_size, 2]), - gather_by_index(td["locs"], next_node).reshape([batch_size, 2]), - ).reshape([batch_size, 1]) - curr_time = torch.max( - (curr_time + dist).int(), - gather_by_index(td["time_windows"], next_node)[..., 0].reshape( - [batch_size, 1] - ), - ) - assert torch.all( - curr_time - <= gather_by_index(td["time_windows"], next_node)[..., 1].reshape( - [batch_size, 1] - ) - ), "vehicle cannot start service before deadline" - curr_time = curr_time + gather_by_index(td["durations"], next_node).reshape( - [batch_size, 1] - ) - curr_node = next_node - curr_time[curr_node == 0] = 0.0 # reset time for depot - - @staticmethod - def render(td: TensorDict, actions=None, ax=None, scale_xy: bool = False, **kwargs): - CVRPEnv.render(td=td, actions=actions, ax=ax, scale_xy=scale_xy, **kwargs) - - @staticmethod - def load_data( - name: str, - solomon=False, - path_instances: str = None, - type: str = None, - compute_edge_weights: bool = False, - ): - if solomon == True: - assert type in [ - "instance", - "solution", - ], "type must be either 'instance' or 'solution'" - if type == "instance": - instance = load_solomon_instance( - name=name, path=path_instances, edge_weights=compute_edge_weights - ) - elif type == "solution": - instance = load_solomon_solution(name=name, path=path_instances) - return instance - return load_npz_to_tensordict(filename=name) - - def extract_from_solomon(self, instance: dict, batch_size: int = 1): - # extract parameters for the environment from the Solomon instance - self.min_demand = instance["demand"][1:].min() - self.max_demand = instance["demand"][1:].max() - self.vehicle_capacity = instance["capacity"] - self.min_loc = instance["node_coord"][1:].min() - self.max_loc = instance["node_coord"][1:].max() - self.min_time = instance["time_window"][:, 0].min() - self.max_time = instance["time_window"][:, 1].max() - # assert the time window of the depot starts at 0 and ends at max_time - assert self.min_time == 0, "Time window of depot must start at 0." - assert ( - self.max_time == instance["time_window"][0, 1] - ), "Depot must have latest end time." - # convert to format used in CVRPTWEnv - td = TensorDict( - { - "depot": torch.tensor( - instance["node_coord"][0], - dtype=torch.float32, - device=self.device, - ).repeat(batch_size, 1), - "locs": torch.tensor( - instance["node_coord"][1:], - dtype=torch.float32, - device=self.device, - ).repeat(batch_size, 1, 1), - "demand": torch.tensor( - instance["demand"][1:], - dtype=torch.float32, - device=self.device, - ).repeat(batch_size, 1), - "durations": torch.tensor( - instance["service_time"], - dtype=torch.int64, - device=self.device, - ).repeat(batch_size, 1), - "time_windows": torch.tensor( - instance["time_window"], - dtype=torch.int64, - device=self.device, - ).repeat(batch_size, 1, 1), - }, - batch_size=1, # we assume batch_size will always be 1 for loaded instances - ) - return self.reset(td, batch_size=batch_size) diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py index 897ee755..1c63820f 100644 --- a/rl4co/envs/scheduling/__init__.py +++ b/rl4co/envs/scheduling/__init__.py @@ -1,3 +1,2 @@ from rl4co.envs.scheduling.ffsp.env import FFSPEnv -from rl4co.envs.scheduling.fjsp.env import FJSPEnv from rl4co.envs.scheduling.smtwtp.env import SMTWTPEnv diff --git a/rl4co/envs/scheduling/fjsp/__init__.py b/rl4co/envs/scheduling/fjsp/__init__.py deleted file mode 100644 index 4eb6d9df..00000000 --- a/rl4co/envs/scheduling/fjsp/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -NO_OP_ID = -1 -INIT_FINISH = 9999.0 diff --git a/rl4co/envs/scheduling/fjsp/env.py b/rl4co/envs/scheduling/fjsp/env.py deleted file mode 100644 index 4a6a217f..00000000 --- a/rl4co/envs/scheduling/fjsp/env.py +++ /dev/null @@ -1,424 +0,0 @@ -import torch - -from einops import rearrange, reduce -from tensordict.tensordict import TensorDict -from torchrl.data import ( - BoundedTensorSpec, - CompositeSpec, - UnboundedContinuousTensorSpec, - UnboundedDiscreteTensorSpec, -) - -from rl4co.envs.common.base import RL4COEnvBase as EnvBase -from rl4co.utils.ops import gather_by_index, sample_n_random_actions - -from . import INIT_FINISH, NO_OP_ID -from .generator import FJSPFileGenerator, FJSPGenerator -from .render import render -from .utils import calc_lower_bound, get_job_ops_mapping, op_is_ready - - -class FJSPEnv(EnvBase): - """Flexible Job-Shop Scheduling Problem (FJSP) environment - At each step, the agent chooses a job-machine combination. The operation to be processed next for the selected job is - then executed on the selected machine. The reward is 0 unless the agent scheduled all operations of all jobs. - In that case, the reward is (-)makespan of the schedule: maximizing the reward is equivalent to minimizing the makespan. - - Observations: - - time: current time - - next_op: next operation per job - - proc_times: processing time of operation-machine pairs - - pad_mask: specifies padded operations - - start_op_per_job: id of first operation per job - - end_op_per_job: id of last operation per job - - start_times: start time of operation (defaults to 0 if not scheduled) - - finish_times: finish time of operation (defaults to INIT_FINISH if not scheduled) - - job_ops_adj: adjacency matrix specifying job-operation affiliation - - ops_job_map: same as above but using ids of jobs to indicate affiliation - - ops_sequence_order: specifies the order in which operations have to be processed - - ma_assignment: specifies which operation has been scheduled on which machine - - busy_until: specifies until when the machine will be busy - - num_eligible: number of machines that can process an operation - - job_in_process: whether job is currently being processed - - job_done: whether the job is done - - Constrains: - the agent may not select: - - machines that are currently busy - - jobs that are done already - - jobs that are currently processed - - job-machine combinations, where the machine cannot process the next operation of the job - - Finish condition: - - the agent has scheduled all operations of all jobs - - Reward: - - the negative makespan of the final schedule - - Args: - generator: FJSPGenerator instance as the data generator - generator_params: parameters for the generator - mask_no_ops: if True, agent may not select waiting operation (unless instance is done) - """ - - name = "fjsp" - - def __init__( - self, - generator: FJSPGenerator = None, - generator_params: dict = {}, - mask_no_ops: bool = True, - **kwargs, - ): - super().__init__(check_solution=False, **kwargs) - if generator is None: - if generator_params.get("file_path", None) is not None: - generator = FJSPFileGenerator(**generator_params) - else: - generator = FJSPGenerator(**generator_params) - self.generator = generator - self.num_mas = generator.num_mas - self.num_jobs = generator.num_jobs - self.n_ops_max = generator.max_ops_per_job * self.num_jobs - self.mask_no_ops = mask_no_ops - self._make_spec(self.generator) - - def _decode_graph_structure(self, td: TensorDict): - batch_size = td.batch_size - start_op_per_job = td["start_op_per_job"] - end_op_per_job = td["end_op_per_job"] - pad_mask = td["pad_mask"] - n_ops_max = td["pad_mask"].size(-1) - - # here we will generate the operations-job mapping: - ops_job_map, ops_job_bin_map = get_job_ops_mapping( - start_op_per_job, end_op_per_job, n_ops_max - ) - - # mask invalid edges (caused by padding) - ops_job_bin_map[pad_mask.unsqueeze(1).expand_as(ops_job_bin_map)] = 0 - - # generate for each batch a sequence specifying the position of all operations in their respective jobs, - # e.g. [0,1,0,0,1,2,0,1,2,3,0,0] for jops with n_ops=[2,1,3,4,1,1] - # (bs, max_ops) - ops_seq_order = torch.sum( - ops_job_bin_map * (ops_job_bin_map.cumsum(2) - 1), dim=1 - ) - - # predecessor and successor adjacency matrices - pred = torch.diag_embed(torch.ones(n_ops_max - 1), offset=-1)[None].expand( - *batch_size, -1, -1 - ) - # the start of the sequence (of each job) does not have a predecessor, therefore we can - # mask all first ops of a job in the predecessor matrix - pred = pred * ops_seq_order.gt(0).unsqueeze(-1).expand_as(pred).to(pred) - succ = torch.diag_embed(torch.ones(n_ops_max - 1), offset=1)[None].expand( - *batch_size, -1, -1 - ) - # apply the same logic as above to mask the last op of a job, which does not have a successor. The last job of a job - # always comes before the 1st op of the next job, therefore performing a left shift of the ops seq tensor here - succ = succ * torch.cat( - (ops_seq_order[:, 1:], ops_seq_order.new_full((*batch_size, 1), 0)), dim=1 - ).gt(0).to(succ).unsqueeze(-1).expand_as(succ) - - # adjacency matrix = predecessors, successors and self loops - # (bs, max_ops, max_ops, 2) - ops_adj = torch.stack((pred, succ), dim=3) - - td = td.update( - { - "ops_adj": ops_adj, - "job_ops_adj": ops_job_bin_map, - "ops_job_map": ops_job_map, - # "op_spatial_enc": ops_spatial_enc, - "ops_sequence_order": ops_seq_order, - } - ) - - return td, n_ops_max - - def _reset(self, td: TensorDict = None, batch_size=None) -> TensorDict: - td_reset = td.clone() - - td_reset, n_ops_max = self._decode_graph_structure(td_reset) - - # schedule - start_op_per_job = td_reset["start_op_per_job"] - start_times = torch.zeros((*batch_size, n_ops_max)) - finish_times = torch.full((*batch_size, n_ops_max), INIT_FINISH) - ma_assignment = torch.zeros((*batch_size, self.num_mas, n_ops_max)) - - # reset feature space - busy_until = torch.zeros((*batch_size, self.num_mas)) - # (bs, ma, ops) - ops_ma_adj = (td_reset["proc_times"] > 0).to(torch.float32) - # (bs, ops) - num_eligible = torch.sum(ops_ma_adj, dim=1) - - td_reset = td_reset.update( - { - "start_times": start_times, - "finish_times": finish_times, - "ma_assignment": ma_assignment, - "busy_until": busy_until, - "num_eligible": num_eligible, - "next_op": start_op_per_job.clone().to(torch.int64), - "ops_ma_adj": ops_ma_adj, - "op_scheduled": torch.full((*batch_size, n_ops_max), False), - "job_in_process": torch.full((*batch_size, self.num_jobs), False), - "reward": torch.zeros((*batch_size,), dtype=torch.float32), - "time": torch.zeros((*batch_size,)), - "job_done": torch.full((*batch_size, self.num_jobs), False), - "done": torch.full((*batch_size, 1), False), - }, - ) - - td_reset.set("lbs", calc_lower_bound(td_reset)) - td_reset.set("is_ready", op_is_ready(td_reset)) - td_reset.set("action_mask", self.get_action_mask(td_reset)) - - return td_reset - - def get_action_mask(self, td: TensorDict) -> torch.Tensor: - batch_size = td.size(0) - - # (bs, jobs, machines) - action_mask = torch.full((batch_size, self.num_jobs, self.num_mas), False).to( - td.device - ) - - # mask jobs that are done already - action_mask.add_(td["job_done"].unsqueeze(2)) - # as well as jobs that are currently processed - action_mask.add_(td["job_in_process"].unsqueeze(2)) - - # mask machines that are currently busy - action_mask.add_(td["busy_until"].gt(td["time"].unsqueeze(1)).unsqueeze(1)) - - # exclude job-machine combinations, where the machine cannot process the next op of the job - next_ops_proc_times = gather_by_index( - td["proc_times"], td["next_op"].unsqueeze(1), dim=2, squeeze=False - ).transpose(1, 2) - action_mask.add_(next_ops_proc_times == 0) - if self.mask_no_ops: - no_op_mask = ~td["done"] - else: - no_op_mask = ~td["job_in_process"].any(1, keepdims=True) & ~td["done"] - # flatten action mask to correspond with logit shape - action_mask = rearrange(action_mask, "bs j m -> bs (j m)") - # NOTE: 1 means feasible action, 0 means infeasible action - mask = torch.cat((~no_op_mask, ~action_mask), dim=1) - return mask - - def _step(self, td: TensorDict): - # cloning required to avoid inplace operation which avoids gradient backtracking - td = td.clone() - td["action"].subtract_(1) - # (bs) - dones = td["done"].squeeze(1) - # specify which batch instances require which operation - no_op = td["action"].eq(NO_OP_ID) - no_op = no_op & ~dones - req_op = ~no_op & ~dones - - # transition to next time for no op instances - if no_op.any(): - td, dones = self._transit_to_next_time(no_op, td) - - td_op = td.masked_select(req_op) - - # (#req_op) - selected_job = td_op["action"] // self.num_mas - # (#req_op) - selected_machine = td_op["action"] % self.num_mas - td_op = self._make_step(td_op, selected_job, selected_machine) - - td[req_op] = td_op - - # action mask - td.set("action_mask", self.get_action_mask(td)) - - step_complete = self._check_step_complete(td, dones) - while step_complete.any(): - td, dones = self._transit_to_next_time(step_complete, td) - td.set("action_mask", self.get_action_mask(td)) - step_complete = self._check_step_complete(td, dones) - - # after we have transitioned to a next time step, we determine which operations are ready - td["is_ready"] = op_is_ready(td) - - td["lbs"] = calc_lower_bound(td) - - return td - - @staticmethod - def _check_step_complete(td, dones): - """check whether there a feasible actions left to be taken during the current - time step. If this is not the case (and the instance is not done), - we need to adance the timer of the repsective instance - """ - return ~reduce(td["action_mask"], "bs ... -> bs", "any") & ~dones - - def _make_step(self, td: TensorDict, selected_job, selected_machine) -> TensorDict: - """ - Environment transition function - """ - - batch_idx = torch.arange(td.size(0)) - - td["job_in_process"][batch_idx, selected_job] = 1 - - # (#req_op) - selected_op = td["next_op"].gather(1, selected_job[:, None]).squeeze(1) - - # mark op as schedules - td["op_scheduled"][batch_idx, selected_op] = True - - # update machine state - proc_time_of_action = td["proc_times"][batch_idx, selected_machine, selected_op] - # we may not select a machine that is busy - assert torch.all(td["busy_until"][batch_idx, selected_machine] <= td["time"]) - - # update schedule - td["start_times"][batch_idx, selected_op] = td["time"] - td["finish_times"][batch_idx, selected_op] = td["time"] + proc_time_of_action - td["ma_assignment"][batch_idx, selected_machine, selected_op] = 1 - # update the state of the selected machine - td["busy_until"][batch_idx, selected_machine] = td["time"] + proc_time_of_action - - return td - - def _transit_to_next_time(self, step_complete, td: TensorDict) -> TensorDict: - """ - Transit to the next time - """ - - # we need a transition to a next time step if either - # 1.) all machines are busy - # 2.) all operations are already currently in process (can only happen if num_jobs < num_machines) - # 3.) idle machines can not process any of the not yet scheduled operations - # 4.) no_op is choosen - available_time_ma = td["busy_until"] - end_op_per_job = td["end_op_per_job"] - # we want to transition to the next time step where a machine becomes idle again. This time step must be - # in the future, therefore we mask all machine idle times lying in the past / present - available_time = ( - torch.where( - available_time_ma > td["time"][:, None], available_time_ma, torch.inf - ) - .min(1) - .values - ) - - assert not torch.any(available_time[step_complete].isinf()) - td["time"] = torch.where(step_complete, available_time, td["time"]) - - # this may only be set when the operation is finished, not when it is scheduled - # operation of job is finished, set next operation and flag job as being idle - curr_ops_end = td["finish_times"].gather(1, td["next_op"]) - op_finished = td["job_in_process"] & (curr_ops_end <= td["time"][:, None]) - # check whether a job is finished, which is the case when the last operation of the job is finished - job_finished = op_finished & (td["next_op"] == end_op_per_job) - # determine the next operation for a job that is not done, but whose latest operation is finished - td["next_op"] = torch.where( - op_finished & ~job_finished, - td["next_op"] + 1, - td["next_op"], - ) - td["job_in_process"][op_finished] = False - - td["job_done"] = td["job_done"] + job_finished - td["done"] = td["job_done"].all(1, keepdim=True) - - return td, td["done"].squeeze(1) - - def _get_reward(self, td, actions=None) -> TensorDict: - return -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values - - def _make_spec(self, generator: FJSPGenerator): - self.observation_spec = CompositeSpec( - time=UnboundedDiscreteTensorSpec( - shape=(1,), - dtype=torch.int64, - ), - next_op=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs,), - dtype=torch.int64, - ), - proc_times=UnboundedDiscreteTensorSpec( - shape=(self.num_mas, self.n_ops_max), - dtype=torch.float32, - ), - pad_mask=UnboundedDiscreteTensorSpec( - shape=(self.num_mas, self.n_ops_max), - dtype=torch.bool, - ), - start_op_per_job=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs,), - dtype=torch.bool, - ), - end_op_per_job=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs,), - dtype=torch.bool, - ), - start_times=UnboundedDiscreteTensorSpec( - shape=(self.n_ops_max,), - dtype=torch.int64, - ), - finish_times=UnboundedDiscreteTensorSpec( - shape=(self.n_ops_max,), - dtype=torch.int64, - ), - job_ops_adj=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs, self.n_ops_max), - dtype=torch.int64, - ), - ops_job_map=UnboundedDiscreteTensorSpec( - shape=(self.n_ops_max), - dtype=torch.int64, - ), - ops_sequence_order=UnboundedDiscreteTensorSpec( - shape=(self.n_ops_max), - dtype=torch.int64, - ), - ma_assignment=UnboundedDiscreteTensorSpec( - shape=(self.num_mas, self.n_ops_max), - dtype=torch.int64, - ), - busy_until=UnboundedDiscreteTensorSpec( - shape=(self.num_mas,), - dtype=torch.int64, - ), - num_eligible=UnboundedDiscreteTensorSpec( - shape=(self.n_ops_max,), - dtype=torch.int64, - ), - job_in_process=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs,), - dtype=torch.bool, - ), - job_done=UnboundedDiscreteTensorSpec( - shape=(self.num_jobs,), - dtype=torch.bool, - ), - shape=(), - ) - self.action_spec = BoundedTensorSpec( - shape=(1,), - dtype=torch.int64, - low=-1, - high=self.n_ops_max, - ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) - self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) - - @staticmethod - def render(td, idx): - return render(td, idx) - - def select_start_nodes(self, td: TensorDict, num_starts: int): - return sample_n_random_actions(td, num_starts) - - def get_num_starts(self, td): - # NOTE in the paper they use N_s = 100 - return 100 diff --git a/rl4co/envs/scheduling/fjsp/generator.py b/rl4co/envs/scheduling/fjsp/generator.py deleted file mode 100644 index f1ae6202..00000000 --- a/rl4co/envs/scheduling/fjsp/generator.py +++ /dev/null @@ -1,216 +0,0 @@ -from functools import partial -from typing import List - -import numpy as np -import torch - -from tensordict.tensordict import TensorDict - -from rl4co.envs.common.utils import Generator -from rl4co.utils.pylogger import get_pylogger - -from .parser import get_max_ops_from_files, read - -log = get_pylogger(__name__) - - -class FJSPGenerator(Generator): - - """Data generator for the Flexible Job-Shop Scheduling Problem (FJSP). - - Args: - num_stage: number of stages - num_machine: number of machines - num_job: number of jobs - min_time: minimum running time of each job on each machine - max_time: maximum running time of each job on each machine - flatten_stages: whether to flatten the stages - - Returns: - A TensorDict with the following key: - start_op_per_job [batch_size, num_jobs]: first operation of each job - end_op_per_job [batch_size, num_jobs]: last operation of each job - proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines - pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used - - """ - - def __init__( - self, - num_jobs: int = 10, - num_machines: int = 5, - min_ops_per_job: int = 4, - max_ops_per_job: int = 6, - min_processing_time: int = 1, - max_processing_time: int = 20, - min_eligible_ma_per_op: int = 1, - max_eligible_ma_per_op: int = None, - **unused_kwargs, - ): - self.num_jobs = num_jobs - self.num_mas = num_machines - self.min_ops_per_job = min_ops_per_job - self.max_ops_per_job = max_ops_per_job - self.min_processing_time = min_processing_time - self.max_processing_time = max_processing_time - self.min_eligible_ma_per_op = min_eligible_ma_per_op - self.max_eligible_ma_per_op = max_eligible_ma_per_op or num_machines - # determines whether to use a fixed number of total operations or let it vary between instances - # NOTE: due to the way rl4co builds datasets, we need a fixed size here - self.n_ops_max = max_ops_per_job * num_jobs - - # FFSP environment doen't have any other kwargs - if len(unused_kwargs) > 0: - log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") - - def _simulate_processing_times( - self, n_eligible_per_ops: torch.Tensor - ) -> torch.Tensor: - bs, n_ops_max = n_eligible_per_ops.shape - - # (bs, max_ops, machines) - ma_seq_per_ops = torch.arange(1, self.num_mas + 1)[None, None].expand( - bs, n_ops_max, self.num_mas - ) - # generate a matrix of size (ops, mas) per batch, each row having as many ones as the operation eligible machines - # E.g. n_eligible_per_ops=[1,3,2]; num_mas=4 - # [[1,0,0,0], - # 1,1,1,0], - # 1,1,0,0]] - # This will be shuffled randomly to generate a machine-operation mapping - ma_ops_edges_unshuffled = torch.Tensor.float( - ma_seq_per_ops <= n_eligible_per_ops[..., None] - ) - # random shuffling - idx = torch.rand_like(ma_ops_edges_unshuffled).argsort() - ma_ops_edges = ma_ops_edges_unshuffled.gather(2, idx).transpose(1, 2) - - # (bs, max_ops, machines) - proc_times = torch.ones((bs, n_ops_max, self.num_mas)) - proc_times = torch.randint( - self.min_processing_time, - self.max_processing_time + 1, - size=(bs, self.num_mas, n_ops_max), - ) - - # remove proc_times for which there is no corresponding ma-ops connection - proc_times = proc_times * ma_ops_edges - return proc_times - - def _generate(self, batch_size) -> TensorDict: - # simulate how many operations each job has - n_ope_per_job = torch.randint( - self.min_ops_per_job, - self.max_ops_per_job + 1, - size=(*batch_size, self.num_jobs), - ) - - # determine the total number of operations per batch instance (which may differ) - n_ops_batch = n_ope_per_job.sum(1) # (bs) - # determine the maximum total number of operations over all batch instances - n_ops_max = self.n_ops_max or n_ops_batch.max() - - # generate a mask, specifying which operations are padded - pad_mask = torch.arange(n_ops_max).unsqueeze(0).expand(*batch_size, -1) - pad_mask = pad_mask.ge(n_ops_batch[:, None].expand_as(pad_mask)) - - # determine the id of the end operation for each job - end_op_per_job = n_ope_per_job.cumsum(1) - 1 - - # determine the id of the starting operation for each job - # (bs, num_jobs) - start_op_per_job = torch.cat( - ( - torch.zeros((*batch_size, 1)).to(end_op_per_job), - end_op_per_job[:, :-1] + 1, - ), - dim=1, - ) - - # here we simulate the eligible machines per operation and the processing times - n_eligible_per_ops = torch.randint( - self.min_eligible_ma_per_op, - self.max_eligible_ma_per_op + 1, - (*batch_size, n_ops_max), - ) - n_eligible_per_ops[pad_mask] = 0 - - # simulate processing times for machine-operation pairs - # (bs, num_mas, n_ops_max) - proc_times = self._simulate_processing_times(n_eligible_per_ops) - - td = TensorDict( - { - "start_op_per_job": start_op_per_job, - "end_op_per_job": end_op_per_job, - "proc_times": proc_times, - "pad_mask": pad_mask, - }, - batch_size=batch_size, - ) - - return td - - -class FJSPFileGenerator(Generator): - """Data generator for the Flexible Job-Shop Scheduling Problem (FJSP) using instance files - - Args: - path: path to files - - Returns: - A TensorDict with the following key: - start_op_per_job [batch_size, num_jobs]: first operation of each job - end_op_per_job [batch_size, num_jobs]: last operation of each job - proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines - pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used - - """ - - def __init__(self, file_path: str, n_ops_max: int = None, **unused_kwargs): - self.files = self.list_files(file_path) - self.num_samples = len(self.files) - - if len(unused_kwargs) > 0: - log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") - - if len(self.files) > 1: - n_ops_max = get_max_ops_from_files(self.files) - - ret = map(partial(read, max_ops=n_ops_max), self.files) - - td_list, num_jobs, num_machines, max_ops_per_job = list(zip(*list(ret))) - num_jobs, num_machines = map(lambda x: x[0], (num_jobs, num_machines)) - max_ops_per_job = max(max_ops_per_job) - - self.td = torch.cat(td_list, dim=0) - self.num_mas = num_machines - self.num_jobs = num_jobs - self.max_ops_per_job = max_ops_per_job - self.start_idx = 0 - - def _generate(self, batch_size: List[int]) -> TensorDict: - batch_size = np.prod(batch_size) - if batch_size > self.num_samples: - log.warning( - f"Only found {self.num_samples} instance files, but specified dataset size is {batch_size}" - ) - end_idx = self.start_idx + batch_size - td = self.td[self.start_idx : end_idx] - self.start_idx += batch_size - return td - - @staticmethod - def list_files(path): - import os - - files = [ - os.path.join(path, f) - for f in os.listdir(path) - if os.path.isfile(os.path.join(path, f)) - ] - assert len(files) > 0 - files = sorted( - files, key=lambda f: int(os.path.splitext(os.path.basename(f))[0][:4]) - ) - return files diff --git a/rl4co/envs/scheduling/fjsp/parser.py b/rl4co/envs/scheduling/fjsp/parser.py deleted file mode 100644 index f05c8fca..00000000 --- a/rl4co/envs/scheduling/fjsp/parser.py +++ /dev/null @@ -1,180 +0,0 @@ -import os - -from functools import partial -from pathlib import Path -from typing import List, Tuple, Union - -import torch - -from tensordict import TensorDict - -ProcessingData = List[Tuple[int, int]] - - -def list_files(path): - import os - - files = [ - os.path.join(path, f) - for f in os.listdir(path) - if os.path.isfile(os.path.join(path, f)) - ] - return files - - -def parse_job_line(line: Tuple[int]) -> Tuple[ProcessingData]: - """ - Parses a FJSPLIB job data line of the following form: - - * ( * ( )) - - In words, the first value is the number of operations. Then, for each - operation, the first number represents the number of machines that can - process the operation, followed by, the machine index and processing time - for each eligible machine. - - Note that the machine indices start from 1, so we subtract 1 to make them - zero-based. - """ - num_operations = line[0] - operations = [] - idx = 1 - - for _ in range(num_operations): - num_pairs = int(line[idx]) * 2 - machines = line[idx + 1 : idx + 1 + num_pairs : 2] - durations = line[idx + 2 : idx + 2 + num_pairs : 2] - operations.append([(m, d) for m, d in zip(machines, durations)]) - - idx += 1 + num_pairs - - return operations - - -def get_n_ops_of_instance(file): - lines = file2lines(file) - jobs = [parse_job_line(line) for line in lines[1:]] - n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) - total_ops = int(n_ope_per_job.sum()) - return total_ops - - -def get_max_ops_from_files(files): - return max(map(get_n_ops_of_instance, files)) - - -def read(loc: Path, max_ops=None): - """ - Reads an FJSPLIB instance. - - Args: - loc: location of instance file - max_ops: optionally specify the maximum number of total operations (will be filled by padding) - - Returns: - instance: the parsed instance - """ - lines = file2lines(loc) - - # First line contains metadata. - num_jobs, num_machines = lines[0][0], lines[0][1] - - # The remaining lines contain the job-operation data, where each line - # represents a job and its operations. - jobs = [parse_job_line(line) for line in lines[1:]] - n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) - total_ops = int(n_ope_per_job.sum()) - if max_ops is not None: - assert total_ops <= max_ops, "got more operations then specified through max_ops" - max_ops = max_ops or total_ops - max_ops_per_job = int(n_ope_per_job.max()) - - end_op_per_job = n_ope_per_job.cumsum(1) - 1 - start_op_per_job = torch.cat((torch.zeros((1, 1)), end_op_per_job[:, :-1] + 1), dim=1) - - pad_mask = torch.arange(max_ops) - pad_mask = pad_mask.ge(total_ops).unsqueeze(0) - - proc_times = torch.zeros((num_machines, max_ops)) - op_cnt = 0 - for job in jobs: - for op in job: - for ma, dur in op: - # subtract one to let indices start from zero - proc_times[ma - 1, op_cnt] = dur - op_cnt += 1 - proc_times = proc_times.unsqueeze(0) - - td = TensorDict( - { - "start_op_per_job": start_op_per_job, - "end_op_per_job": end_op_per_job, - "proc_times": proc_times, - "pad_mask": pad_mask, - }, - batch_size=[1], - ) - - return td, num_jobs, num_machines, max_ops_per_job - - -def file2lines(loc: Union[Path, str]) -> List[List[int]]: - with open(loc, "r") as fh: - lines = [line for line in fh.readlines() if line.strip()] - - def parse_num(word: str): - return int(word) if "." not in word else int(float(word)) - - return [[parse_num(x) for x in line.split()] for line in lines] - - -def write_one(args, where=None): - id, instance = args - assert ( - len(instance["proc_times"].shape) == 2 - ), "no batch dimension allowed in write operation" - lines = [] - - # The flexibility is the average number of eligible machines per operation. - num_eligible = (instance["proc_times"] > 0).sum() - n_ops = (~instance["pad_mask"]).sum() - num_jobs = instance["next_op"].size(0) - num_machines = instance["proc_times"].size(0) - flexibility = round(int(num_eligible) / int(n_ops), 5) - - metadata = f"{num_jobs}\t{num_machines}\t{flexibility}" - lines.append(metadata) - - for i in range(num_jobs): - ops_of_job = instance["job_ops_adj"][i].nonzero().squeeze(1) - job = [len(ops_of_job)] # number of operations of the job - - for op in ops_of_job: - eligible_ma = instance["proc_times"][:, op].nonzero().squeeze(1) - job.append(eligible_ma.size(0)) # num_eligible - - for machine in eligible_ma: - duration = instance["proc_times"][machine, op] - assert duration > 0, "something is wrong" - # add one since in song instances ma indices start from one - job.extend([int(machine.item()) + 1, int(duration.item())]) - - line = " ".join(str(num) for num in job) - lines.append(line) - - formatted = "\n".join(lines) - - file_name = f"{str(id+1).rjust(4, '0')}_{num_jobs}j_{num_machines}m.txt" - full_path = os.path.join(where, file_name) - - with open(full_path, "w") as fh: - fh.write(formatted) - - return formatted - - -def write(where: Union[Path, str], instances: TensorDict): - if not os.path.exists(where): - os.makedirs(where) - - return list(map(partial(write_one, where=where), enumerate(iter(instances)))) diff --git a/rl4co/envs/scheduling/fjsp/render.py b/rl4co/envs/scheduling/fjsp/render.py deleted file mode 100644 index bfb86bf4..00000000 --- a/rl4co/envs/scheduling/fjsp/render.py +++ /dev/null @@ -1,72 +0,0 @@ -from collections import defaultdict - -import matplotlib.pyplot as plt -import numpy as np - -from matplotlib.colors import ListedColormap -from tensordict.tensordict import TensorDict - -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -def render(td: TensorDict, idx: int): - inst = td[idx] - num_jobs = inst["job_ops_adj"].size(0) - - # Define a colormap with a color for each job - colors = plt.cm.tab10(np.linspace(0, 1, num_jobs)) - cmap = ListedColormap(colors) - - assign = inst["ma_assignment"].nonzero() - - schedule = defaultdict(list) - - for val in assign: - machine = val[0].item() - op = val[1].item() - # get start and end times of operation - start = inst["start_times"][val[1]] - end = inst["finish_times"][val[1]] - # write information to schedule dictionary - schedule[machine].append((op, start, end)) - - _, ax = plt.subplots() - - # Plot horizontal bars for each task - for ma, ops in schedule.items(): - for op, start, end in ops: - job = inst["job_ops_adj"][:, op].nonzero().item() - ax.barh( - ma, - end - start, - left=start, - height=0.6, - color=cmap(job), - edgecolor="black", - linewidth=1, - ) - - ax.text( - start + (end - start) / 2, ma, op, ha="center", va="center", color="white" - ) - - # Set labels and title - ax.set_yticks(range(len(schedule))) - ax.set_yticklabels([f"Machine {i}" for i in range(len(schedule))]) - ax.set_xlabel("Time") - ax.set_title("Gantt Chart") - - # Add a legend for class labels - handles = [plt.Rectangle((0, 0), 1, 1, color=cmap(i)) for i in range(num_jobs)] - ax.legend( - handles, - [f"Job {label}" for label in range(num_jobs)], - loc="center left", - bbox_to_anchor=(1, 0.5), - ) - - plt.tight_layout() - # Show the Gantt chart - plt.show() diff --git a/rl4co/envs/scheduling/fjsp/utils.py b/rl4co/envs/scheduling/fjsp/utils.py deleted file mode 100644 index b3ee40b8..00000000 --- a/rl4co/envs/scheduling/fjsp/utils.py +++ /dev/null @@ -1,333 +0,0 @@ -import logging - -from typing import List, Tuple, Union - -import torch - -from tensordict import TensorDict -from torch import Size, Tensor - -from rl4co.envs.scheduling.fjsp import INIT_FINISH - -logger = logging.getLogger(__name__) - - -def get_op_features(td: TensorDict): - return torch.stack((td["lbs"], td["is_ready"], td["num_eligible"]), dim=-1) - - -def cat_and_norm_features( - td: TensorDict, feats: List[str], time_feats: List[str], norm_const: int -): - # logger.info(f"will scale the features {','.join(time_feats)} with a constant ({norm_const})") - feature_list = [] - for feat in feats: - if feat in time_feats: - feature_list.append(td[feat] / norm_const) - else: - feature_list.append(td[feat]) - - return torch.stack(feature_list, dim=-1).to(torch.float32) - - -def view( - tensor: Tensor, - idx: Tuple[Tensor], - pad_mask: Tensor, - new_shape: Union[Size, List[int]], - pad_value: Union[float, int], -): - # convert mask specifying which entries are padded into mask specifying which entries to keep - mask = ~pad_mask - new_view = tensor.new_full(size=new_shape, fill_value=pad_value) - new_view[idx] = tensor[mask] - return new_view - - -def _get_idx_for_job_op_view(td: TensorDict) -> tuple: - bs, _, n_total_ops = td["job_ops_adj"].shape - # (bs, ops) - batch_idx = torch.arange(bs, device=td.device).repeat_interleave(n_total_ops) - batch_idx = batch_idx.reshape(bs, -1) - # (bs, ops) - ops_job_map = td["ops_job_map"] - # (bs, ops) - ops_sequence_order = td["ops_sequence_order"] - # (bs*n_ops_max, 3) - idx = ( - torch.stack((batch_idx, ops_job_map, ops_sequence_order), dim=-1) - .to(torch.long) - .flatten(0, 1) - ) - # (bs, n_ops_max) - mask = ~td["pad_mask"] - # (total_ops_in_batch, 3) - idx = idx[mask.flatten(0, 1)] - b, j, o = map(lambda x: x.squeeze(1), idx.chunk(3, dim=-1)) - return b, j, o - - -def get_job_op_view( - td: TensorDict, keys: List[str] = [], pad_value: Union[float, int] = 0 -): - """This function reshapes all tensors of the tensordict from a flat operations-only view - to a nested job-operation view and creates a new tensordict from it. - :param _type_ td: tensordict - :return _type_: dict - """ - # ============= Prepare the new index ============= - bs, num_jobs, _ = td["job_ops_adj"].shape - max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) - idx = _get_idx_for_job_op_view(td) - new_shape = Size((bs, num_jobs, max_ops_per_job)) - pad_mask = td["pad_mask"] - # ============================================== - - # due to special structure, processing times are treated seperately - if "proc_times" in keys: - keys.remove("proc_times") - # reshape processing times; (bs, ma, ops) -> (bs, ma, jobs, ops_per_job) - new_proc_times_view = view( - td["proc_times"].permute(0, 2, 1), idx, pad_mask, new_shape, pad_value - ).permute(0, 3, 1, 2) - - # add padding mask if not in keys - if "pad_mask" not in keys: - keys.append("pad_mask") - - new_views = dict( - map(lambda key: (key, view(td[key], idx, pad_mask, new_shape)), keys) - ) - - # update tensordict clone with reshaped tensors - return {"proc_times": new_proc_times_view, **new_views} - - -def blockify(td, tensor: Tensor, pad_value: Union[float, int] = 0): - assert len(tensor.shape) in [ - 2, - 3, - ], "blockify only supports tensors of shape (bs, seq, (d)), where the feature dim d is optional" - # get the size of the blockified tensor - bs, _, *d = tensor.shape - num_jobs = td["job_ops_adj"].size(1) - max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) - new_shape = Size((bs, num_jobs, max_ops_per_job, *d)) - # get indices of valid entries of blockified tensor - idx = _get_idx_for_job_op_view(td) - pad_mask = td["pad_mask"] - # create the blockified view - new_view_tensor = view(tensor, idx, pad_mask, new_shape, pad_value) - return new_view_tensor - - -def unblockify( - td: TensorDict, tensor: Tensor, mask: Tensor = None, pad_value: Union[float, int] = 0 -): - assert len(tensor.shape) in [ - 3, - 4, - ], "blockify only supports tensors of shape (bs, nb, s, (d)), where the feature dim d is optional" - # get the size of the blockified tensor - bs, _, _, *d = tensor.shape - n_ops_per_batch = td["job_ops_adj"].sum((1, 2)).unsqueeze(1) # (bs) - seq_len = int(n_ops_per_batch.max()) - new_shape = Size((bs, seq_len, *d)) - - # create the mask to gather then entries of the blockified tensor. NOTE that only by - # blockifying the original pad_mask - pad_mask = td["pad_mask"] - pad_mask = blockify(td, pad_mask, True) - - # get indices of valid entrie in flat matrix - b = torch.arange(bs, device=td.device).repeat_interleave(seq_len).reshape(bs, seq_len) - i = torch.arange(seq_len, device=td.device)[None].repeat(bs, 1) - idx = tuple(map(lambda x: x[i < n_ops_per_batch], (b, i))) - # create view - new_tensor = view(tensor, idx, pad_mask, new_shape, pad_value=pad_value) - return new_tensor - - -def first_diff(x: Tensor, dim: int): - shape = x.shape - shape = (*shape[:dim], 1, *shape[dim + 1 :]) - seq_cutoff = x.index_select(dim, torch.arange(x.size(dim) - 1, device=x.device)) - lagged_seq = x - torch.cat((seq_cutoff.new_zeros(*shape), seq_cutoff), dim=dim) - return lagged_seq - - -def spatial_encoding(td: TensorDict): - """We use a spatial encoing as proposed in GraphFormer (https://arxiv.org/abs/2106.05234) - The spatial encoding in GraphFormer determines the distance of the shortest path between and - nodes i and j and uses a special value for node pairs that cannot be connected at all. - For any two operations i e=2) and for i>j the negative number of - operations that starting from j, have been completet before arriving at i (e.g. i=5 j=3 -> e=-2). - For i=j we set e=0 as well as for operations of different jobs. - - :param torch.Tensor[bs, n_ops] ops_job_map: tensor specifying the index of its corresponding job - :return torch.Tensor[bs, n_ops, n_ops]: length of shortest path between any two operations - """ - bs, _, n_total_ops = td["job_ops_adj"].shape - max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) - ops_job_map = td["ops_job_map"] - pad_mask = td["pad_mask"] - - same_job = (ops_job_map[:, None] == ops_job_map[..., None]).to(torch.int32) - # mask padded - same_job[pad_mask.unsqueeze(2).expand_as(same_job)] = 0 - same_job[pad_mask.unsqueeze(1).expand_as(same_job)] = 0 - # take upper triangular of same_job and set diagonal to zero for counting purposes - upper_tri = torch.triu(same_job) - torch.diag( - torch.ones(n_total_ops, device=td.device) - )[None].expand_as(same_job) - # cumsum and masking of operations that do not belong to the same job - num_jumps = upper_tri.cumsum(2) * upper_tri - # mirror the matrix - num_jumps = num_jumps + num_jumps.transpose(1, 2) - # NOTE: shifted this logic into the spatial encoding module - # num_jumps = num_jumps + (-num_jumps.transpose(1,2)) - assert not torch.any(num_jumps >= max_ops_per_job) - # special value for ops of different jobs and self-loops - num_jumps = torch.where(num_jumps == 0, -1, num_jumps) - self_mask = torch.eye(n_total_ops).repeat(bs, 1, 1).bool() - num_jumps[self_mask] = 0 - return num_jumps - - -def calc_lower_bound(td: TensorDict): - """Here we calculate the lower bound of the operations finish times. In the FJSP case, multiple things need to - be taken into account due to the usability of the different machines for multiple ops of different jobs: - - 1.) Operations may only start once their direct predecessor is finished. We calculate its lower bound by - adding the minimum possible operation time to this detected start time. However, we cannot use the proc_times - directly, but need to account for the fact, that machines might still be busy, once an operation can be processed. - We detect this offset by detecting ops-machine pairs, where the first possible start point of the operation is before - the machine becomes idle again - Therefore, we add this discrepancy to the proc_time of the respective ops-ma combination - - 2.) If an operation has been scheduled, we use its real finishing time as lower bound. In this case, using the cumulative sum - of all peedecessors of a job does not make sense, since it is likely to differ from the real finishing time of its direct - predecessor (its only a lower bound). Therefore, we add the finish time to the cumulative sum of processing time of all - UNSCHEDULED operations, to obtain the lower bound. - Making this work is a bit hacky: We compute the first differences of finishing times of those operations scheduled and - add them to the matrix of processing times, where already processed operations are masked (with zero) - - - :param TensorDict td: _description_ - :return _type_: _description_ - """ - - proc_times = td["proc_times"].clone() # (bs, ma, ops) - busy_until = td["busy_until"] # (bs, ma) - ops_adj = td["ops_adj"] # (bs, ops, ops, 2) - finish_times = td["finish_times"] # (bs, ops) - job_ops_adj = td["job_ops_adj"] # (bs, jobs, ops) - op_scheduled = td["op_scheduled"].to(torch.float32) # (bs, ops) - - ############## REGARDING POINT 1 OF DOCSTRING ############## - # for operations whose immidiate predecessor is scheduled, we can determine its earliest - # start time by the end time of the predecessor. - # (bs, num_ops, 1) - maybe_start_at = torch.bmm(ops_adj[..., 0], finish_times[..., None]).squeeze(2) - # using the start_time, we can determine if and how long an op needs to wait for a machine to finish - wait_for_ma_offset = torch.clip(busy_until[..., None] - maybe_start_at[:, None], 0) - # we add this required waiting time to the respective processing time - after that we determine the best machine for each operation - mask = proc_times == 0 - proc_times[mask] = torch.inf - proc_times += wait_for_ma_offset - # select best machine for operation, given the offset - min_proc_times = proc_times.min(1).values - - ############### REGARDING POINT 2 OF DOCSTRING ################### - # Now we determine all operations that are not scheduled yet (and thus have no finish_time). We will compute the cumulative - # sum over the processing time to determine the lower bound of unscheduled operations... - proc_matrix = job_ops_adj - ops_assigned = proc_matrix * op_scheduled[:, None] - proc_matrix_not_scheduled = proc_matrix * ( - torch.ones_like(proc_matrix) - op_scheduled[:, None] - ) - - # ...and add the finish_time of the last scheduled operation of the respective job to that. To make this work, using the cumsum logic, - # we calc the first differences of the finish times and seperate by job. - # We use the first differences, so that the finish times do not add up during cumulative sum below - # (bs, num_jobs, num_ops) - finish_times_1st_diff = ops_assigned * first_diff( - ops_assigned * finish_times[:, None], 2 - ) - - # masking the processing time of scheduled operations and add their finish times instead (first diff thereof) - lb_end_expand = ( - proc_matrix_not_scheduled * min_proc_times.unsqueeze(1).expand_as(job_ops_adj) - + finish_times_1st_diff - ) - # (bs, max_ops); lower bound finish time per operation using the cumsum logic - LBs = torch.sum(job_ops_adj * lb_end_expand.cumsum(-1), dim=1) - # remove nans - LBs = torch.nan_to_num(LBs, nan=0.0) - - # test - assert torch.where( - finish_times != INIT_FINISH, torch.isclose(LBs, finish_times), True - ).all() - - return LBs - - -def op_is_ready(td: TensorDict): - # compare finish times of predecessors with current time step; shape=(b, n_ops_max) - is_ready = ( - torch.bmm(td["ops_adj"][..., 0], td["finish_times"][..., None]).squeeze(2) - <= td["time"][:, None] - ) - # shape=(b, n_ops_max) - is_scheduled = td["ma_assignment"].sum(1).bool() - # op is ready for scheduling if it has not been scheduled and its predecessor is finished - return torch.logical_and(is_ready, ~is_scheduled) - - -def get_job_ops_mapping( - start_op_per_job: torch.Tensor, end_op_per_job: torch.Tensor, n_ops_max: int -) -> Tuple[torch.Tensor, torch.Tensor]: - """Implements a mapping function from operations to jobs - - :param torch.Tensor start_op_per_job: index of first operation of each job - :param torch.Tensor end_op_per_job: index of last operation of each job - :return Tuple[torch.Tensor, torch.Tensor]: - 1st.) index mapping (bs, num_ops): [0,0,1,1,1] means that first two operations belong to job 0 - 2st.) binary mapping (bs, num_jobs, num_ops): [[1,1,0], [0,0,1]] means that first two operations belong to job 0 - """ - device = end_op_per_job.device - end_op_per_job = end_op_per_job.clone() - - bs, num_jobs = end_op_per_job.shape - - # in order to avoid shape conflicts, set the end operation id to the id of max_ops (all batches have same #ops) - end_op_per_job[:, -1] = n_ops_max - 1 - - # here we will generate the operations-job mapping: - # Therefore we first generate a sequence of operation ids and expand it the the size of the mapping matrix: - # (bs, jobs, max_ops) - ops_seq_exp = torch.arange(n_ops_max, device=device)[None, None].expand( - bs, num_jobs, -1 - ) - # (bs, jobs, max_ops) # expanding start and end operation ids - end_op_per_job_exp = end_op_per_job[..., None].expand_as(ops_seq_exp) - start_op_per_job_exp = start_op_per_job[..., None].expand_as(ops_seq_exp) - # given ids of start and end operations per job, this generates the mapping of ops to jobs - # (bs, jobs, max_ops) - ops_job_map = torch.nonzero( - (ops_seq_exp <= end_op_per_job_exp) & (ops_seq_exp >= start_op_per_job_exp) - ) - # (bs, max_ops) - ops_job_map = torch.stack(ops_job_map[:, 1].split(n_ops_max), dim=0) - - # we might also want a binary mapping / adjacency matrix connecting jobs to operations - # (bs, num_jobs, num_ops) - ops_job_bin_map = torch.scatter_add( - input=ops_job_map.new_zeros((bs, num_jobs, n_ops_max)), - dim=1, - index=ops_job_map.unsqueeze(1), - src=ops_job_map.new_ones((bs, num_jobs, n_ops_max)), - ) - - return ops_job_map, ops_job_bin_map diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index 0ebec158..a99a4278 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -19,7 +19,6 @@ from rl4co.models.rl.ppo.ppo import PPO from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline from rl4co.models.rl.reinforce.reinforce import REINFORCE -from rl4co.models.zoo import HetGNNModel from rl4co.models.zoo.active_search import ActiveSearch from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy from rl4co.models.zoo.amppo import AMPPO diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index 7fa92a0b..a4517912 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -19,10 +19,6 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "tsp": TSPContext, "atsp": TSPContext, "cvrp": VRPContext, - "vrpb": VRPContext, - "ovrp": VRPContext, - "vrpl": VRPContext, - "cvrptw": VRPTWContext, "cvrptw": VRPTWContext, "ffsp": FFSPContext, "svrp": SVRPContext, @@ -36,6 +32,7 @@ def env_context_embedding(env_name: str, config: dict) -> nn.Module: "mtsp": MTSPContext, "smtwtp": SMTWTPContext, "mdcpdp": MDCPDPContext, + "mtvrp": MTVRPContext } if env_name not in embedding_registry: @@ -150,6 +147,50 @@ def _state_embedding(self, embeddings, td): state_embedding = td["vehicle_capacity"] - td["used_capacity"] return state_embedding +class VRPBContext(EnvContext): + """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). + Project the following to the embedding space: + - current node embedding + - remaining capacity (vehicle_capacity - used_capacity) + """ + + def __init__(self, embed_dim): + super(VRPContext, self).__init__( + embed_dim=embed_dim, step_context_dim=embed_dim + 1 + ) + + def _state_embedding(self, embeddings, td): + mask = (td["used_capacity_backhaul"] == 0) + used_capacity = torch.where(mask, td["used_capacity_linehaul"], td["used_capacity_backhaul"]) + state_embedding = td["vehicle_capacity"] - used_capacity + return state_embedding + +class MTVRPContext(VRPBContext): + """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). + Project the following to the embedding space: + - current node embedding + - remaining capacity (vehicle_capacity - used_capacity) + - current time + - current route length + - if route should be open + """ + + def __init__(self, embed_dim): + super(VRPBContext, self).__init__( + embed_dim=embed_dim, step_context_dim=embed_dim + 4 + ) + + def _state_embedding(self, embeddings, td): + + capacity = super()._state_embedding(embeddings, td) + current_time = td["current_time"] + current_length = td["current_route_length"] + is_open = td["open_route"] + is_open_tensor = torch.zeros_like(is_open, dtype=torch.float) + is_open_tensor[is_open] = 1 + + return torch.cat([capacity, current_time, current_length, is_open_tensor], -1) + class VRPTWContext(VRPContext): """Context embedding for the Capacitated Vehicle Routing Problem (CVRP). diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index 2d18dfc7..fef341aa 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -3,8 +3,6 @@ from tensordict.tensordict import TensorDict -from rl4co.models.nn.ops import PositionalEncoding - def env_init_embedding(env_name: str, config: dict) -> nn.Module: """Get environment initial embedding. The init embedding is used to initialize the @@ -20,11 +18,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "atsp": TSPInitEmbedding, "matnet": MatNetInitEmbedding, "cvrp": VRPInitEmbedding, - "vrpb": VRPInitEmbedding, - "vrpl": VRPInitEmbedding, - "ovrp": VRPInitEmbedding, "cvrptw": VRPTWInitEmbedding, - "vrptw": VRPTWInitEmbedding, "svrp": SVRPInitEmbedding, "sdvrp": VRPInitEmbedding, "pctsp": PCTSPInitEmbedding, @@ -36,7 +30,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "mtsp": MTSPInitEmbedding, "smtwtp": SMTWTPInitEmbedding, "mdcpdp": MDCPDPInitEmbedding, - "fjsp": FJSPFeatureEmbedding, + "mtvrp":MTVRPInitEmbedding, } if env_name not in embedding_registry: @@ -150,6 +144,28 @@ def forward(self, td): ) ) return torch.cat((depot_embedding, node_embeddings), -2) + + +class MTVRPInitEmbedding(VRPInitEmbedding): + def __init__(self, embed_dim, linear_bias=True, node_dim: int = 5): + # node_dim = 5: x, y, demand, tw start, tw end + super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim) + + def forward(self, td): + depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] + #durations = td["durations"][..., 1:] + time_windows = td["time_windows"][..., 1:, :] + # embeddings + demands = td["demand_linehaul"][..., None] - td["demand_backhaul"][..., None] + + depot_embedding = self.init_embed_depot(depot) + node_embeddings = self.init_embed( + torch.cat( + (cities, demands[:,1:], time_windows), -1 + ) + ) + + return torch.cat((depot_embedding, node_embeddings), -2) class SVRPInitEmbedding(nn.Module): @@ -386,65 +402,3 @@ def forward(self, td): delivery_embeddings = self.init_embed_delivery(delivery_feats) # concatenate on graph size dimension return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) - - -class FJSPFeatureEmbedding(nn.Module): - def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100): - super().__init__() - self.embed_dim = embed_dim - self.norm_coef = norm_coef - - self.init_ope_embed = nn.Linear(4, self.embed_dim, bias=False) - self.edge_embed = nn.Linear(1, embed_dim, bias=False) - - self.ope_pos_enc = PositionalEncoding(embed_dim) - # TODO allow for reencoding after each step - self.stepwise = False - - def forward(self, td: TensorDict): - if self.stepwise: - ops_emb = self._stepwise_operations_embed(td) - ma_emb = self._stepwise_machine_embed(td) - edge_emb = None - else: - ops_emb = self._init_operations_embed(td) - ma_emb = self._init_machine_embed(td) - edge_emb = self._init_edge_embed(td) - return ma_emb, ops_emb, edge_emb - - def _init_operations_embed(self, td: TensorDict): - pos = td["ops_sequence_order"] - - features = [ - td["lbs"].unsqueeze(-1) / self.norm_coef, - td["is_ready"].unsqueeze(-1), - td["num_eligible"].unsqueeze(-1), - td["ops_job_map"].unsqueeze(-1), - ] - features = torch.cat(features, dim=-1) - # (bs, num_ops, emb_dim) - ops_embeddings = self.init_ope_embed(features) - - # (bs, num_ops, emb_dim) - ops_embeddings = self.ope_pos_enc(ops_embeddings, pos.to(torch.int64)) - # zero out padded entries - ops_embeddings[td["pad_mask"].unsqueeze(-1).expand_as(ops_embeddings)] = 0 - return ops_embeddings - - def _init_machine_embed(self, td: TensorDict): - bs, num_ma = td["busy_until"].shape - ma_embeddings = torch.zeros( - (bs, num_ma, self.embed_dim), device=td.device, dtype=torch.float32 - ) - return ma_embeddings - - def _init_edge_embed(self, td: TensorDict): - proc_times = td["proc_times"].unsqueeze(-1) / self.norm_coef - edge_embed = self.edge_embed(proc_times) - return edge_embed - - def _stepwise_operations_embed(self, td: TensorDict): - raise NotImplementedError("Stepwise encoding not yet implemented") - - def _stepwise_machine_embed(self, td: TensorDict): - raise NotImplementedError("Stepwise encoding not yet implemented") diff --git a/rl4co/models/nn/ops.py b/rl4co/models/nn/ops.py index ebbb5063..04fec365 100644 --- a/rl4co/models/nn/ops.py +++ b/rl4co/models/nn/ops.py @@ -1,6 +1,5 @@ import math -import torch import torch.nn as nn @@ -36,32 +35,3 @@ def forward(self, x): else: assert self.normalizer is None, "Unknown normalizer type" return x - - -class PositionalEncoding(nn.Module): - def __init__(self, embed_dim: int, dropout: float = 0.1, max_len: int = 1000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - self.d_model = embed_dim - max_len = max_len - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model) - ) - pe = torch.zeros(max_len, 1, self.d_model) - pe[:, 0, 0::2] = torch.sin(position * div_term) - pe[:, 0, 1::2] = torch.cos(position * div_term) - pe = pe.transpose(0, 1) # [1, max_len, d_model] - self.register_buffer("pe", pe) - - def forward(self, hidden: torch.Tensor, seq_pos) -> torch.Tensor: - """ - Arguments: - x: Tensor, shape ``[batch_size, seq_len, embedding_dim]`` - seq_pos: Tensor, shape ``[batch_size, seq_len]`` - """ - pes = self.pe.expand(hidden.size(0), -1, -1).gather( - 1, seq_pos.unsqueeze(-1).expand(-1, -1, self.d_model) - ) - hidden = hidden + pes - return self.dropout(hidden) diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 360c17aa..477750c4 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -53,7 +53,7 @@ def shared_step( ): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) - out = self.policy(td, self.env, phase=phase, select_best=phase != "train") + out = self.policy(td, self.env, phase=phase) # Compute loss if phase == "train": diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py index c16bbe9b..7796c630 100644 --- a/rl4co/models/zoo/__init__.py +++ b/rl4co/models/zoo/__init__.py @@ -10,7 +10,6 @@ HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy, ) -from rl4co.models.zoo.hetgnn import HetGNNModel from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy from rl4co.models.zoo.nargnn import NARGNNPolicy diff --git a/rl4co/models/zoo/hetgnn/__init__.py b/rl4co/models/zoo/hetgnn/__init__.py deleted file mode 100644 index f98562b4..00000000 --- a/rl4co/models/zoo/hetgnn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .model import HetGNNModel diff --git a/rl4co/models/zoo/hetgnn/decoder.py b/rl4co/models/zoo/hetgnn/decoder.py deleted file mode 100644 index 68bf1d36..00000000 --- a/rl4co/models/zoo/hetgnn/decoder.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - -from rl4co.models.common.constructive.autoregressive import AutoregressiveDecoder -from rl4co.models.nn.mlp import MLP -from rl4co.utils.ops import batchify, gather_by_index - - -class HetGNNDecoder(AutoregressiveDecoder): - def __init__( - self, embed_dim, feed_forward_hidden_dim: int = 64, feed_forward_layers: int = 2 - ) -> None: - super().__init__() - self.mlp = MLP( - input_dim=2 * embed_dim, - output_dim=1, - num_neurons=[feed_forward_hidden_dim] * feed_forward_layers, - ) - self.dummy = nn.Parameter(torch.rand(2 * embed_dim)) - - def pre_decoder_hook(self, td, env, hidden, num_starts): - return td, env, hidden - - def forward(self, td, hidden, num_starts): - if num_starts > 1: - hidden = tuple(map(lambda x: batchify(x, num_starts), hidden)) - - ma_emb, ops_emb = hidden - bs, n_rows, emb_dim = ma_emb.shape - - # (bs, n_jobs, emb) - job_emb = gather_by_index(ops_emb, td["next_op"], squeeze=False) - - # (bs, n_jobs, n_ma, emb) - job_emb_expanded = job_emb.unsqueeze(2).expand(-1, -1, n_rows, -1) - ma_emb_expanded = ma_emb.unsqueeze(1).expand_as(job_emb_expanded) - - # Input of actor MLP - # shape: [bs, num_jobs * n_ma, 2*emb] - h_actions = torch.cat((job_emb_expanded, ma_emb_expanded), dim=-1).flatten(1, 2) - no_ops = self.dummy[None, None].expand(bs, 1, -1) # [bs, 1, 2*emb_dim] - # [bs, num_jobs * n_ma + 1, 2*emb_dim] - h_actions_w_noop = torch.cat((no_ops, h_actions), 1) - - # (b, j*m) - mask = td["action_mask"] - - # (b, j*m) - logits = self.mlp(h_actions_w_noop).squeeze(-1) - - return logits, mask diff --git a/rl4co/models/zoo/hetgnn/encoder.py b/rl4co/models/zoo/hetgnn/encoder.py deleted file mode 100644 index 6f966cf8..00000000 --- a/rl4co/models/zoo/hetgnn/encoder.py +++ /dev/null @@ -1,132 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import einsum -from torch import Tensor - -from rl4co.models.nn.env_embeddings import env_init_embedding -from rl4co.models.nn.ops import Normalization - - -class HetGNNLayer(nn.Module): - def __init__( - self, - embed_dim: int, - ) -> None: - super().__init__() - - self.self_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) - self.cross_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) - self.edge_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) - self.activation = nn.ReLU() - self.scale = 1 / math.sqrt(embed_dim) - - def forward( - self, self_emb: Tensor, other_emb: Tensor, edge_emb: Tensor, edges: Tensor - ): - bs, n_rows, _ = self_emb.shape - - # concat operation embeddings and o-m edge features (proc times) - # Calculate attention coefficients - er = einsum(self_emb, self.self_attn, "b m e, e one -> b m") * self.scale - ec = einsum(other_emb, self.cross_attn, "b o e, e one -> b o") * self.scale - ee = einsum(edge_emb, self.edge_attn, "b m o e, e one -> b m o") * self.scale - - # element wise multiplication similar to broadcast column logits over rows with masking - ec_expanded = einsum(edges, ec, "b m o, b o -> b m o") - # element wise multiplication similar to broadcast row logits over cols with masking - er_expanded = einsum(edges, er, "b m o, b m -> b m o") - - # adding the projections of different node types and edges together (equivalent to first concat and then project) - # (bs, n_rows, n_cols) - cross_logits = self.activation(ec_expanded + ee + er_expanded) - - # (bs, n_rows, 1) - self_logits = self.activation(er + er).unsqueeze(-1) - - # (bs, n_ma, n_ops + 1) - mask = torch.cat( - ( - edges == 1, - torch.full( - size=(bs, n_rows, 1), - dtype=torch.bool, - fill_value=True, - device=edges.device, - ), - ), - dim=-1, - ) - - # (bs, n_ma, n_ops + 1) - all_logits = torch.cat((cross_logits, self_logits), dim=-1) - all_logits[~mask] = -torch.inf - attn_scores = F.softmax(all_logits, dim=-1) - # (bs, n_ma, n_ops) - cross_attn_scores = attn_scores[..., :-1] - # (bs, n_ma, 1) - self_attn_scores = attn_scores[..., -1].unsqueeze(-1) - - # augment column embeddings with edge features, (bs, r, c, e) - other_emb_aug = edge_emb + other_emb.unsqueeze(-3) - cross_emb = einsum(cross_attn_scores, other_emb_aug, "b m o, b m o e -> b m e") - self_emb = self_emb * self_attn_scores - # (bs, n_ma, emb_dim) - hidden = torch.sigmoid(cross_emb + self_emb) - return hidden - - -class HetGNNBlock(nn.Module): - def __init__(self, embed_dim) -> None: - super().__init__() - self.norm1 = Normalization(embed_dim, normalization="batch") - self.norm2 = Normalization(embed_dim, normalization="batch") - self.hgnn1 = HetGNNLayer(embed_dim) - self.hgnn2 = HetGNNLayer(embed_dim) - - def forward(self, x1, x2, edge_emb, edges): - h1 = self.hgnn1(x1, x2, edge_emb, edges) - h1 = self.norm1(h1 + x1) - - h2 = self.hgnn2(x2, x1, edge_emb.transpose(1, 2), edges.transpose(1, 2)) - h2 = self.norm2(h2 + x2) - - return h1, h2 - - -class HetGNNEncoder(nn.Module): - def __init__( - self, - embed_dim: int, - num_layers: int = 2, - init_embedding=None, - edge_key: str = "ops_ma_adj", - edge_weights_key: str = "proc_times", - linear_bias: bool = False, - ) -> None: - super().__init__() - - if init_embedding is None: - init_embedding = env_init_embedding("fjsp", {"embed_dim": embed_dim}) - self.init_embedding = init_embedding - - self.edge_key = edge_key - self.edge_weights_key = edge_weights_key - - self.num_layers = num_layers - self.layers = nn.ModuleList([HetGNNBlock(embed_dim) for _ in range(num_layers)]) - - def forward(self, td): - edges = td[self.edge_key] - bs, n_rows, n_cols = edges.shape - row_emb, col_emb, edge_emb = self.init_embedding(td) - assert row_emb.size(1) == n_rows, "incorrect number of row embeddings" - assert col_emb.size(1) == n_cols, "incorrect number of column embeddings" - - for layer in self.layers: - row_emb, col_emb = layer(row_emb, col_emb, edge_emb, edges) - - return (row_emb, col_emb), None diff --git a/rl4co/models/zoo/hetgnn/model.py b/rl4co/models/zoo/hetgnn/model.py deleted file mode 100644 index 40f27de2..00000000 --- a/rl4co/models/zoo/hetgnn/model.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Union - -from rl4co.envs.common.base import RL4COEnvBase -from rl4co.models.rl import REINFORCE -from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline - -from .policy import HetGNNPolicy - - -class HetGNNModel(REINFORCE): - """Heterogenous Graph Neural Network Model as described by Song et al. (2022): - 'Flexible Job Shop Scheduling via Graph Neural Network and Deep Reinforcement Learning' - - Args: - env: Environment to use for the algorithm - policy: Policy to use for the algorithm - baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) - policy_kwargs: Keyword arguments for policy - baseline_kwargs: Keyword arguments for baseline - **kwargs: Keyword arguments passed to the superclass - """ - - def __init__( - self, - env: RL4COEnvBase, - policy: HetGNNPolicy = None, - baseline: Union[REINFORCEBaseline, str] = "rollout", - policy_kwargs={}, - baseline_kwargs={}, - **kwargs, - ): - assert ( - env.name == "fjsp" - ), "HetGNNModel currently only works for FJSP (Flexible Job-Shop Scheduling Problem)" - if policy is None: - policy = HetGNNPolicy(env_name=env.name, **policy_kwargs) - - super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/hetgnn/policy.py b/rl4co/models/zoo/hetgnn/policy.py deleted file mode 100644 index c51dc30e..00000000 --- a/rl4co/models/zoo/hetgnn/policy.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Optional - -import torch.nn as nn - -from rl4co.models.common.constructive.autoregressive import ( - AutoregressiveDecoder, - AutoregressiveEncoder, - AutoregressivePolicy, -) -from rl4co.utils.pylogger import get_pylogger - -from .decoder import HetGNNDecoder -from .encoder import HetGNNEncoder - -log = get_pylogger(__name__) - - -class HetGNNPolicy(AutoregressivePolicy): - """ - Base Non-autoregressive policy for NCO construction methods. - This creates a heatmap of NxN for N nodes (i.e., heuristic) that models the probability to go from one node to another for all nodes. - - The policy performs the following steps: - 1. Encode the environment initial state into node embeddings - 2. Decode (non-autoregressively) to construct the solution to the NCO problem - - Warning: - The effectiveness of the non-autoregressive approach can vary significantly across different problem types and configurations. - It may require careful tuning of the model architecture and decoding strategy to achieve competitive results. - - Args: - encoder: Encoder module. Can be passed by sub-classes - decoder: Decoder module. Note that this moule defaults to the non-autoregressive decoder - embed_dim: Dimension of the embeddings - env_name: Name of the environment used to initialize embeddings - init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment - edge_embedding: Model to use for the edge embedding. If None, use the default embedding for the environment - graph_network: Model to use for the graph network. If None, use the default embedding for the environment - heatmap_generator: Model to use for the heatmap generator. If None, use the default embedding for the environment - num_layers_heatmap_generator: Number of layers in the heatmap generator - num_layers_graph_encoder: Number of layers in the graph encoder - act_fn: Activation function to use in the encoder - agg_fn: Aggregation function to use in the encoder - linear_bias: Whether to use bias in the encoder - train_decode_type: Type of decoding during training - val_decode_type: Type of decoding during validation - test_decode_type: Type of decoding during testing - **constructive_policy_kw: Unused keyword arguments - """ - - def __init__( - self, - encoder: Optional[AutoregressiveEncoder] = None, - decoder: Optional[AutoregressiveDecoder] = None, - embed_dim: int = 64, - num_encoder_layers: int = 2, - env_name: str = "fjsp", - init_embedding: Optional[nn.Module] = None, - linear_bias: bool = True, - train_decode_type: str = "sampling", - val_decode_type: str = "greedy", - test_decode_type: str = "multistart_sampling", - **constructive_policy_kw, - ): - if len(constructive_policy_kw) > 0: - log.warn(f"Unused kwargs: {constructive_policy_kw}") - - if encoder is None: - encoder = HetGNNEncoder( - embed_dim=embed_dim, - num_layers=num_encoder_layers, - init_embedding=init_embedding, - linear_bias=linear_bias, - ) - - # The decoder generates logits given the current td and heatmap - if decoder is None: - decoder = HetGNNDecoder( - embed_dim=embed_dim, - feed_forward_hidden_dim=embed_dim, - feed_forward_layers=2, - ) - else: - # check if the decoder has trainable parameters - if any(p.requires_grad for p in decoder.parameters()): - log.error( - "The decoder contains trainable parameters. This should not happen in a non-autoregressive policy." - ) - - # Pass to constructive policy - super(HetGNNPolicy, self).__init__( - encoder=encoder, - decoder=decoder, - env_name=env_name, - train_decode_type=train_decode_type, - val_decode_type=val_decode_type, - test_decode_type=test_decode_type, - **constructive_policy_kw, - ) diff --git a/rl4co/tasks/__init__.py b/rl4co/tasks/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/rl4co/tasks/eval.py b/rl4co/tasks/eval.py deleted file mode 100644 index bfa7de0a..00000000 --- a/rl4co/tasks/eval.py +++ /dev/null @@ -1,405 +0,0 @@ -import numpy as np -import torch - -from torch.utils.data import DataLoader -from tqdm.auto import tqdm - -from rl4co.data.transforms import StateAugmentation -from rl4co.utils.ops import batchify, gather_by_index, sample_n_random_actions, unbatchify - - -def check_unused_kwargs(class_, kwargs): - if len(kwargs) > 0 and not (len(kwargs) == 1 and "progress" in kwargs): - print(f"Warning: {class_.__class__.__name__} does not use kwargs {kwargs}") - - -class EvalBase: - """Base class for evaluation - - Args: - env: Environment - progress: Whether to show progress bar - **kwargs: Additional arguments (to be implemented in subclasses) - """ - - name = "base" - - def __init__(self, env, progress=True, **kwargs): - check_unused_kwargs(self, kwargs) - self.env = env - self.progress = progress - - def __call__(self, policy, dataloader, **kwargs): - """Evaluate the policy on the given dataloader with **kwargs parameter - self._inner is implemented in subclasses and returns actions and rewards - """ - - # Collect timings for evaluation (more accurate than timeit) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - with torch.inference_mode(): - rewards_list = [] - actions_list = [] - - for batch in tqdm( - dataloader, disable=not self.progress, desc=f"Running {self.name}" - ): - td = batch.to(next(policy.parameters()).device) - td = self.env.reset(td) - actions, rewards = self._inner(policy, td, **kwargs) - rewards_list.append(rewards) - actions_list.append(actions) - - rewards = torch.cat(rewards_list) - - # Padding: pad actions to the same length with zeros - max_length = max(action.size(-1) for action in actions) - actions = torch.cat( - [ - torch.nn.functional.pad(action, (0, max_length - action.size(-1))) - for action in actions - ], - 0, - ) - - end_event.record() - torch.cuda.synchronize() - inference_time = start_event.elapsed_time(end_event) - - tqdm.write(f"Mean reward for {self.name}: {rewards.mean():.4f}") - tqdm.write(f"Time: {inference_time/1000:.4f}s") - - # Empty cache - torch.cuda.empty_cache() - - return { - "actions": actions.cpu(), - "rewards": rewards.cpu(), - "inference_time": inference_time, - "avg_reward": rewards.cpu().mean(), - } - - def _inner(self, policy, td): - """Inner function to be implemented in subclasses. - This function returns actions and rewards for the given policy - """ - raise NotImplementedError("Implement in subclass") - - -class GreedyEval(EvalBase): - """Evaluates the policy using greedy decoding and single trajectory""" - - name = "greedy" - - def __init__(self, env, **kwargs): - check_unused_kwargs(self, kwargs) - super().__init__(env, kwargs.get("progress", True)) - - def _inner(self, policy, td): - out = policy( - td.clone(), - decode_type="greedy", - num_starts=0, - return_actions=True, - ) - rewards = self.env.get_reward(td, out["actions"]) - return out["actions"], rewards - - -class AugmentationEval(EvalBase): - """Evaluates the policy via N state augmentations - `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO - https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 - - Args: - num_augment (int): Number of state augmentations - force_dihedral_8 (bool): Whether to force the use of 8 augmentations - """ - - name = "augmentation" - - def __init__(self, env, num_augment=8, force_dihedral_8=False, feats=None, **kwargs): - check_unused_kwargs(self, kwargs) - super().__init__(env, kwargs.get("progress", True)) - self.augmentation = StateAugmentation( - num_augment=num_augment, - augment_fn="dihedral8" if force_dihedral_8 else "symmetric", - feats=feats, - ) - - def _inner(self, policy, td, num_augment=None): - if num_augment is None: - num_augment = self.augmentation.num_augment - td_init = td.clone() - td = self.augmentation(td) - out = policy(td.clone(), decode_type="greedy", num_starts=0, return_actions=True) - - # Move into batches and compute rewards - rewards = self.env.get_reward(batchify(td_init, num_augment), out["actions"]) - rewards = unbatchify(rewards, num_augment) - actions = unbatchify(out["actions"], num_augment) - - # Get best reward and corresponding action - rewards, max_idxs = rewards.max(dim=1) - actions = gather_by_index(actions, max_idxs, dim=1) - return actions, rewards - - @property - def num_augment(self): - return self.augmentation.num_augment - - -class SamplingEval(EvalBase): - """Evaluates the policy via N samples from the policy - - Args: - samples (int): Number of samples to take - softmax_temp (float): Temperature for softmax sampling. The higher the temperature, the more random the sampling - """ - - name = "sampling" - - def __init__(self, env, samples, softmax_temp=None, **kwargs): - check_unused_kwargs(self, kwargs) - super().__init__(env, kwargs.get("progress", True)) - - self.samples = samples - self.softmax_temp = softmax_temp - - def _inner(self, policy, td): - out = policy( - td.clone(), - decode_type="sampling", - num_starts=self.samples, - multistart=True, - return_actions=True, - softmax_temp=self.softmax_temp, - select_best=True, - select_start_nodes_fn=lambda td, _, n: sample_n_random_actions(td, n), - ) - - # Move into batches and compute rewards - rewards = out["reward"] - actions = out["actions"] - - return actions, rewards - - -class GreedyMultiStartEval(EvalBase): - """Evaluates the policy via `num_starts` greedy multistarts samples from the policy - - Args: - num_starts (int): Number of greedy multistarts to use - """ - - name = "multistart_greedy" - - def __init__(self, env, num_starts=None, **kwargs): - check_unused_kwargs(self, kwargs) - super().__init__(env, kwargs.get("progress", True)) - - assert num_starts is not None, "Must specify num_starts" - self.num_starts = num_starts - - def _inner(self, policy, td): - td_init = td.clone() - out = policy( - td.clone(), - decode_type="multistart_greedy", - num_starts=self.num_starts, - return_actions=True, - ) - - # Move into batches and compute rewards - td = batchify(td_init, self.num_starts) - rewards = self.env.get_reward(td, out["actions"]) - rewards = unbatchify(rewards, self.num_starts) - actions = unbatchify(out["actions"], self.num_starts) - - # Get the best trajectories - rewards, max_idxs = rewards.max(dim=1) - actions = gather_by_index(actions, max_idxs, dim=1) - return actions, rewards - - -class GreedyMultiStartAugmentEval(EvalBase): - """Evaluates the policy via `num_starts` samples from the policy - and `num_augment` augmentations of each sample.` - `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO - https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 - - Args: - num_starts: Number of greedy multistart samples - num_augment: Number of augmentations per sample - force_dihedral_8: If True, force the use of 8 augmentations (rotations and flips) as in POMO - """ - - name = "multistart_greedy_augment" - - def __init__( - self, - env, - num_starts=None, - num_augment=8, - force_dihedral_8=False, - feats=None, - **kwargs, - ): - check_unused_kwargs(self, kwargs) - super().__init__(env, kwargs.get("progress", True)) - - assert num_starts is not None, "Must specify num_starts" - self.num_starts = num_starts - assert not ( - num_augment != 8 and force_dihedral_8 - ), "Cannot force dihedral 8 when num_augment != 8" - self.augmentation = StateAugmentation( - num_augment=num_augment, - augment_fn="dihedral8" if force_dihedral_8 else "symmetric", - feats=feats, - ) - - def _inner(self, policy, td, num_augment=None): - if num_augment is None: - num_augment = self.augmentation.num_augment - - td_init = td.clone() - - td = self.augmentation(td) - out = policy( - td.clone(), - decode_type="multistart_greedy", - num_starts=self.num_starts, - return_actions=True, - ) - - # Move into batches and compute rewards - td = batchify(td_init, (num_augment, self.num_starts)) - rewards = self.env.get_reward(td, out["actions"]) - rewards = unbatchify(rewards, self.num_starts * num_augment) - actions = unbatchify(out["actions"], self.num_starts * num_augment) - - # Get the best trajectories - rewards, max_idxs = rewards.max(dim=1) - actions = gather_by_index(actions, max_idxs, dim=1) - return actions, rewards - - @property - def num_augment(self): - return self.augmentation.num_augment - - -def get_automatic_batch_size(eval_fn, start_batch_size=8192, max_batch_size=4096): - """Automatically reduces the batch size based on the eval function - - Args: - eval_fn: The eval function - start_batch_size: The starting batch size. This should be the theoretical maximum batch size - max_batch_size: The maximum batch size. This is the practical maximum batch size - """ - batch_size = start_batch_size - - effective_ratio = 1 - - if hasattr(eval_fn, "num_starts"): - batch_size = batch_size // (eval_fn.num_starts // 10) - effective_ratio *= eval_fn.num_starts // 10 - if hasattr(eval_fn, "num_augment"): - batch_size = batch_size // eval_fn.num_augment - effective_ratio *= eval_fn.num_augment - if hasattr(eval_fn, "samples"): - batch_size = batch_size // eval_fn.samples - effective_ratio *= eval_fn.samples - - batch_size = min(batch_size, max_batch_size) - # get closest integer power of 2 - batch_size = 2 ** int(np.log2(batch_size)) - - print(f"Effective batch size: {batch_size} (ratio: {effective_ratio})") - - return batch_size - - -def evaluate_policy( - env, - policy, - dataset, - method="greedy", - batch_size=None, - max_batch_size=4096, - start_batch_size=8192, - auto_batch_size=True, - save_results=False, - save_fname="results.npz", - **kwargs, -): - num_loc = getattr(env.generator, "num_loc", None) - - methods_mapping = { - "greedy": {"func": GreedyEval, "kwargs": {}}, - "sampling": { - "func": SamplingEval, - "kwargs": {"samples": 100, "softmax_temp": 1.0}, - }, - "multistart_greedy": { - "func": GreedyMultiStartEval, - "kwargs": {"num_starts": num_loc}, - }, - "augment_dihedral_8": { - "func": AugmentationEval, - "kwargs": {"num_augment": 8, "force_dihedral_8": True}, - }, - "augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}}, - "multistart_greedy_augment_dihedral_8": { - "func": GreedyMultiStartAugmentEval, - "kwargs": { - "num_augment": 8, - "force_dihedral_8": True, - "num_starts": num_loc, - }, - }, - "multistart_greedy_augment": { - "func": GreedyMultiStartAugmentEval, - "kwargs": {"num_augment": 8, "num_starts": num_loc}, - }, - } - - assert method in methods_mapping, "Method {} not found".format(method) - - # Set up the evaluation function - eval_settings = methods_mapping[method] - func, kwargs_ = eval_settings["func"], eval_settings["kwargs"] - # subsitute kwargs with the ones passed in - kwargs_.update(kwargs) - kwargs = kwargs_ - eval_fn = func(env, **kwargs) - - if auto_batch_size: - assert ( - batch_size is None - ), "Cannot specify batch_size when auto_batch_size is True" - batch_size = get_automatic_batch_size( - eval_fn, max_batch_size=max_batch_size, start_batch_size=start_batch_size - ) - print("Using automatic batch size: {}".format(batch_size)) - - # Set up the dataloader - dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=0, - collate_fn=dataset.collate_fn, - ) - - # Run evaluation - retvals = eval_fn(policy, dataloader) - - # Save results - if save_results: - print("Saving results to {}".format(save_fname)) - np.savez(save_fname, **retvals) - - return retvals diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py deleted file mode 100644 index 6826382d..00000000 --- a/rl4co/tasks/train.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import List, Optional, Tuple - -import hydra -import lightning as L -import pyrootutils -import torch - -from lightning import Callback, LightningModule -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from rl4co import utils -from rl4co.utils import RL4COTrainer - -pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True) - - -log = utils.get_pylogger(__name__) - - -@utils.task_wrapper -def run(cfg: DictConfig) -> Tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. - - Args: - cfg (DictConfig): Configuration composed by Hydra. - Returns: - Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. - """ - - # set seed for random number generators in pytorch, numpy and python.random - if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=True) - - # We instantiate the environment separately and then pass it to the model - log.info(f"Instantiating environment <{cfg.env._target_}>") - env = hydra.utils.instantiate(cfg.env) - - # Note that the RL environment is instantiated inside the model - log.info(f"Instantiating model <{cfg.model._target_}>") - model: LightningModule = hydra.utils.instantiate(cfg.model, env) - - log.info("Instantiating callbacks...") - callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) - - log.info("Instantiating loggers...") - logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) - - log.info("Instantiating trainer...") - trainer: RL4COTrainer = hydra.utils.instantiate( - cfg.trainer, - callbacks=callbacks, - logger=logger, - ) - - object_dict = { - "cfg": cfg, - "model": model, - "callbacks": callbacks, - "logger": logger, - "trainer": trainer, - } - - if logger: - log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) - - if cfg.get("compile", False): - log.info("Compiling model!") - model = torch.compile(model) - - if cfg.get("train"): - log.info("Starting training!") - trainer.fit(model=model, ckpt_path=cfg.get("ckpt_path")) - - train_metrics = trainer.callback_metrics - - if cfg.get("test"): - log.info("Starting testing!") - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = None - trainer.test(model=model, ckpt_path=ckpt_path) - log.info(f"Best ckpt path: {ckpt_path}") - - test_metrics = trainer.callback_metrics - - # merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} - - return metric_dict, object_dict - - -@hydra.main(version_base="1.3", config_path="../../configs", config_name="main.yaml") -def train(cfg: DictConfig) -> Optional[float]: - # apply extra utilities - # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) - utils.extras(cfg) - - # train the model - metric_dict, _ = run(cfg) - - # safely retrieve metric value for hydra-based hyperparameter optimization - metric_value = utils.get_metric_value( - metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") - ) - - # return optimized metric - return metric_value - - -if __name__ == "__main__": - train() diff --git a/rl4co/utils/__init__.py b/rl4co/utils/__init__.py deleted file mode 100644 index 4b0246aa..00000000 --- a/rl4co/utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from rl4co.utils.instantiators import instantiate_callbacks, instantiate_loggers -from rl4co.utils.pylogger import get_pylogger -from rl4co.utils.rich_utils import enforce_tags, print_config_tree -from rl4co.utils.trainer import RL4COTrainer -from rl4co.utils.utils import ( - extras, - get_metric_value, - log_hyperparameters, - show_versions, - task_wrapper, -) diff --git a/rl4co/utils/callbacks/speed_monitor.py b/rl4co/utils/callbacks/speed_monitor.py deleted file mode 100644 index 3f1ab6ae..00000000 --- a/rl4co/utils/callbacks/speed_monitor.py +++ /dev/null @@ -1,123 +0,0 @@ -# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor -# We only need the speed monitoring, not the GPU monitoring -import time - -import lightning as L - -from lightning.pytorch.callbacks import Callback -from lightning.pytorch.utilities.parsing import AttributeDict -from lightning.pytorch.utilities.rank_zero import rank_zero_only - - -class SpeedMonitor(Callback): - """Monitor the speed of each step and each epoch.""" - - def __init__( - self, - intra_step_time: bool = True, - inter_step_time: bool = True, - epoch_time: bool = True, - verbose=False, - ): - super().__init__() - self._log_stats = AttributeDict( - { - "intra_step_time": intra_step_time, - "inter_step_time": inter_step_time, - "epoch_time": epoch_time, - } - ) - self.verbose = verbose - - def on_train_start(self, trainer: "L.Trainer", L_module: "L.LightningModule") -> None: - self._snap_epoch_time = None - - def on_train_epoch_start( - self, trainer: "L.Trainer", L_module: "L.LightningModule" - ) -> None: - self._snap_intra_step_time = None - self._snap_inter_step_time = None - self._snap_epoch_time = time.time() - - def on_validation_epoch_start( - self, trainer: "L.Trainer", L_module: "L.LightningModule" - ) -> None: - self._snap_inter_step_time = None - - def on_test_epoch_start( - self, trainer: "L.Trainer", L_module: "L.LightningModule" - ) -> None: - self._snap_inter_step_time = None - - @rank_zero_only - def on_train_batch_start( - self, - trainer: "L.Trainer", - *unused_args, - **unused_kwargs, # easy fix for new pytorch lightning versions - ) -> None: - if self._log_stats.intra_step_time: - self._snap_intra_step_time = time.time() - - if not self._should_log(trainer): - return - - logs = {} - if self._log_stats.inter_step_time and self._snap_inter_step_time: - # First log at beginning of second step - logs["time/inter_step (ms)"] = ( - time.time() - self._snap_inter_step_time - ) * 1000 - - if trainer.logger is not None: - trainer.logger.log_metrics(logs, step=trainer.global_step) - - @rank_zero_only - def on_train_batch_end( - self, - trainer: "L.Trainer", - L_module: "L.LightningModule", - *unused_args, - **unused_kwargs, # easy fix for new pytorch lightning versions - ) -> None: - if self._log_stats.inter_step_time: - self._snap_inter_step_time = time.time() - - if ( - self.verbose - and self._log_stats.intra_step_time - and self._snap_intra_step_time - ): - L_module.print( - f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}" - ) - - if not self._should_log(trainer): - return - - logs = {} - if self._log_stats.intra_step_time and self._snap_intra_step_time: - logs["time/intra_step (ms)"] = ( - time.time() - self._snap_intra_step_time - ) * 1000 - - if trainer.logger is not None: - trainer.logger.log_metrics(logs, step=trainer.global_step) - - @rank_zero_only - def on_train_epoch_end( - self, - trainer: "L.Trainer", - L_module: "L.LightningModule", - ) -> None: - logs = {} - if self._log_stats.epoch_time and self._snap_epoch_time: - logs["time/epoch (s)"] = time.time() - self._snap_epoch_time - if trainer.logger is not None: - trainer.logger.log_metrics(logs, step=trainer.global_step) - - @staticmethod - def _should_log(trainer) -> bool: - return ( - trainer.global_step + 1 - ) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/rl4co/utils/decoding.py b/rl4co/utils/decoding.py deleted file mode 100644 index 76a33b82..00000000 --- a/rl4co/utils/decoding.py +++ /dev/null @@ -1,555 +0,0 @@ -import abc - -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - -from tensordict.tensordict import TensorDict - -from rl4co.envs import RL4COEnvBase -from rl4co.utils.ops import batchify, unbatchify, unbatchify_and_gather -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -def get_decoding_strategy(decoding_strategy, **config): - strategy_registry = { - "greedy": Greedy, - "sampling": Sampling, - "multistart_greedy": Greedy, - "multistart_sampling": Sampling, - "beam_search": BeamSearch, - "evaluate": Evaluate, - } - - if decoding_strategy not in strategy_registry: - log.warning( - f"Unknown decode type '{decoding_strategy}'. Available decode types: {strategy_registry.keys()}. Defaulting to Sampling." - ) - - if "multistart" in decoding_strategy: - config["multistart"] = True - - return strategy_registry.get(decoding_strategy, Sampling)(**config) - - -def get_log_likelihood(logprobs, actions, mask=None, return_sum: bool = True): - """Get log likelihood of selected actions. - Note that mask is a boolean tensor where True means the value should be kept. - - Args: - logprobs: Log probabilities of actions from the model (batch_size, seq_len, action_dim). - actions: Selected actions (batch_size, seq_len). - mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). - return_sum: Whether to return the sum of log probabilities or not. Defaults to True. - """ - logprobs = logprobs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) - - # Optional: mask out actions irrelevant to objective so they do not get reinforced - if mask is not None: - logprobs[~mask] = 0 - - assert ( - logprobs > -1000 - ).data.all(), "Logprobs should not be -inf, check sampling procedure!" - - # Calculate log_likelihood - if return_sum: - return logprobs.sum(1) # [batch] - else: - return logprobs # [batch, decode_len] - - -def decode_logprobs(logprobs, mask, decode_type="sampling"): - """Decode log probabilities to select actions with mask. - Note that mask is a boolean tensor where True means the value should be kept. - """ - if "greedy" in decode_type: - selected = DecodingStrategy.greedy(logprobs, mask) - elif "sampling" in decode_type: - selected = DecodingStrategy.sampling(logprobs, mask) - else: - assert False, "Unknown decode type: {}".format(decode_type) - return selected - - -def random_policy(td): - """Helper function to select a random action from available actions""" - action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1) - td.set("action", action) - return td - - -def rollout(env, td, policy, max_steps: int = None): - """Helper function to rollout a policy. Currently, TorchRL does not allow to step - over envs when done with `env.rollout()`. We need this because for environments that complete at different steps. - """ - - max_steps = float("inf") if max_steps is None else max_steps - actions = [] - steps = 0 - - while not td["done"].all(): - td = policy(td) - actions.append(td["action"]) - td = env.step(td)["next"] - steps += 1 - if steps > max_steps: - log.info("Max steps reached") - break - return ( - env.get_reward(td, torch.stack(actions, dim=1)), - td, - torch.stack(actions, dim=1), - ) - - -def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf. Done out-of-place. - Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6 - """ - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - return logits.masked_fill(indices_to_remove, float("-inf")) - - -def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf. Done out-of-place. - Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14 - """ - if top_p <= 0.0 or top_p >= 1.0: - return logits - - # First sort and calculate cumulative sum of probabilities. - sorted_logits, sorted_indices = torch.sort(logits, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= (1 - top_p) - - # Scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter( - -1, sorted_indices, sorted_indices_to_remove - ) - return logits.masked_fill(indices_to_remove, float("-inf")) - - -def process_logits( - logits: torch.Tensor, - mask: torch.Tensor = None, - temperature: float = 1.0, - top_p: float = 0.0, - top_k: int = 0, - tanh_clipping: float = 0, - mask_logits: bool = True, -): - """Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling. - - Note: - We convert to log probabilities instead of probabilities to avoid numerical instability. - This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))), - and avoiding the division by the sum of exponentials can help with numerical stability. - You may check the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html). - - Args: - logits: Logits from the model (batch_size, num_actions). - mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). - temperature: Temperature scaling. Higher values make the distribution more uniform (exploration), - lower values make it more peaky (exploitation). - top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability - less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform. - top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and - do not return all the top-k logits here. - tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). - mask_logits: Whether to mask logits of infeasible actions. - """ - - # Tanh clipping from Bello et al. 2016 - if tanh_clipping > 0: - logits = torch.tanh(logits) * tanh_clipping - - # In RL, we want to mask the logits to prevent the agent from selecting infeasible actions - if mask_logits: - assert mask is not None, "mask must be provided if mask_logits is True" - logits[~mask] = float("-inf") - - logits = logits / temperature # temperature scaling - - if top_k > 0: - top_k = min(top_k, logits.size(-1)) # safety check - logits = modify_logits_for_top_k_filtering(logits, top_k) - - if top_p > 0: - assert top_p <= 1.0, "top-p should be in (0, 1]." - logits = modify_logits_for_top_p_filtering(logits, top_p) - - # Compute log probabilities - return F.log_softmax(logits, dim=-1) - - -class DecodingStrategy(metaclass=abc.ABCMeta): - """Base class for decoding strategies. Subclasses should implement the :meth:`_step` method. - Includes hooks for pre and post main decoding operations. - - Args: - temperature: Temperature scaling. Higher values make the distribution more uniform (exploration), - lower values make it more peaky (exploitation). Defaults to 1.0. - top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Defaults to 0.0. - top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Defaults to 0. - mask_logits: Whether to mask logits of infeasible actions. Defaults to True. - tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0. - multistart: Whether to use multistart decoding. Defaults to False. - num_starts: Number of starts for multistart decoding. Defaults to None. - """ - - name = "base" - - def __init__( - self, - temperature: float = 1.0, - top_p: float = 0.0, - top_k: int = 0, - mask_logits: bool = True, - tanh_clipping: float = 0, - multistart: bool = False, - num_starts: Optional[int] = None, - select_start_nodes_fn: Optional[callable] = None, - select_best: bool = False, - **kwargs, - ) -> None: - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.mask_logits = mask_logits - self.tanh_clipping = tanh_clipping - self.multistart = multistart - self.num_starts = num_starts - self.select_start_nodes_fn = select_start_nodes_fn - self.select_best = select_best - # initialize buffers - self.actions = [] - self.logprobs = [] - - @abc.abstractmethod - def _step( - self, - logprobs: torch.Tensor, - mask: torch.Tensor, - td: TensorDict, - action: torch.Tensor = None, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: - """Main decoding operation. This method should be called in a loop until all sequences are done. - - Args: - logprobs: Log probabilities processed from logits of the model. - mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). - td: TensorDict containing the current state of the environment. - action: Optional action to use, e.g. for evaluating log probabilities. - """ - raise NotImplementedError("Must be implemented by subclass") - - def pre_decoder_hook( - self, td: TensorDict, env: RL4COEnvBase, action: torch.Tensor = None - ): - """Pre decoding hook. This method is called before the main decoding operation.""" - # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask - if self.multistart: - if self.num_starts is None: - self.num_starts = env.get_num_starts(td) - else: - if self.num_starts is not None: - if self.num_starts >= 1: - log.warn( - f"num_starts={self.num_starts} is ignored for decode_type={self.name}" - ) - - self.num_starts = 0 - - # Multi-start decoding: first action is chosen by ad-hoc node selection - if self.num_starts >= 1: - if action is None: # if action is provided, we use it as the first action - if self.select_start_nodes_fn is not None: - action = self.select_start_nodes_fn(td, env, self.num_starts) - else: - action = env.select_start_nodes(td, num_starts=self.num_starts) - - # Expand td to batch_size * num_starts - td = batchify(td, self.num_starts) - - td.set("action", action) - td = env.step(td)["next"] - logprobs = torch.zeros_like( - td["action_mask"], device=td.device - ) # first logprobs is 0, so p = logprobs.exp() = 1 - - self.logprobs.append(logprobs) - self.actions.append(action) - - return td, env, self.num_starts - - def post_decoder_hook( - self, td: TensorDict, env: RL4COEnvBase - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict, RL4COEnvBase]: - assert ( - len(self.logprobs) > 0 - ), "No logprobs were collected because all environments were done. Check your initial state" - logprobs = torch.stack(self.logprobs, 1) - actions = torch.stack(self.actions, 1) - if self.num_starts > 0 and self.select_best: - logprobs, actions, td, env = self._select_best(logprobs, actions, td, env) - return logprobs, actions, td, env - - def step( - self, - logits: torch.Tensor, - mask: torch.Tensor, - td: TensorDict, - action: torch.Tensor = None, - **kwargs, - ) -> TensorDict: - """Main decoding operation. This method should be called in a loop until all sequences are done. - - Args: - logits: Logits from the model. - mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). - td: TensorDict containing the current state of the environment. - action: Optional action to use, e.g. for evaluating log probabilities. - """ - if not self.mask_logits: # set mask_logit to None if mask_logits is False - mask = None - - logprobs = process_logits( - logits, - mask, - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - tanh_clipping=self.tanh_clipping, - mask_logits=self.mask_logits, - ) - logprobs, selected_action, td = self._step( - logprobs, mask, td, action=action, **kwargs - ) - td.set("action", selected_action) - self.actions.append(selected_action) - self.logprobs.append(logprobs) - return td - - @staticmethod - def greedy(logprobs, mask=None): - """Select the action with the highest probability.""" - # [BS], [BS] - selected = logprobs.argmax(dim=-1) - if mask is not None: - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" - - return selected - - @staticmethod - def sampling(logprobs, mask=None): - """Sample an action with a multinomial distribution given by the log probabilities.""" - probs = logprobs.exp() - selected = torch.multinomial(probs, 1).squeeze(1) - - if mask is not None: - while (~mask).gather(1, selected.unsqueeze(-1)).data.any(): - log.info("Sampled bad values, resampling!") - selected = probs.multinomial(1).squeeze(1) - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" - - return selected - - def _select_best(self, logprobs, actions, td: TensorDict, env: RL4COEnvBase): - rewards = env.get_reward(td, actions) - _, max_idxs = unbatchify(rewards, self.num_starts).max(dim=-1) - - actions = unbatchify_and_gather(actions, max_idxs, self.num_starts) - logprobs = unbatchify_and_gather(logprobs, max_idxs, self.num_starts) - td = unbatchify_and_gather(td, max_idxs, self.num_starts) - - return logprobs, actions, td, env - - -class Greedy(DecodingStrategy): - name = "greedy" - - def _step( - self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: - """Select the action with the highest log probability""" - selected = self.greedy(logprobs, mask) - return logprobs, selected, td - - -class Sampling(DecodingStrategy): - name = "sampling" - - def _step( - self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: - """Sample an action with a multinomial distribution given by the log probabilities.""" - selected = self.sampling(logprobs, mask) - return logprobs, selected, td - - -class Evaluate(DecodingStrategy): - name = "evaluate" - - def _step( - self, - logprobs: torch.Tensor, - mask: torch.Tensor, - td: TensorDict, - action: torch.Tensor, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: - """The action is provided externally, so we just return the action""" - selected = action - return logprobs, selected, td - - -class BeamSearch(DecodingStrategy): - name = "beam_search" - - def __init__(self, beam_width=None, select_best=True, **kwargs) -> None: - super().__init__(**kwargs) - self.beam_width = beam_width - self.select_best = select_best - self.parent_beam_logprobs = None - self.beam_path = [] - - def _step( - self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: - selected, batch_beam_idx = self._make_beam_step(logprobs) - # select the correct state representation, logprobs and mask according to beam parent - td = td[batch_beam_idx] - logprobs = logprobs[batch_beam_idx] - mask = mask[batch_beam_idx] - - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" - - return logprobs, selected, td - - def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs): - if self.beam_width is None: - self.beam_width = env.get_num_starts(td) - assert self.beam_width > 1, "beam width must be larger than 1" - - # select start nodes. TODO: include first step in beam search as well - if self.select_start_nodes_fn is not None: - action = self.select_start_nodes_fn(td, env, self.beam_width) - else: - action = env.select_start_nodes(td, num_starts=self.beam_width) - - # Expand td to batch_size * beam_width - td = batchify(td, self.beam_width) - - td.set("action", action) - td = env.step(td)["next"] - - logprobs = torch.zeros_like(td["action_mask"], device=td.device) - beam_parent = torch.zeros(logprobs.size(0), device=td.device, dtype=torch.int32) - - self.logprobs.append(logprobs) - self.actions.append(action) - self.parent_beam_logprobs = logprobs.gather(1, action[..., None]) - self.beam_path.append(beam_parent) - - return td, env, self.beam_width - - def post_decoder_hook(self, td, env): - # [BS*BW, seq_len] - aligned_sequences, aligned_logprobs = self._backtrack() - - if self.select_best: - return self._select_best_beam(aligned_logprobs, aligned_sequences, td, env) - else: - return aligned_logprobs, aligned_sequences, td, env - - def _backtrack(self): - # [BS*BW, seq_len] - actions = torch.stack(self.actions, 1) - # [BS*BW, seq_len] - logprobs = torch.stack(self.logprobs, 1) - assert actions.size(1) == len( - self.beam_path - ), "action idx shape and beam path shape dont match" - - # [BS*BW] - cur_parent = self.beam_path[-1] - # [BS*BW] - reversed_aligned_sequences = [actions[:, -1]] - reversed_aligned_logprobs = [logprobs[:, -1]] - - aug_batch_size = actions.size(0) - batch_size = aug_batch_size // self.beam_width - batch_beam_sequence = ( - torch.arange(0, batch_size).repeat(self.beam_width).to(actions.device) - ) - - for k in reversed(range(len(self.beam_path) - 1)): - batch_beam_idx = batch_beam_sequence + cur_parent * batch_size - - reversed_aligned_sequences.append(actions[batch_beam_idx, k]) - reversed_aligned_logprobs.append(logprobs[batch_beam_idx, k]) - cur_parent = self.beam_path[k][batch_beam_idx] - - # [BS*BW, seq_len*num_targets] - actions = torch.stack(list(reversed(reversed_aligned_sequences)), dim=1) - logprobs = torch.stack(list(reversed(reversed_aligned_logprobs)), dim=1) - - return actions, logprobs - - def _select_best_beam(self, logprobs, actions, td: TensorDict, env: RL4COEnvBase): - aug_batch_size = logprobs.size(0) # num nodes - batch_size = aug_batch_size // self.beam_width - rewards = env.get_reward(td, actions) - _, idx = torch.cat(rewards.unsqueeze(1).split(batch_size), 1).max(1) - flat_idx = torch.arange(batch_size, device=rewards.device) + idx * batch_size - return logprobs[flat_idx], actions[flat_idx], td[flat_idx], env - - def _make_beam_step(self, logprobs: torch.Tensor): - aug_batch_size, num_nodes = logprobs.shape # num nodes - batch_size = aug_batch_size // self.beam_width - batch_beam_sequence = ( - torch.arange(0, batch_size).repeat(self.beam_width).to(logprobs.device) - ) - - # [BS*BW, num_nodes] + [BS*BW, 1] -> [BS*BW, num_nodes] - log_beam_prob = logprobs + self.parent_beam_logprobs # - - # [BS, num_nodes * BW] - log_beam_prob_hstacked = torch.cat(log_beam_prob.split(batch_size), dim=1) - # [BS, BW] - topk_logprobs, topk_ind = torch.topk( - log_beam_prob_hstacked, self.beam_width, dim=1 - ) - - # [BS*BW, 1] - logprobs_selected = torch.hstack(torch.unbind(topk_logprobs, 1)).unsqueeze(1) - - # [BS*BW, 1] - topk_ind = torch.hstack(torch.unbind(topk_ind, 1)) - - # since we stack the logprobs from the distinct branches, the indices in - # topk dont correspond to node indices directly and need to be translated - selected = topk_ind % num_nodes # determine node index - - # calc parent this branch comes from - beam_parent = (topk_ind // num_nodes).int() - - batch_beam_idx = batch_beam_sequence + beam_parent * batch_size - - self.parent_beam_logprobs = logprobs_selected - self.beam_path.append(beam_parent) - - return selected, batch_beam_idx diff --git a/rl4co/utils/instantiators.py b/rl4co/utils/instantiators.py deleted file mode 100644 index e3b25183..00000000 --- a/rl4co/utils/instantiators.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import List - -import hydra - -from lightning import Callback -from lightning.pytorch.loggers import Logger -from omegaconf import DictConfig - -from rl4co.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger diff --git a/rl4co/utils/lightning.py b/rl4co/utils/lightning.py deleted file mode 100644 index a3f29cb7..00000000 --- a/rl4co/utils/lightning.py +++ /dev/null @@ -1,76 +0,0 @@ -import os - -import lightning as L -import torch - -from omegaconf import DictConfig - -# from rl4co. -from rl4co.utils.pylogger import get_pylogger - -log = get_pylogger(__name__) - - -def get_lightning_device(lit_module: L.LightningModule) -> torch.device: - """Get the device of the Lightning module before setup is called - See device setting issue in setup https://github.com/Lightning-AI/lightning/issues/2638 - """ - try: - if lit_module.trainer.strategy.root_device != lit_module.device: - return lit_module.trainer.strategy.root_device - return lit_module.device - except Exception: - return lit_module.device - - -def remove_key(config, key="wandb"): - """Remove keys containing 'key`""" - new_config = {} - for k, v in config.items(): - if key in k: - continue - else: - new_config[k] = v - return new_config - - -def clean_hydra_config( - config, keep_value_only=True, remove_keys="wandb", clean_cfg_path=True -): - """Clean hydra config by nesting dictionary and cleaning values""" - # Remove keys containing `remove_keys` - if not isinstance(remove_keys, list): - remove_keys = [remove_keys] - for key in remove_keys: - config = remove_key(config, key=key) - - new_config = {} - # Iterate over config dictionary - for key, value in config.items(): - # If key contains slash, split it and create nested dictionary recursively - if "/" in key: - keys = key.split("/") - d = new_config - for k in keys[:-1]: - d = d.setdefault(k, {}) - d[keys[-1]] = value["value"] if keep_value_only else value - else: - new_config[key] = value["value"] if keep_value_only else value - - cfg = DictConfig(new_config) - - if clean_cfg_path: - # Clean cfg_path recursively substituting root_dir with cwd - root_dir = cfg.paths.root_dir - - def replace_dir_recursive(d, search, replace): - for k, v in d.items(): - if isinstance(v, dict) or isinstance(v, DictConfig): - replace_dir_recursive(v, search, replace) - elif isinstance(v, str): - if search in v: - d[k] = v.replace(search, replace) - - replace_dir_recursive(cfg, root_dir, os.getcwd()) - - return cfg diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py deleted file mode 100644 index d59592e7..00000000 --- a/rl4co/utils/ops.py +++ /dev/null @@ -1,268 +0,0 @@ -from functools import lru_cache -from typing import Optional, Union - -import torch - -from einops import rearrange -from tensordict import TensorDict -from torch import Tensor - - -def _batchify_single( - x: Union[Tensor, TensorDict], repeats: int -) -> Union[Tensor, TensorDict]: - """Same as repeat on dim=0 for Tensordicts as well""" - s = x.shape - return x.expand(repeats, *s).contiguous().view(s[0] * repeats, *s[1:]) - - -def batchify( - x: Union[Tensor, TensorDict], shape: Union[tuple, int] -) -> Union[Tensor, TensorDict]: - """Same as `einops.repeat(x, 'b ... -> (b r) ...', r=repeats)` but ~1.5x faster and supports TensorDicts. - Repeats batchify operation `n` times as specified by each shape element. - If shape is a tuple, iterates over each element and repeats that many times to match the tuple shape. - - Example: - >>> x.shape: [a, b, c, ...] - >>> shape: [a, b, c] - >>> out.shape: [a*b*c, ...] - """ - shape = [shape] if isinstance(shape, int) else shape - for s in reversed(shape): - x = _batchify_single(x, s) if s > 0 else x - return x - - -def _unbatchify_single( - x: Union[Tensor, TensorDict], repeats: int -) -> Union[Tensor, TensorDict]: - """Undoes batchify operation for Tensordicts as well""" - s = x.shape - return x.view(repeats, s[0] // repeats, *s[1:]).permute(1, 0, *range(2, len(s) + 1)) - - -def unbatchify( - x: Union[Tensor, TensorDict], shape: Union[tuple, int] -) -> Union[Tensor, TensorDict]: - """Same as `einops.rearrange(x, '(r b) ... -> b r ...', r=repeats)` but ~2x faster and supports TensorDicts - Repeats unbatchify operation `n` times as specified by each shape element - If shape is a tuple, iterates over each element and unbatchifies that many times to match the tuple shape. - - Example: - >>> x.shape: [a*b*c, ...] - >>> shape: [a, b, c] - >>> out.shape: [a, b, c, ...] - """ - shape = [shape] if isinstance(shape, int) else shape - for s in reversed( - shape - ): # we need to reverse the shape to unbatchify in the right order - x = _unbatchify_single(x, s) if s > 0 else x - return x - - -def gather_by_index(src, idx, dim=1, squeeze=True): - """Gather elements from src by index idx along specified dim - - Example: - >>> src: shape [64, 20, 2] - >>> idx: shape [64, 3)] # 3 is the number of idxs on dim 1 - >>> Returns: [64, 3, 2] # get the 3 elements from src at idx - """ - expanded_shape = list(src.shape) - expanded_shape[dim] = -1 - idx = idx.view(idx.shape + (1,) * (src.dim() - idx.dim())).expand(expanded_shape) - return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) - - -def unbatchify_and_gather(x: Tensor, idx: Tensor, n: int): - """first unbatchify a tensor by n and then gather (usually along the unbatchified dimension) - by the specified index - """ - x = unbatchify(x, n) - return gather_by_index(x, idx, dim=idx.dim()) - - -@torch.jit.script -def get_distance(x: Tensor, y: Tensor): - """Euclidean distance between two tensors of shape `[..., n, dim]`""" - return (x - y).norm(p=2, dim=-1) - - -@torch.jit.script -def get_tour_length(ordered_locs): - """Compute the total tour distance for a batch of ordered tours. - Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up. - - Args: - ordered_locs: Tensor of shape [batch_size, num_nodes, 2] containing the ordered locations of the tour - """ - ordered_locs_next = torch.roll(ordered_locs, -1, dims=-2) - return get_distance(ordered_locs_next, ordered_locs).sum(-1) - -@torch.jit.script -def get_open_tour_length(ordered_locs): - """Compute the total tour distance for a batch of ordered tours. - Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up. - - Args: - ordered_locs: Tensor of shape [batch_size, num_nodes, 2] containing the ordered locations of the tour - """ - ordered_locs_next = torch.roll(ordered_locs, 1, dims=-2) - - segment_lengths = ((ordered_locs_next-ordered_locs)**2).sum(-1).sqrt() - - # Get the first value of ordered_locs - first_loc = ordered_locs[:, 0, :][:,None,:].expand(ordered_locs_next.shape) - - # Check the ids where the location is the same as the first value - same_loc_ids = torch.all(ordered_locs_next == first_loc, dim=-1) - - # for open VRP, the distance between last customer and the depot is not counted - segment_lengths[same_loc_ids] = 0 - - travel_distances = segment_lengths.sum(1) - - return travel_distances - -@torch.jit.script -def get_distance_matrix(locs: Tensor): - """Compute the euclidean distance matrix for the given coordinates. - - Args: - locs: Tensor of shape [..., n, dim] - """ - distance = (locs[..., :, None, :] - locs[..., None, :, :]).norm(p=2, dim=-1) - return distance - - -def calculate_entropy(logprobs: Tensor): - """Calculate the entropy of the log probabilities distribution - logprobs: Tensor of shape [batch, decoder_steps, num_actions] - """ - logprobs = torch.nan_to_num(logprobs, nan=0.0) - entropy = -(logprobs.exp() * logprobs).sum(dim=-1) # [batch, decoder steps] - entropy = entropy.sum(dim=1) # [batch] -- sum over decoding steps - assert entropy.isfinite().all(), "Entropy is not finite" - return entropy - - -# TODO: modularize inside the envs -def get_num_starts(td, env_name=None): - """Returns the number of possible start nodes for the environment based on the action mask""" - num_starts = td["action_mask"].shape[-1] - if env_name == "pdp": - num_starts = ( - num_starts - 1 - ) // 2 # only half of the nodes (i.e. pickup nodes) can be start nodes - elif env_name in ["cvrp", "cvrptw", "sdvrp", "mtsp", "op", "pctsp", "spctsp"]: - num_starts = num_starts - 1 # depot cannot be a start node - - return num_starts - - -def select_start_nodes(td, env, num_starts): - """Node selection strategy as proposed in POMO (Kwon et al. 2020) - and extended in SymNCO (Kim et al. 2022). - Selects different start nodes for each batch element - - Args: - td: TensorDict containing the data. We may need to access the available actions to select the start nodes - env: Environment may determine the node selection strategy - num_starts: Number of nodes to select. This may be passed when calling the policy directly. See :class:`rl4co.models.AutoregressiveDecoder` - """ - num_loc = env.generator.num_loc if hasattr(env.generator, "num_loc") else 0xFFFFFFFF - if env.name in ["tsp", "atsp"]: - selected = ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc - ) - elif env.name == "fjsp": - raise NotImplementedError("Multistart not yet supported for FJSP") - else: - # Environments with depot: we do not select the depot as a start node - selected = ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc - + 1 - ) - if env.name == "op": - if (td["action_mask"][..., 1:].float().sum(-1) < num_starts).any(): - # for the orienteering problem, we may have some nodes that are not available - # so we need to resample from the distribution of available nodes - selected = ( - torch.multinomial( - td["action_mask"][..., 1:].float(), num_starts, replacement=True - ) - + 1 - ) # re-add depot index - selected = rearrange(selected, "b n -> (n b)") - return selected - - -def get_best_actions(actions, max_idxs): - actions = unbatchify(actions, max_idxs.shape[0]) - return actions.gather(0, max_idxs[..., None, None]) - - -def sparsify_graph(cost_matrix: Tensor, k_sparse: Optional[int] = None, self_loop=False): - """Generate a sparsified graph for the cost_matrix by selecting k edges with the lowest cost for each node. - - Args: - cost_matrix: Tensor of shape [m, n] - k_sparse: Number of edges to keep for each node. Defaults to max(n//5, 10) if not provided. - self_loop: Include self-loop edges in the generated graph when m==n. Defaults to False. - """ - m, n = cost_matrix.shape - k_sparse = max(n // 5, 10) if k_sparse is None else k_sparse - - # fill diagonal value with +inf to exclude them from topk results - if not self_loop and m == n: - # k_sparse should not exceed n-1 in this occasion - k_sparse = min(k_sparse, n - 1) - cost_matrix.fill_diagonal_(torch.inf) - - # select top-k edges with least cost - topk_values, topk_indices = torch.topk( - cost_matrix, k=k_sparse, dim=-1, largest=False, sorted=False - ) - - # generate PyG-compatiable edge_index - edge_index_u = torch.repeat_interleave( - torch.arange(m, device=cost_matrix.device), topk_indices.shape[1] - ) - edge_index_v = topk_indices.flatten() - edge_index = torch.stack([edge_index_u, edge_index_v]) - - edge_attr = topk_values.flatten().unsqueeze(-1) - return edge_index, edge_attr - - -@lru_cache(5) -def get_full_graph_edge_index(num_node: int, self_loop=False) -> Tensor: - adj_matrix = torch.ones(num_node, num_node) - if not self_loop: - adj_matrix.fill_diagonal_(0) - edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) - return edge_index - - -def sample_n_random_actions(td: TensorDict, n: int): - """Helper function to sample n random actions from available actions. If - number of valid actions is less then n, we sample with replacement from the - valid actions - """ - action_mask = td["action_mask"] - # check whether to use replacement or not - n_valid_actions = torch.sum(action_mask[:, 1:], 1).min() - if n_valid_actions < n: - replace = True - else: - replace = False - ps = torch.rand((action_mask.shape)) - ps[~action_mask] = -torch.inf - ps = torch.softmax(ps, dim=1) - selected = torch.multinomial(ps, n, replacement=replace).squeeze(1) - selected = rearrange(selected, "b n -> (n b)") - return selected.to(td.device) diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py deleted file mode 100644 index 46367a37..00000000 --- a/rl4co/utils/optim_helpers.py +++ /dev/null @@ -1,38 +0,0 @@ -import inspect - -import torch -from torch.optim import Optimizer - - -def get_pytorch_lr_schedulers(): - """Get all learning rate schedulers from `torch.optim.lr_scheduler`""" - return torch.optim.lr_scheduler.__all__ - - -def get_pytorch_optimizers(): - """Get all optimizers from `torch.optim`""" - optimizers = [] - for name, obj in inspect.getmembers(torch.optim): - if inspect.isclass(obj) and issubclass(obj, Optimizer): - optimizers.append(name) - return optimizers - - -def create_optimizer(parameters, optimizer_name: str, **optimizer_kwargs) -> Optimizer: - """Create optimizer for model. If `optimizer_name` is not found, raise ValueError.""" - if optimizer_name in get_pytorch_optimizers(): - optimizer_cls = getattr(torch.optim, optimizer_name) - return optimizer_cls(parameters, **optimizer_kwargs) - else: - raise ValueError(f"Optimizer {optimizer_name} not found.") - - -def create_scheduler( - optimizer: Optimizer, scheduler_name: str, **scheduler_kwargs -) -> torch.optim.lr_scheduler.LRScheduler: - """Create scheduler for optimizer. If `scheduler_name` is not found, raise ValueError.""" - if scheduler_name in get_pytorch_lr_schedulers(): - scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_name) - return scheduler_cls(optimizer, **scheduler_kwargs) - else: - raise ValueError(f"Scheduler {scheduler_name} not found.") diff --git a/rl4co/utils/param_grouping.py b/rl4co/utils/param_grouping.py deleted file mode 100644 index 529205a9..00000000 --- a/rl4co/utils/param_grouping.py +++ /dev/null @@ -1,138 +0,0 @@ -import inspect - -import hydra -import torch.nn as nn - - -def group_parameters_for_optimizer( - model, optimizer_cfg, bias_weight_decay=False, normalization_weight_decay=False -): - """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with - attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for - normalization parameters if normalization_weight_decay==False - """ - # Get the weight decay from the config, or from the default value of the optimizer constructor - # if it's not specified in the config. - if "weight_decay" in optimizer_cfg: - weight_decay = optimizer_cfg.weight_decay - else: - # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value - signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) - if "weight_decay" in signature.parameters: - weight_decay = signature.parameters["weight_decay"].default - if weight_decay is inspect.Parameter.empty: - weight_decay = 0.0 - else: - weight_decay = 0.0 - - # If none of the parameters have weight decay anyway, and there are no parameters with special - # optimization params - if weight_decay == 0.0 and not any(hasattr(p, "_optim") for p in model.parameters()): - return model.parameters() - - skip = model.no_weight_decay() if hasattr(model, "no_weight_decay") else set() - skip_keywords = ( - model.no_weight_decay_keywords() - if hasattr(model, "no_weight_decay_keywords") - else set() - ) - - # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - special = set() - whitelist_weight_modules = (nn.Linear,) - blacklist_weight_modules = nn.Embedding - if not normalization_weight_decay: - blacklist_weight_modules += ( - nn.BatchNorm1d, - nn.BatchNorm2d, - nn.BatchNorm3d, - nn.LazyBatchNorm1d, - nn.LazyBatchNorm2d, - nn.LazyBatchNorm3d, - nn.GroupNorm, - nn.SyncBatchNorm, - nn.InstanceNorm1d, - nn.InstanceNorm2d, - nn.InstanceNorm3d, - nn.LayerNorm, - nn.LocalResponseNorm, - ) - - param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn # full param name - # In case of parameter sharing, some parameters show up here but are not in - # param_dict.keys() - if not p.requires_grad or fpn not in param_dict: - continue # frozen weights - if hasattr(p, "_optim"): - special.add(fpn) - elif fpn in skip or any( - skip_keyword in fpn for skip_keyword in skip_keywords - ): - no_decay.add(fpn) - elif getattr(p, "_no_weight_decay", False): - no_decay.add(fpn) - elif not bias_weight_decay and pn.endswith("bias"): - no_decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # special case the position embedding parameter in the root GPT module as not decayed - if "pos_emb" in param_dict: - no_decay.add("pos_emb") - - decay |= param_dict.keys() - no_decay - special - # validate that we considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert ( - len(param_dict.keys() - special - union_params) == 0 - ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - - if weight_decay == 0.0 or not no_decay: - param_groups = [ - { - "params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], - "weight_decay": weight_decay, - } - ] - else: - param_groups = [ - { - "params": [param_dict[pn] for pn in sorted(list(decay))], - "weight_decay": weight_decay, - }, - { - "params": [param_dict[pn] for pn in sorted(list(no_decay))], - "weight_decay": 0.0, - }, - ] - # Add parameters with special hyperparameters - # Unique dicts - hps = [ - dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special) - ] - for hp in hps: - params = [param_dict[pn] for pn in special if param_dict[pn]._optim == hp] - param_groups.append({"params": params, **hp}) - - return param_groups diff --git a/rl4co/utils/pylogger.py b/rl4co/utils/pylogger.py deleted file mode 100644 index aa1b5f1a..00000000 --- a/rl4co/utils/pylogger.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging - -from lightning.pytorch.utilities.rank_zero import rank_zero_only - - -def get_pylogger(name=__name__) -> logging.Logger: - """Initializes multi-GPU-friendly python command line logger.""" - - logger = logging.getLogger(name) - - # this ensures all logging levels get marked with the rank zero decorator - # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ( - "debug", - "info", - "warning", - "error", - "exception", - "fatal", - "critical", - ) - for level in logging_levels: - setattr(logger, level, rank_zero_only(getattr(logger, level))) - - return logger diff --git a/rl4co/utils/rich_utils.py b/rl4co/utils/rich_utils.py deleted file mode 100644 index 652ba568..00000000 --- a/rl4co/utils/rich_utils.py +++ /dev/null @@ -1,97 +0,0 @@ -from pathlib import Path -from typing import Sequence - -import rich -import rich.syntax -import rich.tree - -from hydra.core.hydra_config import HydraConfig -from lightning.pytorch.utilities.rank_zero import rank_zero_only -from omegaconf import DictConfig, OmegaConf, open_dict -from rich.prompt import Prompt - -from rl4co.utils.utils import pylogger - -log = pylogger.get_pylogger(__name__) - - -@rank_zero_only -def print_config_tree( - cfg: DictConfig, - print_order: Sequence[str] = ( - # "data", # note: data is dealt with in model - "model", - "callbacks", - "logger", - "trainer", - "paths", - "extras", - ), - resolve: bool = True, - save_to_file: bool = False, -) -> None: - """Prints content of DictConfig using Rich library and its tree structure. - Args: - cfg (DictConfig): Configuration composed by Hydra. - print_order (Sequence[str], optional): Determines in what order config components are printed. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - save_to_file (bool, optional): Whether to export config to the hydra output folder. - """ - - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - queue = [] - - # add fields from `print_order` to queue - for field in print_order: - queue.append(field) if field in cfg else log.warning( - f"Field '{field}' not found in config. Skipping '{field}' config printing..." - ) - - # add all the other fields to queue (not specified in `print_order`) - for field in cfg: - if field not in queue: - queue.append(field) - - # generate config tree from queue - for field in queue: - branch = tree.add(field, style=style, guide_style=style) - - config_group = cfg[field] - if isinstance(config_group, DictConfig): - branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) - else: - branch_content = str(config_group) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - # print config tree - rich.print(tree) - - # save config tree to file - if save_to_file: - with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: - rich.print(tree, file=file) - - -@rank_zero_only -def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config.""" - - if not cfg.get("tags"): - if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") - - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") - tags = [t.strip() for t in tags.split(",") if t != ""] - - with open_dict(cfg): - cfg.tags = tags - - log.info(f"Tags: {cfg.tags}") - - if save_to_file: - with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: - rich.print(cfg.tags, file=file) diff --git a/rl4co/utils/test_utils.py b/rl4co/utils/test_utils.py deleted file mode 100644 index baa7f31d..00000000 --- a/rl4co/utils/test_utils.py +++ /dev/null @@ -1,62 +0,0 @@ -from torch.utils.data import DataLoader - -from rl4co.envs import ( - CVRPEnv, - CVRPTWEnv, - DPPEnv, - MDPPEnv, - MTSPEnv, - OPEnv, - PCTSPEnv, - PDPEnv, - SDVRPEnv, - SMTWTPEnv, - SPCTSPEnv, - TSPEnv, -) - - -def get_env(name, size): - if name == "tsp": - env = TSPEnv(generator_params=dict(num_loc=size)) - elif name == "cvrp": - env = CVRPEnv(generator_params=dict(num_loc=size)) - elif name == "cvrptw": - env = CVRPTWEnv(generator_params=dict(num_loc=size)) - elif name == "sdvrp": - env = SDVRPEnv(generator_params=dict(num_loc=size)) - elif name == "pdp": - env = PDPEnv(generator_params=dict(num_loc=size)) - elif name == "op": - env = OPEnv(generator_params=dict(num_loc=size)) - elif name == "mtsp": - env = MTSPEnv(generator_params=dict(num_loc=size)) - elif name == "pctsp": - env = PCTSPEnv(generator_params=dict(num_loc=size)) - elif name == "spctsp": - env = SPCTSPEnv(generator_params=dict(num_loc=size)) - elif name == "dpp": - env = DPPEnv() - elif name == "mdpp": - env = MDPPEnv() - elif name == "smtwtp": - env = SMTWTPEnv() - else: - raise ValueError(f"Unknown env_name: {name}") - - return env.transform() - - -def generate_env_data(env, size, batch_size): - env = get_env(env, size) - dataset = env.dataset([batch_size]) - - dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=0, - collate_fn=dataset.collate_fn, - ) - - return env, next(iter(dataloader)) diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py deleted file mode 100644 index 0ad10fa1..00000000 --- a/rl4co/utils/trainer.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Iterable, List, Optional, Union - -import lightning.pytorch as pl -import torch - -from lightning import Callback, Trainer -from lightning.fabric.accelerators.cuda import num_cuda_devices -from lightning.pytorch.accelerators import Accelerator -from lightning.pytorch.core.datamodule import LightningDataModule -from lightning.pytorch.loggers import Logger -from lightning.pytorch.strategies import DDPStrategy, Strategy -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS - -from rl4co import utils - -log = utils.get_pylogger(__name__) - - -class RL4COTrainer(Trainer): - """Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. - - Note: - The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. - This allows for datasets to be re-created on the run and distributed by Lightning across - devices on each epoch. Setting to a value different than 1 may lead to overfitting to a - specific (such as the initial) data distribution. - - Args: - accelerator: hardware accelerator to use. - callbacks: list of callbacks. - logger: logger (or iterable collection of loggers) for experiment tracking. - min_epochs: minimum number of training epochs. - max_epochs: maximum number of training epochs. - strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). - devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. - gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. - precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). - This also allows to use `FlashAttention` by default. - disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. - auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. - reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. - matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision - **kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. - """ - - def __init__( - self, - accelerator: Union[str, Accelerator] = "auto", - callbacks: Optional[List[Callback]] = None, - logger: Optional[Union[Logger, Iterable[Logger]]] = None, - min_epochs: Optional[int] = None, - max_epochs: Optional[int] = None, - strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", - gradient_clip_val: Union[int, float] = 1.0, - precision: Union[str, int] = "16-mixed", - reload_dataloaders_every_n_epochs: int = 1, - disable_profiling_executor: bool = True, - auto_configure_ddp: bool = True, - matmul_precision: Union[str, int] = "medium", - **kwargs, - ): - # Disable JIT profiling executor. This reduces memory and increases speed. - # Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 - if disable_profiling_executor: - try: - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - except AttributeError: - pass - - # Configure DDP automatically if multiple GPUs are available - if auto_configure_ddp and strategy == "auto": - if devices == "auto": - n_devices = num_cuda_devices() - elif isinstance(devices, Iterable): - n_devices = len(devices) - else: - n_devices = devices - if n_devices > 1: - log.info( - "Configuring DDP strategy automatically with {} GPUs".format( - n_devices - ) - ) - strategy = DDPStrategy( - find_unused_parameters=True, # We set to True due to RL envs - gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations - ) - - # Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision - if matmul_precision is not None: - torch.set_float32_matmul_precision(matmul_precision) - - # Check if gradient_clip_val is set to None - if gradient_clip_val is None: - log.warning( - "gradient_clip_val is set to None. This may lead to unstable training." - ) - - # We should reload dataloaders every epoch for RL training - if reload_dataloaders_every_n_epochs != 1: - log.warning( - "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " - + "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." - ) - - # Main call to `Trainer` superclass - super().__init__( - accelerator=accelerator, - callbacks=callbacks, - logger=logger, - min_epochs=min_epochs, - max_epochs=max_epochs, - strategy=strategy, - gradient_clip_val=gradient_clip_val, - devices=devices, - precision=precision, - reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, - **kwargs, - ) - - def fit( - self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[str] = None, - ) -> None: - """ - We override the `fit` method to automatically apply and handle RL4CO magic - to 'self.automatic_optimization = False' models, such as PPO - - It behaves exactly like the original `fit` method, but with the following changes: - - if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None - """ - - if not model.automatic_optimization: - if self.gradient_clip_val is not None: - log.warning( - "Overriding gradient_clip_val to None for 'automatic_optimization=False' models" - ) - self.gradient_clip_val = None - - super().fit( - model=model, - train_dataloaders=train_dataloaders, - val_dataloaders=val_dataloaders, - datamodule=datamodule, - ckpt_path=ckpt_path, - ) diff --git a/rl4co/utils/utils.py b/rl4co/utils/utils.py deleted file mode 100644 index 547e0e2d..00000000 --- a/rl4co/utils/utils.py +++ /dev/null @@ -1,287 +0,0 @@ -import importlib -import platform -import sys -import warnings - -from importlib.util import find_spec -from typing import Callable, List - -import hydra - -from lightning import Callback -from lightning.pytorch.loggers.logger import Logger - -# Import the necessary PyTorch Lightning component -from lightning.pytorch.trainer.connectors.accelerator_connector import ( - _AcceleratorConnector, -) -from lightning.pytorch.utilities.rank_zero import rank_zero_only -from omegaconf import DictConfig - -from rl4co.utils import pylogger, rich_utils - -log = pylogger.get_pylogger(__name__) - - -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that wraps the task function in extra utilities. - - Makes multirun more resistant to failure. - - Utilities: - - Calling the `utils.extras()` before the task is started - - Calling the `utils.close_loggers()` after the task is finished or failed - - Logging the exception if occurs - - Logging the output dir - """ - - def wrap(cfg: DictConfig): - # execute the task - try: - # apply extra utilities - extras(cfg) - - metric_dict, object_dict = task_func(cfg=cfg) - - # things to do if exception occurs - except Exception as ex: - # save exception to `.log` file - log.exception("") - - # when using hydra plugins like Optuna, you might want to disable raising exception - # to avoid multirun failure - raise ex - - # things to always do after either success or exception - finally: - # display output dir path in terminal - log.info(f"Output dir: {cfg.paths.output_dir}") - - # close loggers (even if exception occurs so multirun won't fail) - close_loggers() - - return metric_dict, object_dict - - return wrap - - -def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. - - Utilities: - - Ignoring python warnings - - Setting tags from command line - - Rich config printing - """ - - # return if no `extras` config - if not cfg.get("extras"): - log.warning("Extras config not found! ") - return - - # disable python warnings - if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - # prompt user to input tags from command line if none are provided in the config - if cfg.extras.get("enforce_tags"): - log.info("Enforcing tags! ") - rich_utils.enforce_tags(cfg, save_to_file=True) - - # pretty print config tree using Rich library - if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") - rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("No callback configs found! Skipping..") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: - """Instantiates loggers from config.""" - logger: List[Logger] = [] - - if not logger_cfg: - log.warning("No logger configs found! Skipping...") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) - - return logger - - -@rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: - """Controls which config parts are saved by lightning loggers. - - Additionally saves: - - Number of model parameters - """ - - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return - - hparams["model"] = cfg["model"] - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - ## Note: we do not use the data config, since it is dealt with in the model - ## which is a `LightningModule` - # hparams["data"] = cfg["data"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) - - -def get_metric_value(metric_dict: dict, metric_name: str) -> float: - """Safely retrieves value of the metric logged in LightningModule.""" - - if not metric_name: - log.info("Metric name is None! Skipping metric value retrieval...") - return None - - if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) - - metric_value = metric_dict[metric_name].item() - log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") - - return metric_value - - -def close_loggers() -> None: - """Makes sure all loggers closed properly (prevents logging failure during multirun).""" - - log.info("Closing loggers...") - - if find_spec("wandb"): # if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - -@rank_zero_only -def save_file(path: str, content: str) -> None: - """Save file in rank zero mode (only on one process in multi-GPU setup).""" - with open(path, "w+") as file: - file.write(content) - - -def merge_with_defaults(_config=None, **defaults) -> dict: - """Merge configuration with default values. - - This function merges a provided configuration dictionary with default values. - If no configuration is provided (`_config` is None), it returns the default values. - If a dictionary is provided, it updates the defaults dictionary with the values from the provided dictionary. - Otherwise, it sets all keys in the defaults dictionary to `_config`. - - Args: - _config: Configuration to merge. Defaults to None. - **defaults: Default values to merge with the configuration. - - Returns: - dict: Merged configuration with default values. - """ - if _config is None: - return defaults - elif isinstance(_config, (DictConfig, dict)): - defaults.update(dict(**_config)) # type: ignore - return defaults - else: - return {key: _config for key in defaults.keys()} - - -def show_versions(): - """ - This function prints version information that is useful when filing bug - reports. Inspired by https://github.com/PyVRP/PyVRP - """ - - modules = { - "rl4co": "rl4co", - "torch": "torch", - "lightning": "pytorch_lightning", # Updated module name if necessary - "torchrl": "torchrl", - "tensordict": "tensordict", - "numpy": "numpy", - "pytorch_geometric": "torch_geometric", - "hydra-core": "hydra", - "omegaconf": "omegaconf", - "matplotlib": "matplotlib", - } - - # Find the longest module name for formatting - longest_name = max(len(name) for name in modules.keys()) - - print("INSTALLED VERSIONS") - print("-" * (longest_name + 20)) - # modules - for name, module in modules.items(): - try: - imported_module = importlib.import_module(module) - version = imported_module.__version__ - except ImportError: - version = "Not installed" - print(f"{name.rjust(longest_name)} : {version}") - # platform information - print(f'{"Python".rjust(longest_name)} : {sys.version.split()[0]}') - print(f'{"Platform".rjust(longest_name)} : {platform.platform()}') - try: - lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator(None) - except Exception: - lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator() - # lightning hardware accelerators - print(f'{"Lightning device".rjust(longest_name)} : {lightning_auto_device}') From 99af6f90e070fa7a4d3bd4a33860e82f7696b63e Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Mon, 13 May 2024 17:34:14 +0800 Subject: [PATCH 3/6] update mtvrp --- rl4co/tasks/__init__.py | 0 rl4co/tasks/eval.py | 407 +++++++++++++++++++ rl4co/tasks/train.py | 117 ++++++ rl4co/utils/__init__.py | 11 + rl4co/utils/callbacks/speed_monitor.py | 123 ++++++ rl4co/utils/decoding.py | 540 +++++++++++++++++++++++++ rl4co/utils/instantiators.py | 51 +++ rl4co/utils/lightning.py | 76 ++++ rl4co/utils/ops.py | 214 ++++++++++ rl4co/utils/optim_helpers.py | 38 ++ rl4co/utils/param_grouping.py | 138 +++++++ rl4co/utils/pylogger.py | 25 ++ rl4co/utils/rich_utils.py | 97 +++++ rl4co/utils/test_utils.py | 62 +++ rl4co/utils/trainer.py | 152 +++++++ rl4co/utils/utils.py | 287 +++++++++++++ 16 files changed, 2338 insertions(+) create mode 100644 rl4co/tasks/__init__.py create mode 100644 rl4co/tasks/eval.py create mode 100644 rl4co/tasks/train.py create mode 100644 rl4co/utils/__init__.py create mode 100644 rl4co/utils/callbacks/speed_monitor.py create mode 100644 rl4co/utils/decoding.py create mode 100644 rl4co/utils/instantiators.py create mode 100644 rl4co/utils/lightning.py create mode 100644 rl4co/utils/ops.py create mode 100644 rl4co/utils/optim_helpers.py create mode 100644 rl4co/utils/param_grouping.py create mode 100644 rl4co/utils/pylogger.py create mode 100644 rl4co/utils/rich_utils.py create mode 100644 rl4co/utils/test_utils.py create mode 100644 rl4co/utils/trainer.py create mode 100644 rl4co/utils/utils.py diff --git a/rl4co/tasks/__init__.py b/rl4co/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl4co/tasks/eval.py b/rl4co/tasks/eval.py new file mode 100644 index 00000000..5be1abcc --- /dev/null +++ b/rl4co/tasks/eval.py @@ -0,0 +1,407 @@ +import numpy as np +import torch + +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from rl4co.data.transforms import StateAugmentation +from rl4co.utils.ops import batchify, gather_by_index, unbatchify + + +def check_unused_kwargs(class_, kwargs): + if len(kwargs) > 0 and not (len(kwargs) == 1 and "progress" in kwargs): + print(f"Warning: {class_.__class__.__name__} does not use kwargs {kwargs}") + + +class EvalBase: + """Base class for evaluation + + Args: + env: Environment + progress: Whether to show progress bar + **kwargs: Additional arguments (to be implemented in subclasses) + """ + + name = "base" + + def __init__(self, env, progress=True, **kwargs): + check_unused_kwargs(self, kwargs) + self.env = env + self.progress = progress + + def __call__(self, policy, dataloader, **kwargs): + """Evaluate the policy on the given dataloader with **kwargs parameter + self._inner is implemented in subclasses and returns actions and rewards + """ + + # Collect timings for evaluation (more accurate than timeit) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + with torch.inference_mode(): + rewards_list = [] + actions_list = [] + + for batch in tqdm( + dataloader, disable=not self.progress, desc=f"Running {self.name}" + ): + td = batch.to(next(policy.parameters()).device) + td = self.env.reset(td) + actions, rewards = self._inner(policy, td, **kwargs) + rewards_list.append(rewards) + actions_list.append(actions) + + rewards = torch.cat(rewards_list) + + # Padding: pad actions to the same length with zeros + max_length = max(action.size(-1) for action in actions) + actions = torch.cat( + [ + torch.nn.functional.pad(action, (0, max_length - action.size(-1))) + for action in actions + ], + 0, + ) + + end_event.record() + torch.cuda.synchronize() + inference_time = start_event.elapsed_time(end_event) + + tqdm.write(f"Mean reward for {self.name}: {rewards.mean():.4f}") + tqdm.write(f"Time: {inference_time/1000:.4f}s") + + # Empty cache + torch.cuda.empty_cache() + + return { + "actions": actions.cpu(), + "rewards": rewards.cpu(), + "inference_time": inference_time, + "avg_reward": rewards.cpu().mean(), + } + + def _inner(self, policy, td): + """Inner function to be implemented in subclasses. + This function returns actions and rewards for the given policy + """ + raise NotImplementedError("Implement in subclass") + + +class GreedyEval(EvalBase): + """Evaluates the policy using greedy decoding and single trajectory""" + + name = "greedy" + + def __init__(self, env, **kwargs): + check_unused_kwargs(self, kwargs) + super().__init__(env, kwargs.get("progress", True)) + + def _inner(self, policy, td): + out = policy( + td.clone(), + decode_type="greedy", + num_starts=0, + return_actions=True, + ) + rewards = self.env.get_reward(td, out["actions"]) + return out["actions"], rewards + + +class AugmentationEval(EvalBase): + """Evaluates the policy via N state augmentations + `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO + https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 + + Args: + num_augment (int): Number of state augmentations + force_dihedral_8 (bool): Whether to force the use of 8 augmentations + """ + + name = "augmentation" + + def __init__(self, env, num_augment=8, force_dihedral_8=False, feats=None, **kwargs): + check_unused_kwargs(self, kwargs) + super().__init__(env, kwargs.get("progress", True)) + self.augmentation = StateAugmentation( + num_augment=num_augment, + augment_fn="dihedral8" if force_dihedral_8 else "symmetric", + feats=feats, + ) + + def _inner(self, policy, td, num_augment=None): + if num_augment is None: + num_augment = self.augmentation.num_augment + td_init = td.clone() + td = self.augmentation(td) + out = policy(td.clone(), decode_type="greedy", num_starts=0, return_actions=True) + + # Move into batches and compute rewards + rewards = self.env.get_reward(batchify(td_init, num_augment), out["actions"]) + rewards = unbatchify(rewards, num_augment) + actions = unbatchify(out["actions"], num_augment) + + # Get best reward and corresponding action + rewards, max_idxs = rewards.max(dim=1) + actions = gather_by_index(actions, max_idxs, dim=1) + return actions, rewards + + @property + def num_augment(self): + return self.augmentation.num_augment + + +class SamplingEval(EvalBase): + """Evaluates the policy via N samples from the policy + + Args: + samples (int): Number of samples to take + softmax_temp (float): Temperature for softmax sampling. The higher the temperature, the more random the sampling + """ + + name = "sampling" + + def __init__(self, env, samples, softmax_temp=None, **kwargs): + check_unused_kwargs(self, kwargs) + super().__init__(env, kwargs.get("progress", True)) + + self.samples = samples + self.softmax_temp = softmax_temp + + def _inner(self, policy, td): + td = batchify(td, self.samples) + out = policy( + td.clone(), + decode_type="sampling", + num_starts=0, + return_actions=True, + softmax_temp=self.softmax_temp, + ) + + # Move into batches and compute rewards + rewards = self.env.get_reward(td, out["actions"]) + rewards = unbatchify(rewards, self.samples) + actions = unbatchify(out["actions"], self.samples) + + # Get the best reward and action for each sample + rewards, max_idxs = rewards.max(dim=1) + actions = gather_by_index(actions, max_idxs, dim=1) + return actions, rewards + + +class GreedyMultiStartEval(EvalBase): + """Evaluates the policy via `num_starts` greedy multistarts samples from the policy + + Args: + num_starts (int): Number of greedy multistarts to use + """ + + name = "multistart_greedy" + + def __init__(self, env, num_starts=None, **kwargs): + check_unused_kwargs(self, kwargs) + super().__init__(env, kwargs.get("progress", True)) + + assert num_starts is not None, "Must specify num_starts" + self.num_starts = num_starts + + def _inner(self, policy, td): + td_init = td.clone() + out = policy( + td.clone(), + decode_type="multistart_greedy", + num_starts=self.num_starts, + return_actions=True, + ) + + # Move into batches and compute rewards + td = batchify(td_init, self.num_starts) + rewards = self.env.get_reward(td, out["actions"]) + rewards = unbatchify(rewards, self.num_starts) + actions = unbatchify(out["actions"], self.num_starts) + + # Get the best trajectories + rewards, max_idxs = rewards.max(dim=1) + actions = gather_by_index(actions, max_idxs, dim=1) + return actions, rewards + + +class GreedyMultiStartAugmentEval(EvalBase): + """Evaluates the policy via `num_starts` samples from the policy + and `num_augment` augmentations of each sample.` + `force_dihedral_8` forces the use of 8 augmentations (rotations and flips) as in POMO + https://en.wikipedia.org/wiki/Examples_of_groups#dihedral_group_of_order_8 + + Args: + num_starts: Number of greedy multistart samples + num_augment: Number of augmentations per sample + force_dihedral_8: If True, force the use of 8 augmentations (rotations and flips) as in POMO + """ + + name = "multistart_greedy_augment" + + def __init__( + self, + env, + num_starts=None, + num_augment=8, + force_dihedral_8=False, + feats=None, + **kwargs, + ): + check_unused_kwargs(self, kwargs) + super().__init__(env, kwargs.get("progress", True)) + + assert num_starts is not None, "Must specify num_starts" + self.num_starts = num_starts + assert not ( + num_augment != 8 and force_dihedral_8 + ), "Cannot force dihedral 8 when num_augment != 8" + self.augmentation = StateAugmentation( + num_augment=num_augment, + augment_fn="dihedral8" if force_dihedral_8 else "symmetric", + feats=feats, + ) + + def _inner(self, policy, td, num_augment=None): + if num_augment is None: + num_augment = self.augmentation.num_augment + + td_init = td.clone() + + td = self.augmentation(td) + out = policy( + td.clone(), + decode_type="multistart_greedy", + num_starts=self.num_starts, + return_actions=True, + ) + + # Move into batches and compute rewards + td = batchify(td_init, (num_augment, self.num_starts)) + rewards = self.env.get_reward(td, out["actions"]) + rewards = unbatchify(rewards, self.num_starts * num_augment) + actions = unbatchify(out["actions"], self.num_starts * num_augment) + + # Get the best trajectories + rewards, max_idxs = rewards.max(dim=1) + actions = gather_by_index(actions, max_idxs, dim=1) + return actions, rewards + + @property + def num_augment(self): + return self.augmentation.num_augment + + +def get_automatic_batch_size(eval_fn, start_batch_size=8192, max_batch_size=4096): + """Automatically reduces the batch size based on the eval function + + Args: + eval_fn: The eval function + start_batch_size: The starting batch size. This should be the theoretical maximum batch size + max_batch_size: The maximum batch size. This is the practical maximum batch size + """ + batch_size = start_batch_size + + effective_ratio = 1 + + if hasattr(eval_fn, "num_starts"): + batch_size = batch_size // (eval_fn.num_starts // 10) + effective_ratio *= eval_fn.num_starts // 10 + if hasattr(eval_fn, "num_augment"): + batch_size = batch_size // eval_fn.num_augment + effective_ratio *= eval_fn.num_augment + if hasattr(eval_fn, "samples"): + batch_size = batch_size // eval_fn.samples + effective_ratio *= eval_fn.samples + + batch_size = min(batch_size, max_batch_size) + # get closest integer power of 2 + batch_size = 2 ** int(np.log2(batch_size)) + + print(f"Effective batch size: {batch_size} (ratio: {effective_ratio})") + + return batch_size + + +def evaluate_policy( + env, + policy, + dataset, + method="greedy", + batch_size=None, + max_batch_size=4096, + start_batch_size=8192, + auto_batch_size=True, + save_results=False, + save_fname="results.npz", + **kwargs, +): + num_loc = getattr(env.generator, "num_loc", None) + + methods_mapping = { + "greedy": {"func": GreedyEval, "kwargs": {}}, + "sampling": { + "func": SamplingEval, + "kwargs": {"samples": 100, "softmax_temp": 1.0}, + }, + "multistart_greedy": { + "func": GreedyMultiStartEval, + "kwargs": {"num_starts": num_loc}, + }, + "augment_dihedral_8": { + "func": AugmentationEval, + "kwargs": {"num_augment": 8, "force_dihedral_8": True}, + }, + "augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}}, + "multistart_greedy_augment_dihedral_8": { + "func": GreedyMultiStartAugmentEval, + "kwargs": { + "num_augment": 8, + "force_dihedral_8": True, + "num_starts": num_loc, + }, + }, + "multistart_greedy_augment": { + "func": GreedyMultiStartAugmentEval, + "kwargs": {"num_augment": 8, "num_starts": num_loc}, + }, + } + + assert method in methods_mapping, "Method {} not found".format(method) + + # Set up the evaluation function + eval_settings = methods_mapping[method] + func, kwargs_ = eval_settings["func"], eval_settings["kwargs"] + # subsitute kwargs with the ones passed in + kwargs_.update(kwargs) + kwargs = kwargs_ + eval_fn = func(env, **kwargs) + + if auto_batch_size: + assert ( + batch_size is None + ), "Cannot specify batch_size when auto_batch_size is True" + batch_size = get_automatic_batch_size( + eval_fn, max_batch_size=max_batch_size, start_batch_size=start_batch_size + ) + print("Using automatic batch size: {}".format(batch_size)) + + # Set up the dataloader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + collate_fn=dataset.collate_fn, + ) + + # Run evaluation + retvals = eval_fn(policy, dataloader) + + # Save results + if save_results: + print("Saving results to {}".format(save_fname)) + np.savez(save_fname, **retvals) + + return retvals diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py new file mode 100644 index 00000000..6826382d --- /dev/null +++ b/rl4co/tasks/train.py @@ -0,0 +1,117 @@ +from typing import List, Optional, Tuple + +import hydra +import lightning as L +import pyrootutils +import torch + +from lightning import Callback, LightningModule +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from rl4co import utils +from rl4co.utils import RL4COTrainer + +pyrootutils.setup_root(__file__, indicator=".gitignore", pythonpath=True) + + +log = utils.get_pylogger(__name__) + + +@utils.task_wrapper +def run(cfg: DictConfig) -> Tuple[dict, dict]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + Args: + cfg (DictConfig): Configuration composed by Hydra. + Returns: + Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. + """ + + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + # We instantiate the environment separately and then pass it to the model + log.info(f"Instantiating environment <{cfg.env._target_}>") + env = hydra.utils.instantiate(cfg.env) + + # Note that the RL environment is instantiated inside the model + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model, env) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) + + log.info("Instantiating trainer...") + trainer: RL4COTrainer = hydra.utils.instantiate( + cfg.trainer, + callbacks=callbacks, + logger=logger, + ) + + object_dict = { + "cfg": cfg, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + utils.log_hyperparameters(object_dict) + + if cfg.get("compile", False): + log.info("Compiling model!") + model = torch.compile(model) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../../configs", config_name="main.yaml") +def train(cfg: DictConfig) -> Optional[float]: + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + utils.extras(cfg) + + # train the model + metric_dict, _ = run(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = utils.get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + train() diff --git a/rl4co/utils/__init__.py b/rl4co/utils/__init__.py new file mode 100644 index 00000000..4b0246aa --- /dev/null +++ b/rl4co/utils/__init__.py @@ -0,0 +1,11 @@ +from rl4co.utils.instantiators import instantiate_callbacks, instantiate_loggers +from rl4co.utils.pylogger import get_pylogger +from rl4co.utils.rich_utils import enforce_tags, print_config_tree +from rl4co.utils.trainer import RL4COTrainer +from rl4co.utils.utils import ( + extras, + get_metric_value, + log_hyperparameters, + show_versions, + task_wrapper, +) diff --git a/rl4co/utils/callbacks/speed_monitor.py b/rl4co/utils/callbacks/speed_monitor.py new file mode 100644 index 00000000..3f1ab6ae --- /dev/null +++ b/rl4co/utils/callbacks/speed_monitor.py @@ -0,0 +1,123 @@ +# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor +# We only need the speed monitoring, not the GPU monitoring +import time + +import lightning as L + +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities.parsing import AttributeDict +from lightning.pytorch.utilities.rank_zero import rank_zero_only + + +class SpeedMonitor(Callback): + """Monitor the speed of each step and each epoch.""" + + def __init__( + self, + intra_step_time: bool = True, + inter_step_time: bool = True, + epoch_time: bool = True, + verbose=False, + ): + super().__init__() + self._log_stats = AttributeDict( + { + "intra_step_time": intra_step_time, + "inter_step_time": inter_step_time, + "epoch_time": epoch_time, + } + ) + self.verbose = verbose + + def on_train_start(self, trainer: "L.Trainer", L_module: "L.LightningModule") -> None: + self._snap_epoch_time = None + + def on_train_epoch_start( + self, trainer: "L.Trainer", L_module: "L.LightningModule" + ) -> None: + self._snap_intra_step_time = None + self._snap_inter_step_time = None + self._snap_epoch_time = time.time() + + def on_validation_epoch_start( + self, trainer: "L.Trainer", L_module: "L.LightningModule" + ) -> None: + self._snap_inter_step_time = None + + def on_test_epoch_start( + self, trainer: "L.Trainer", L_module: "L.LightningModule" + ) -> None: + self._snap_inter_step_time = None + + @rank_zero_only + def on_train_batch_start( + self, + trainer: "L.Trainer", + *unused_args, + **unused_kwargs, # easy fix for new pytorch lightning versions + ) -> None: + if self._log_stats.intra_step_time: + self._snap_intra_step_time = time.time() + + if not self._should_log(trainer): + return + + logs = {} + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs["time/inter_step (ms)"] = ( + time.time() - self._snap_inter_step_time + ) * 1000 + + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_batch_end( + self, + trainer: "L.Trainer", + L_module: "L.LightningModule", + *unused_args, + **unused_kwargs, # easy fix for new pytorch lightning versions + ) -> None: + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if ( + self.verbose + and self._log_stats.intra_step_time + and self._snap_intra_step_time + ): + L_module.print( + f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}" + ) + + if not self._should_log(trainer): + return + + logs = {} + if self._log_stats.intra_step_time and self._snap_intra_step_time: + logs["time/intra_step (ms)"] = ( + time.time() - self._snap_intra_step_time + ) * 1000 + + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_epoch_end( + self, + trainer: "L.Trainer", + L_module: "L.LightningModule", + ) -> None: + logs = {} + if self._log_stats.epoch_time and self._snap_epoch_time: + logs["time/epoch (s)"] = time.time() - self._snap_epoch_time + if trainer.logger is not None: + trainer.logger.log_metrics(logs, step=trainer.global_step) + + @staticmethod + def _should_log(trainer) -> bool: + return ( + trainer.global_step + 1 + ) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/rl4co/utils/decoding.py b/rl4co/utils/decoding.py new file mode 100644 index 00000000..b0a0ae90 --- /dev/null +++ b/rl4co/utils/decoding.py @@ -0,0 +1,540 @@ +import abc + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from tensordict.tensordict import TensorDict + +from rl4co.envs import RL4COEnvBase +from rl4co.utils.ops import batchify +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def get_decoding_strategy(decoding_strategy, **config): + strategy_registry = { + "greedy": Greedy, + "sampling": Sampling, + "multistart_greedy": Greedy, + "multistart_sampling": Sampling, + "beam_search": BeamSearch, + "evaluate": Evaluate, + } + + if decoding_strategy not in strategy_registry: + log.warning( + f"Unknown decode type '{decoding_strategy}'. Available decode types: {strategy_registry.keys()}. Defaulting to Sampling." + ) + + if "multistart" in decoding_strategy: + config["multistart"] = True + + return strategy_registry.get(decoding_strategy, Sampling)(**config) + + +def get_log_likelihood(logprobs, actions, mask=None, return_sum: bool = True): + """Get log likelihood of selected actions. + Note that mask is a boolean tensor where True means the value should be kept. + + Args: + logprobs: Log probabilities of actions from the model (batch_size, seq_len, action_dim). + actions: Selected actions (batch_size, seq_len). + mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). + return_sum: Whether to return the sum of log probabilities or not. Defaults to True. + """ + logprobs = logprobs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) + + # Optional: mask out actions irrelevant to objective so they do not get reinforced + if mask is not None: + logprobs[~mask] = 0 + + assert ( + logprobs > -1000 + ).data.all(), "Logprobs should not be -inf, check sampling procedure!" + + # Calculate log_likelihood + if return_sum: + return logprobs.sum(1) # [batch] + else: + return logprobs # [batch, decode_len] + + +def decode_logprobs(logprobs, mask, decode_type="sampling"): + """Decode log probabilities to select actions with mask. + Note that mask is a boolean tensor where True means the value should be kept. + """ + if "greedy" in decode_type: + selected = DecodingStrategy.greedy(logprobs, mask) + elif "sampling" in decode_type: + selected = DecodingStrategy.sampling(logprobs, mask) + else: + assert False, "Unknown decode type: {}".format(decode_type) + return selected + + +def random_policy(td): + """Helper function to select a random action from available actions""" + action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1) + td.set("action", action) + return td + + +def rollout(env, td, policy, max_steps: int = None): + """Helper function to rollout a policy. Currently, TorchRL does not allow to step + over envs when done with `env.rollout()`. We need this because for environments that complete at different steps. + """ + + max_steps = float("inf") if max_steps is None else max_steps + actions = [] + steps = 0 + + while not td["done"].all(): + td = policy(td) + actions.append(td["action"]) + td = env.step(td)["next"] + steps += 1 + if steps > max_steps: + log.info("Max steps reached") + break + return ( + env.get_reward(td, torch.stack(actions, dim=1)), + td, + torch.stack(actions, dim=1), + ) + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done out-of-place. + Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L6 + """ + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + return logits.masked_fill(indices_to_remove, float("-inf")) + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done out-of-place. + Ref: https://github.com/togethercomputer/stripedhyena/blob/7e13f618027fea9625be1f2d2d94f9a361f6bd02/stripedhyena/sample.py#L14 + """ + if top_p <= 0.0 or top_p >= 1.0: + return logits + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + + # Scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + return logits.masked_fill(indices_to_remove, float("-inf")) + + +def process_logits( + logits: torch.Tensor, + mask: torch.Tensor = None, + temperature: float = 1.0, + top_p: float = 0.0, + top_k: int = 0, + tanh_clipping: float = 0, + mask_logits: bool = True, +): + """Convert logits to log probabilities with additional features like temperature scaling, top-k and top-p sampling. + + Note: + We convert to log probabilities instead of probabilities to avoid numerical instability. + This is because, roughly, softmax = exp(logits) / sum(exp(logits)) and log(softmax) = logits - log(sum(exp(logits))), + and avoiding the division by the sum of exponentials can help with numerical stability. + You may check the [official PyTorch documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html). + + Args: + logits: Logits from the model (batch_size, num_actions). + mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). + temperature: Temperature scaling. Higher values make the distribution more uniform (exploration), + lower values make it more peaky (exploitation). + top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Remove tokens that have a cumulative probability + less than the threshold 1 - top_p (lower tail of the distribution). If 0, do not perform. + top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Note that we only do filtering and + do not return all the top-k logits here. + tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). + mask_logits: Whether to mask logits of infeasible actions. + """ + + # Tanh clipping from Bello et al. 2016 + if tanh_clipping > 0: + logits = torch.tanh(logits) * tanh_clipping + + # In RL, we want to mask the logits to prevent the agent from selecting infeasible actions + if mask_logits: + assert mask is not None, "mask must be provided if mask_logits is True" + logits[~mask] = float("-inf") + + logits = logits / temperature # temperature scaling + + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # safety check + logits = modify_logits_for_top_k_filtering(logits, top_k) + + if top_p > 0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + logits = modify_logits_for_top_p_filtering(logits, top_p) + + # Compute log probabilities + return F.log_softmax(logits, dim=-1) + + +class DecodingStrategy(metaclass=abc.ABCMeta): + """Base class for decoding strategies. Subclasses should implement the :meth:`_step` method. + Includes hooks for pre and post main decoding operations. + + Args: + temperature: Temperature scaling. Higher values make the distribution more uniform (exploration), + lower values make it more peaky (exploitation). Defaults to 1.0. + top_p: Top-p sampling, a.k.a. Nucleus Sampling (https://arxiv.org/abs/1904.09751). Defaults to 0.0. + top_k: Top-k sampling, i.e. restrict sampling to the top k logits. If 0, do not perform. Defaults to 0. + mask_logits: Whether to mask logits of infeasible actions. Defaults to True. + tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0. + multistart: Whether to use multistart decoding. Defaults to False. + num_starts: Number of starts for multistart decoding. Defaults to None. + """ + + name = "base" + + def __init__( + self, + temperature: float = 1.0, + top_p: float = 0.0, + top_k: int = 0, + mask_logits: bool = True, + tanh_clipping: float = 0, + multistart: bool = False, + num_starts: Optional[int] = None, + select_start_nodes_fn: Optional[callable] = None, + **kwargs, + ) -> None: + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.mask_logits = mask_logits + self.tanh_clipping = tanh_clipping + self.multistart = multistart + self.num_starts = num_starts + self.select_start_nodes_fn = select_start_nodes_fn + # initialize buffers + self.actions = [] + self.logprobs = [] + + @abc.abstractmethod + def _step( + self, + logprobs: torch.Tensor, + mask: torch.Tensor, + td: TensorDict, + action: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Main decoding operation. This method should be called in a loop until all sequences are done. + + Args: + logprobs: Log probabilities processed from logits of the model. + mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). + td: TensorDict containing the current state of the environment. + action: Optional action to use, e.g. for evaluating log probabilities. + """ + raise NotImplementedError("Must be implemented by subclass") + + def pre_decoder_hook( + self, td: TensorDict, env: RL4COEnvBase, action: torch.Tensor = None + ): + """Pre decoding hook. This method is called before the main decoding operation.""" + # Multi-start decoding. If num_starts is None, we use the number of actions in the action mask + if self.multistart: + if self.num_starts is None: + self.num_starts = env.get_num_starts(td) + else: + if self.num_starts is not None: + if self.num_starts >= 1: + log.warn( + f"num_starts={self.num_starts} is ignored for decode_type={self.name}" + ) + + self.num_starts = 0 + + # Multi-start decoding: first action is chosen by ad-hoc node selection + if self.num_starts >= 1: + if action is None: # if action is provided, we use it as the first action + if self.select_start_nodes_fn is not None: + action = self.select_start_nodes_fn(td, env, self.num_starts) + else: + action = env.select_start_nodes(td, num_starts=self.num_starts) + + # Expand td to batch_size * num_starts + td = batchify(td, self.num_starts) + + td.set("action", action) + td = env.step(td)["next"] + logprobs = torch.zeros_like( + td["action_mask"], device=td.device + ) # first logprobs is 0, so p = logprobs.exp() = 1 + + self.logprobs.append(logprobs) + self.actions.append(action) + + return td, env, self.num_starts + + def post_decoder_hook( + self, td: TensorDict, env: RL4COEnvBase + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict, RL4COEnvBase]: + assert ( + len(self.logprobs) > 0 + ), "No logprobs were collected because all environments were done. Check your initial state" + + return torch.stack(self.logprobs, 1), torch.stack(self.actions, 1), td, env + + def step( + self, + logits: torch.Tensor, + mask: torch.Tensor, + td: TensorDict, + action: torch.Tensor = None, + **kwargs, + ) -> TensorDict: + """Main decoding operation. This method should be called in a loop until all sequences are done. + + Args: + logits: Logits from the model. + mask: Action mask. 1 if feasible, 0 otherwise (so we keep if 1 as done in PyTorch). + td: TensorDict containing the current state of the environment. + action: Optional action to use, e.g. for evaluating log probabilities. + """ + if not self.mask_logits: # set mask_logit to None if mask_logits is False + mask = None + + logprobs = process_logits( + logits, + mask, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + tanh_clipping=self.tanh_clipping, + mask_logits=self.mask_logits, + ) + logprobs, selected_action, td = self._step( + logprobs, mask, td, action=action, **kwargs + ) + td.set("action", selected_action) + self.actions.append(selected_action) + self.logprobs.append(logprobs) + return td + + @staticmethod + def greedy(logprobs, mask=None): + """Select the action with the highest probability.""" + # [BS], [BS] + selected = logprobs.argmax(dim=-1) + if mask is not None: + assert ( + not (~mask).gather(1, selected.unsqueeze(-1)).data.any() + ), "infeasible action selected" + + return selected + + @staticmethod + def sampling(logprobs, mask=None): + """Sample an action with a multinomial distribution given by the log probabilities.""" + probs = logprobs.exp() + selected = torch.multinomial(probs, 1).squeeze(1) + + if mask is not None: + while (~mask).gather(1, selected.unsqueeze(-1)).data.any(): + log.info("Sampled bad values, resampling!") + selected = probs.multinomial(1).squeeze(1) + assert ( + not (~mask).gather(1, selected.unsqueeze(-1)).data.any() + ), "infeasible action selected" + + return selected + + +class Greedy(DecodingStrategy): + name = "greedy" + + def _step( + self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Select the action with the highest log probability""" + selected = self.greedy(logprobs, mask) + return logprobs, selected, td + + +class Sampling(DecodingStrategy): + name = "sampling" + + def _step( + self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """Sample an action with a multinomial distribution given by the log probabilities.""" + selected = self.sampling(logprobs, mask) + return logprobs, selected, td + + +class Evaluate(DecodingStrategy): + name = "evaluate" + + def _step( + self, + logprobs: torch.Tensor, + mask: torch.Tensor, + td: TensorDict, + action: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + """The action is provided externally, so we just return the action""" + selected = action + return logprobs, selected, td + + +class BeamSearch(DecodingStrategy): + name = "beam_search" + + def __init__(self, beam_width=None, select_best=True, **kwargs) -> None: + super().__init__(**kwargs) + self.beam_width = beam_width + self.select_best = select_best + self.parent_beam_logprobs = None + self.beam_path = [] + + def _step( + self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + selected, batch_beam_idx = self._make_beam_step(logprobs) + # select the correct state representation, logprobs and mask according to beam parent + td = td[batch_beam_idx] + logprobs = logprobs[batch_beam_idx] + mask = mask[batch_beam_idx] + + assert ( + not (~mask).gather(1, selected.unsqueeze(-1)).data.any() + ), "infeasible action selected" + + return logprobs, selected, td + + def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs): + if self.beam_width is None: + self.beam_width = env.get_num_starts(td) + assert self.beam_width > 1, "beam width must be larger than 1" + + # select start nodes. TODO: include first step in beam search as well + if self.select_start_nodes_fn is not None: + action = self.select_start_nodes_fn(td, env, self.beam_width) + else: + action = env.select_start_nodes(td, num_starts=self.beam_width) + + # Expand td to batch_size * beam_width + td = batchify(td, self.beam_width) + + td.set("action", action) + td = env.step(td)["next"] + + logprobs = torch.zeros_like(td["action_mask"], device=td.device) + beam_parent = torch.zeros(logprobs.size(0), device=td.device, dtype=torch.int32) + + self.logprobs.append(logprobs) + self.actions.append(action) + self.parent_beam_logprobs = logprobs.gather(1, action[..., None]) + self.beam_path.append(beam_parent) + + return td, env, self.beam_width + + def post_decoder_hook(self, td, env): + # [BS*BW, seq_len] + aligned_sequences, aligned_logprobs = self._backtrack() + + if self.select_best: + return self._select_best_beam(aligned_logprobs, aligned_sequences, td, env) + else: + return aligned_logprobs, aligned_sequences, td, env + + def _backtrack(self): + # [BS*BW, seq_len] + actions = torch.stack(self.actions, 1) + # [BS*BW, seq_len] + logprobs = torch.stack(self.logprobs, 1) + assert actions.size(1) == len( + self.beam_path + ), "action idx shape and beam path shape dont match" + + # [BS*BW] + cur_parent = self.beam_path[-1] + # [BS*BW] + reversed_aligned_sequences = [actions[:, -1]] + reversed_aligned_logprobs = [logprobs[:, -1]] + + aug_batch_size = actions.size(0) + batch_size = aug_batch_size // self.beam_width + batch_beam_sequence = ( + torch.arange(0, batch_size).repeat(self.beam_width).to(actions.device) + ) + + for k in reversed(range(len(self.beam_path) - 1)): + batch_beam_idx = batch_beam_sequence + cur_parent * batch_size + + reversed_aligned_sequences.append(actions[batch_beam_idx, k]) + reversed_aligned_logprobs.append(logprobs[batch_beam_idx, k]) + cur_parent = self.beam_path[k][batch_beam_idx] + + # [BS*BW, seq_len*num_targets] + actions = torch.stack(list(reversed(reversed_aligned_sequences)), dim=1) + logprobs = torch.stack(list(reversed(reversed_aligned_logprobs)), dim=1) + + return actions, logprobs + + def _select_best_beam(self, logprobs, actions, td: TensorDict, env: RL4COEnvBase): + aug_batch_size = logprobs.size(0) # num nodes + batch_size = aug_batch_size // self.beam_width + rewards = env.get_reward(td, actions) + _, idx = torch.cat(rewards.unsqueeze(1).split(batch_size), 1).max(1) + flat_idx = torch.arange(batch_size, device=rewards.device) + idx * batch_size + return logprobs[flat_idx], actions[flat_idx], td[flat_idx], env + + def _make_beam_step(self, logprobs: torch.Tensor): + aug_batch_size, num_nodes = logprobs.shape # num nodes + batch_size = aug_batch_size // self.beam_width + batch_beam_sequence = ( + torch.arange(0, batch_size).repeat(self.beam_width).to(logprobs.device) + ) + + # [BS*BW, num_nodes] + [BS*BW, 1] -> [BS*BW, num_nodes] + log_beam_prob = logprobs + self.parent_beam_logprobs # + + # [BS, num_nodes * BW] + log_beam_prob_hstacked = torch.cat(log_beam_prob.split(batch_size), dim=1) + # [BS, BW] + topk_logprobs, topk_ind = torch.topk( + log_beam_prob_hstacked, self.beam_width, dim=1 + ) + + # [BS*BW, 1] + logprobs_selected = torch.hstack(torch.unbind(topk_logprobs, 1)).unsqueeze(1) + + # [BS*BW, 1] + topk_ind = torch.hstack(torch.unbind(topk_ind, 1)) + + # since we stack the logprobs from the distinct branches, the indices in + # topk dont correspond to node indices directly and need to be translated + selected = topk_ind % num_nodes # determine node index + + # calc parent this branch comes from + beam_parent = (topk_ind // num_nodes).int() + + batch_beam_idx = batch_beam_sequence + beam_parent * batch_size + + self.parent_beam_logprobs = logprobs_selected + self.beam_path.append(beam_parent) + + return selected, batch_beam_idx diff --git a/rl4co/utils/instantiators.py b/rl4co/utils/instantiators.py new file mode 100644 index 00000000..e3b25183 --- /dev/null +++ b/rl4co/utils/instantiators.py @@ -0,0 +1,51 @@ +from typing import List + +import hydra + +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from rl4co.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/rl4co/utils/lightning.py b/rl4co/utils/lightning.py new file mode 100644 index 00000000..a3f29cb7 --- /dev/null +++ b/rl4co/utils/lightning.py @@ -0,0 +1,76 @@ +import os + +import lightning as L +import torch + +from omegaconf import DictConfig + +# from rl4co. +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def get_lightning_device(lit_module: L.LightningModule) -> torch.device: + """Get the device of the Lightning module before setup is called + See device setting issue in setup https://github.com/Lightning-AI/lightning/issues/2638 + """ + try: + if lit_module.trainer.strategy.root_device != lit_module.device: + return lit_module.trainer.strategy.root_device + return lit_module.device + except Exception: + return lit_module.device + + +def remove_key(config, key="wandb"): + """Remove keys containing 'key`""" + new_config = {} + for k, v in config.items(): + if key in k: + continue + else: + new_config[k] = v + return new_config + + +def clean_hydra_config( + config, keep_value_only=True, remove_keys="wandb", clean_cfg_path=True +): + """Clean hydra config by nesting dictionary and cleaning values""" + # Remove keys containing `remove_keys` + if not isinstance(remove_keys, list): + remove_keys = [remove_keys] + for key in remove_keys: + config = remove_key(config, key=key) + + new_config = {} + # Iterate over config dictionary + for key, value in config.items(): + # If key contains slash, split it and create nested dictionary recursively + if "/" in key: + keys = key.split("/") + d = new_config + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value["value"] if keep_value_only else value + else: + new_config[key] = value["value"] if keep_value_only else value + + cfg = DictConfig(new_config) + + if clean_cfg_path: + # Clean cfg_path recursively substituting root_dir with cwd + root_dir = cfg.paths.root_dir + + def replace_dir_recursive(d, search, replace): + for k, v in d.items(): + if isinstance(v, dict) or isinstance(v, DictConfig): + replace_dir_recursive(v, search, replace) + elif isinstance(v, str): + if search in v: + d[k] = v.replace(search, replace) + + replace_dir_recursive(cfg, root_dir, os.getcwd()) + + return cfg diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py new file mode 100644 index 00000000..35a7441b --- /dev/null +++ b/rl4co/utils/ops.py @@ -0,0 +1,214 @@ +from functools import lru_cache +from typing import Optional, Union + +import torch + +from einops import rearrange +from tensordict import TensorDict +from torch import Tensor + + +def _batchify_single( + x: Union[Tensor, TensorDict], repeats: int +) -> Union[Tensor, TensorDict]: + """Same as repeat on dim=0 for Tensordicts as well""" + s = x.shape + return x.expand(repeats, *s).contiguous().view(s[0] * repeats, *s[1:]) + + +def batchify( + x: Union[Tensor, TensorDict], shape: Union[tuple, int] +) -> Union[Tensor, TensorDict]: + """Same as `einops.repeat(x, 'b ... -> (b r) ...', r=repeats)` but ~1.5x faster and supports TensorDicts. + Repeats batchify operation `n` times as specified by each shape element. + If shape is a tuple, iterates over each element and repeats that many times to match the tuple shape. + + Example: + >>> x.shape: [a, b, c, ...] + >>> shape: [a, b, c] + >>> out.shape: [a*b*c, ...] + """ + shape = [shape] if isinstance(shape, int) else shape + for s in reversed(shape): + x = _batchify_single(x, s) if s > 0 else x + return x + + +def _unbatchify_single( + x: Union[Tensor, TensorDict], repeats: int +) -> Union[Tensor, TensorDict]: + """Undoes batchify operation for Tensordicts as well""" + s = x.shape + return x.view(repeats, s[0] // repeats, *s[1:]).permute(1, 0, *range(2, len(s) + 1)) + + +def unbatchify( + x: Union[Tensor, TensorDict], shape: Union[tuple, int] +) -> Union[Tensor, TensorDict]: + """Same as `einops.rearrange(x, '(r b) ... -> b r ...', r=repeats)` but ~2x faster and supports TensorDicts + Repeats unbatchify operation `n` times as specified by each shape element + If shape is a tuple, iterates over each element and unbatchifies that many times to match the tuple shape. + + Example: + >>> x.shape: [a*b*c, ...] + >>> shape: [a, b, c] + >>> out.shape: [a, b, c, ...] + """ + shape = [shape] if isinstance(shape, int) else shape + for s in reversed( + shape + ): # we need to reverse the shape to unbatchify in the right order + x = _unbatchify_single(x, s) if s > 0 else x + return x + + +def gather_by_index(src, idx, dim=1, squeeze=True): + """Gather elements from src by index idx along specified dim + + Example: + >>> src: shape [64, 20, 2] + >>> idx: shape [64, 3)] # 3 is the number of idxs on dim 1 + >>> Returns: [64, 3, 2] # get the 3 elements from src at idx + """ + expanded_shape = list(src.shape) + expanded_shape[dim] = -1 + idx = idx.view(idx.shape + (1,) * (src.dim() - idx.dim())).expand(expanded_shape) + return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) + + +@torch.jit.script +def get_distance(x: Tensor, y: Tensor): + """Euclidean distance between two tensors of shape `[..., n, dim]`""" + return (x - y).norm(p=2, dim=-1) + + +@torch.jit.script +def get_tour_length(ordered_locs): + """Compute the total tour distance for a batch of ordered tours. + Computes the L2 norm between each pair of consecutive nodes in the tour and sums them up. + + Args: + ordered_locs: Tensor of shape [batch_size, num_nodes, 2] containing the ordered locations of the tour + """ + ordered_locs_next = torch.roll(ordered_locs, -1, dims=-2) + return get_distance(ordered_locs_next, ordered_locs).sum(-1) + + +@torch.jit.script +def get_distance_matrix(locs: Tensor): + """Compute the euclidean distance matrix for the given coordinates. + + Args: + locs: Tensor of shape [..., n, dim] + """ + distance = (locs[..., :, None, :] - locs[..., None, :, :]).norm(p=2, dim=-1) + return distance + + +def calculate_entropy(logprobs: Tensor): + """Calculate the entropy of the log probabilities distribution + logprobs: Tensor of shape [batch, decoder_steps, num_actions] + """ + logprobs = torch.nan_to_num(logprobs, nan=0.0) + entropy = -(logprobs.exp() * logprobs).sum(dim=-1) # [batch, decoder steps] + entropy = entropy.sum(dim=1) # [batch] -- sum over decoding steps + assert entropy.isfinite().all(), "Entropy is not finite" + return entropy + + +# TODO: modularize inside the envs +def get_num_starts(td, env_name=None): + """Returns the number of possible start nodes for the environment based on the action mask""" + num_starts = td["action_mask"].shape[-1] + if env_name == "pdp": + num_starts = ( + num_starts - 1 + ) // 2 # only half of the nodes (i.e. pickup nodes) can be start nodes + elif env_name in ["cvrp", "cvrptw", "sdvrp", "mtsp", "op", "pctsp", "spctsp"]: + num_starts = num_starts - 1 # depot cannot be a start node + + return num_starts + + +def select_start_nodes(td, env, num_starts): + """Node selection strategy as proposed in POMO (Kwon et al. 2020) + and extended in SymNCO (Kim et al. 2022). + Selects different start nodes for each batch element + + Args: + td: TensorDict containing the data. We may need to access the available actions to select the start nodes + env: Environment may determine the node selection strategy + num_starts: Number of nodes to select. This may be passed when calling the policy directly. See :class:`rl4co.models.AutoregressiveDecoder` + """ + num_loc = env.generator.num_loc if hasattr(env.generator, "num_loc") else 0xFFFFFFFF + if env.name in ["tsp", "atsp"]: + selected = ( + torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) + % num_loc + ) + else: + # Environments with depot: we do not select the depot as a start node + selected = ( + torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) + % num_loc + + 1 + ) + if env.name == "op": + if (td["action_mask"][..., 1:].float().sum(-1) < num_starts).any(): + # for the orienteering problem, we may have some nodes that are not available + # so we need to resample from the distribution of available nodes + selected = ( + torch.multinomial( + td["action_mask"][..., 1:].float(), num_starts, replacement=True + ) + + 1 + ) # re-add depot index + selected = rearrange(selected, "b n -> (n b)") + return selected + + +def get_best_actions(actions, max_idxs): + actions = unbatchify(actions, max_idxs.shape[0]) + return actions.gather(0, max_idxs[..., None, None]) + + +def sparsify_graph(cost_matrix: Tensor, k_sparse: Optional[int] = None, self_loop=False): + """Generate a sparsified graph for the cost_matrix by selecting k edges with the lowest cost for each node. + + Args: + cost_matrix: Tensor of shape [m, n] + k_sparse: Number of edges to keep for each node. Defaults to max(n//5, 10) if not provided. + self_loop: Include self-loop edges in the generated graph when m==n. Defaults to False. + """ + m, n = cost_matrix.shape + k_sparse = max(n // 5, 10) if k_sparse is None else k_sparse + + # fill diagonal value with +inf to exclude them from topk results + if not self_loop and m == n: + # k_sparse should not exceed n-1 in this occasion + k_sparse = min(k_sparse, n - 1) + cost_matrix.fill_diagonal_(torch.inf) + + # select top-k edges with least cost + topk_values, topk_indices = torch.topk( + cost_matrix, k=k_sparse, dim=-1, largest=False, sorted=False + ) + + # generate PyG-compatiable edge_index + edge_index_u = torch.repeat_interleave( + torch.arange(m, device=cost_matrix.device), topk_indices.shape[1] + ) + edge_index_v = topk_indices.flatten() + edge_index = torch.stack([edge_index_u, edge_index_v]) + + edge_attr = topk_values.flatten().unsqueeze(-1) + return edge_index, edge_attr + + +@lru_cache(5) +def get_full_graph_edge_index(num_node: int, self_loop=False) -> Tensor: + adj_matrix = torch.ones(num_node, num_node) + if not self_loop: + adj_matrix.fill_diagonal_(0) + edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) + return edge_index diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py new file mode 100644 index 00000000..46367a37 --- /dev/null +++ b/rl4co/utils/optim_helpers.py @@ -0,0 +1,38 @@ +import inspect + +import torch +from torch.optim import Optimizer + + +def get_pytorch_lr_schedulers(): + """Get all learning rate schedulers from `torch.optim.lr_scheduler`""" + return torch.optim.lr_scheduler.__all__ + + +def get_pytorch_optimizers(): + """Get all optimizers from `torch.optim`""" + optimizers = [] + for name, obj in inspect.getmembers(torch.optim): + if inspect.isclass(obj) and issubclass(obj, Optimizer): + optimizers.append(name) + return optimizers + + +def create_optimizer(parameters, optimizer_name: str, **optimizer_kwargs) -> Optimizer: + """Create optimizer for model. If `optimizer_name` is not found, raise ValueError.""" + if optimizer_name in get_pytorch_optimizers(): + optimizer_cls = getattr(torch.optim, optimizer_name) + return optimizer_cls(parameters, **optimizer_kwargs) + else: + raise ValueError(f"Optimizer {optimizer_name} not found.") + + +def create_scheduler( + optimizer: Optimizer, scheduler_name: str, **scheduler_kwargs +) -> torch.optim.lr_scheduler.LRScheduler: + """Create scheduler for optimizer. If `scheduler_name` is not found, raise ValueError.""" + if scheduler_name in get_pytorch_lr_schedulers(): + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_name) + return scheduler_cls(optimizer, **scheduler_kwargs) + else: + raise ValueError(f"Scheduler {scheduler_name} not found.") diff --git a/rl4co/utils/param_grouping.py b/rl4co/utils/param_grouping.py new file mode 100644 index 00000000..529205a9 --- /dev/null +++ b/rl4co/utils/param_grouping.py @@ -0,0 +1,138 @@ +import inspect + +import hydra +import torch.nn as nn + + +def group_parameters_for_optimizer( + model, optimizer_cfg, bias_weight_decay=False, normalization_weight_decay=False +): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + # Get the weight decay from the config, or from the default value of the optimizer constructor + # if it's not specified in the config. + if "weight_decay" in optimizer_cfg: + weight_decay = optimizer_cfg.weight_decay + else: + # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value + signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) + if "weight_decay" in signature.parameters: + weight_decay = signature.parameters["weight_decay"].default + if weight_decay is inspect.Parameter.empty: + weight_decay = 0.0 + else: + weight_decay = 0.0 + + # If none of the parameters have weight decay anyway, and there are no parameters with special + # optimization params + if weight_decay == 0.0 and not any(hasattr(p, "_optim") for p in model.parameters()): + return model.parameters() + + skip = model.no_weight_decay() if hasattr(model, "no_weight_decay") else set() + skip_keywords = ( + model.no_weight_decay_keywords() + if hasattr(model, "no_weight_decay_keywords") + else set() + ) + + # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear,) + blacklist_weight_modules = nn.Embedding + if not normalization_weight_decay: + blacklist_weight_modules += ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LazyBatchNorm1d, + nn.LazyBatchNorm2d, + nn.LazyBatchNorm3d, + nn.GroupNorm, + nn.SyncBatchNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + nn.LocalResponseNorm, + ) + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + # In case of parameter sharing, some parameters show up here but are not in + # param_dict.keys() + if not p.requires_grad or fpn not in param_dict: + continue # frozen weights + if hasattr(p, "_optim"): + special.add(fpn) + elif fpn in skip or any( + skip_keyword in fpn for skip_keyword in skip_keywords + ): + no_decay.add(fpn) + elif getattr(p, "_no_weight_decay", False): + no_decay.add(fpn) + elif not bias_weight_decay and pn.endswith("bias"): + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + if "pos_emb" in param_dict: + no_decay.add("pos_emb") + + decay |= param_dict.keys() - no_decay - special + # validate that we considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert ( + len(param_dict.keys() - special - union_params) == 0 + ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + if weight_decay == 0.0 or not no_decay: + param_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], + "weight_decay": weight_decay, + } + ] + else: + param_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0, + }, + ] + # Add parameters with special hyperparameters + # Unique dicts + hps = [ + dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special) + ] + for hp in hps: + params = [param_dict[pn] for pn in special if param_dict[pn]._optim == hp] + param_groups.append({"params": params, **hp}) + + return param_groups diff --git a/rl4co/utils/pylogger.py b/rl4co/utils/pylogger.py new file mode 100644 index 00000000..aa1b5f1a --- /dev/null +++ b/rl4co/utils/pylogger.py @@ -0,0 +1,25 @@ +import logging + +from lightning.pytorch.utilities.rank_zero import rank_zero_only + + +def get_pylogger(name=__name__) -> logging.Logger: + """Initializes multi-GPU-friendly python command line logger.""" + + logger = logging.getLogger(name) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) + for level in logging_levels: + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger diff --git a/rl4co/utils/rich_utils.py b/rl4co/utils/rich_utils.py new file mode 100644 index 00000000..652ba568 --- /dev/null +++ b/rl4co/utils/rich_utils.py @@ -0,0 +1,97 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree + +from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from rl4co.utils.utils import pylogger + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + # "data", # note: data is dealt with in model + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = True, + save_to_file: bool = False, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + Args: + cfg (DictConfig): Configuration composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + save_to_file (bool, optional): Whether to export config to the hydra output folder. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config.""" + + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/rl4co/utils/test_utils.py b/rl4co/utils/test_utils.py new file mode 100644 index 00000000..baa7f31d --- /dev/null +++ b/rl4co/utils/test_utils.py @@ -0,0 +1,62 @@ +from torch.utils.data import DataLoader + +from rl4co.envs import ( + CVRPEnv, + CVRPTWEnv, + DPPEnv, + MDPPEnv, + MTSPEnv, + OPEnv, + PCTSPEnv, + PDPEnv, + SDVRPEnv, + SMTWTPEnv, + SPCTSPEnv, + TSPEnv, +) + + +def get_env(name, size): + if name == "tsp": + env = TSPEnv(generator_params=dict(num_loc=size)) + elif name == "cvrp": + env = CVRPEnv(generator_params=dict(num_loc=size)) + elif name == "cvrptw": + env = CVRPTWEnv(generator_params=dict(num_loc=size)) + elif name == "sdvrp": + env = SDVRPEnv(generator_params=dict(num_loc=size)) + elif name == "pdp": + env = PDPEnv(generator_params=dict(num_loc=size)) + elif name == "op": + env = OPEnv(generator_params=dict(num_loc=size)) + elif name == "mtsp": + env = MTSPEnv(generator_params=dict(num_loc=size)) + elif name == "pctsp": + env = PCTSPEnv(generator_params=dict(num_loc=size)) + elif name == "spctsp": + env = SPCTSPEnv(generator_params=dict(num_loc=size)) + elif name == "dpp": + env = DPPEnv() + elif name == "mdpp": + env = MDPPEnv() + elif name == "smtwtp": + env = SMTWTPEnv() + else: + raise ValueError(f"Unknown env_name: {name}") + + return env.transform() + + +def generate_env_data(env, size, batch_size): + env = get_env(env, size) + dataset = env.dataset([batch_size]) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + collate_fn=dataset.collate_fn, + ) + + return env, next(iter(dataloader)) diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py new file mode 100644 index 00000000..77350a50 --- /dev/null +++ b/rl4co/utils/trainer.py @@ -0,0 +1,152 @@ +from typing import Iterable, List, Optional, Union + +import lightning.pytorch as pl +import torch + +from lightning import Callback, Trainer +from lightning.fabric.accelerators.cuda import num_cuda_devices +from lightning.pytorch.accelerators import Accelerator +from lightning.pytorch.core.datamodule import LightningDataModule +from lightning.pytorch.loggers import Logger +from lightning.pytorch.strategies import DDPStrategy, Strategy +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS + +from rl4co import utils + +log = utils.get_pylogger(__name__) + + +class RL4COTrainer(Trainer): + """Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. + + Note: + The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. + This allows for datasets to be re-created on the run and distributed by Lightning across + devices on each epoch. Setting to a value different than 1 may lead to overfitting to a + specific (such as the initial) data distribution. + + Args: + accelerator: hardware accelerator to use. + callbacks: list of callbacks. + logger: logger (or iterable collection of loggers) for experiment tracking. + min_epochs: minimum number of training epochs. + max_epochs: maximum number of training epochs. + strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). + devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. + gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. + precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). + This also allows to use `FlashAttention` by default. + disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. + auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. + reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. + matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision + **kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. + """ + + def __init__( + self, + accelerator: Union[str, Accelerator] = "auto", + callbacks: Optional[List[Callback]] = None, + logger: Optional[Union[Logger, Iterable[Logger]]] = None, + min_epochs: Optional[int] = None, + max_epochs: Optional[int] = None, + strategy: Union[str, Strategy] = "auto", + devices: Union[List[int], str, int] = "auto", + gradient_clip_val: Union[int, float] = 1.0, + precision: Union[str, int] = "16-mixed", + reload_dataloaders_every_n_epochs: int = 1, + disable_profiling_executor: bool = True, + auto_configure_ddp: bool = True, + matmul_precision: Union[str, int] = "medium", + **kwargs, + ): + # Disable JIT profiling executor. This reduces memory and increases speed. + # Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 + if disable_profiling_executor: + try: + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + except AttributeError: + pass + + # Configure DDP automatically if multiple GPUs are available + if auto_configure_ddp and strategy == "auto": + if devices == "auto": + n_devices = num_cuda_devices() + elif isinstance(devices, list): + n_devices = len(devices) + else: + n_devices = devices + if n_devices > 1: + log.info( + "Configuring DDP strategy automatically with {} GPUs".format( + n_devices + ) + ) + strategy = DDPStrategy( + find_unused_parameters=True, # We set to True due to RL envs + gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations + ) + + # Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision + if matmul_precision is not None: + torch.set_float32_matmul_precision(matmul_precision) + + # Check if gradient_clip_val is set to None + if gradient_clip_val is None: + log.warning( + "gradient_clip_val is set to None. This may lead to unstable training." + ) + + # We should reload dataloaders every epoch for RL training + if reload_dataloaders_every_n_epochs != 1: + log.warning( + "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " + + "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." + ) + + # Main call to `Trainer` superclass + super().__init__( + accelerator=accelerator, + callbacks=callbacks, + logger=logger, + min_epochs=min_epochs, + max_epochs=max_epochs, + strategy=strategy, + gradient_clip_val=gradient_clip_val, + devices=devices, + precision=precision, + reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, + **kwargs, + ) + + def fit( + self, + model: "pl.LightningModule", + train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, + val_dataloaders: Optional[EVAL_DATALOADERS] = None, + datamodule: Optional[LightningDataModule] = None, + ckpt_path: Optional[str] = None, + ) -> None: + """ + We override the `fit` method to automatically apply and handle RL4CO magic + to 'self.automatic_optimization = False' models, such as PPO + + It behaves exactly like the original `fit` method, but with the following changes: + - if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None + """ + + if not model.automatic_optimization: + if self.gradient_clip_val is not None: + log.warning( + "Overriding gradient_clip_val to None for 'automatic_optimization=False' models" + ) + self.gradient_clip_val = None + + super().fit( + model=model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + ckpt_path=ckpt_path, + ) diff --git a/rl4co/utils/utils.py b/rl4co/utils/utils.py new file mode 100644 index 00000000..547e0e2d --- /dev/null +++ b/rl4co/utils/utils.py @@ -0,0 +1,287 @@ +import importlib +import platform +import sys +import warnings + +from importlib.util import find_spec +from typing import Callable, List + +import hydra + +from lightning import Callback +from lightning.pytorch.loggers.logger import Logger + +# Import the necessary PyTorch Lightning component +from lightning.pytorch.trainer.connectors.accelerator_connector import ( + _AcceleratorConnector, +) +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from omegaconf import DictConfig + +from rl4co.utils import pylogger, rich_utils + +log = pylogger.get_pylogger(__name__) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that wraps the task function in extra utilities. + + Makes multirun more resistant to failure. + + Utilities: + - Calling the `utils.extras()` before the task is started + - Calling the `utils.close_loggers()` after the task is finished or failed + - Logging the exception if occurs + - Logging the output dir + """ + + def wrap(cfg: DictConfig): + # execute the task + try: + # apply extra utilities + extras(cfg) + + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # when using hydra plugins like Optuna, you might want to disable raising exception + # to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # close loggers (even if exception occurs so multirun won't fail) + close_loggers() + + return metric_dict, object_dict + + return wrap + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + """ + + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + ## Note: we do not use the data config, since it is dealt with in the model + ## which is a `LightningModule` + # hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) + + +def get_metric_value(metric_dict: dict, metric_name: str) -> float: + """Safely retrieves value of the metric logged in LightningModule.""" + + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value + + +def close_loggers() -> None: + """Makes sure all loggers closed properly (prevents logging failure during multirun).""" + + log.info("Closing loggers...") + + if find_spec("wandb"): # if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content) + + +def merge_with_defaults(_config=None, **defaults) -> dict: + """Merge configuration with default values. + + This function merges a provided configuration dictionary with default values. + If no configuration is provided (`_config` is None), it returns the default values. + If a dictionary is provided, it updates the defaults dictionary with the values from the provided dictionary. + Otherwise, it sets all keys in the defaults dictionary to `_config`. + + Args: + _config: Configuration to merge. Defaults to None. + **defaults: Default values to merge with the configuration. + + Returns: + dict: Merged configuration with default values. + """ + if _config is None: + return defaults + elif isinstance(_config, (DictConfig, dict)): + defaults.update(dict(**_config)) # type: ignore + return defaults + else: + return {key: _config for key in defaults.keys()} + + +def show_versions(): + """ + This function prints version information that is useful when filing bug + reports. Inspired by https://github.com/PyVRP/PyVRP + """ + + modules = { + "rl4co": "rl4co", + "torch": "torch", + "lightning": "pytorch_lightning", # Updated module name if necessary + "torchrl": "torchrl", + "tensordict": "tensordict", + "numpy": "numpy", + "pytorch_geometric": "torch_geometric", + "hydra-core": "hydra", + "omegaconf": "omegaconf", + "matplotlib": "matplotlib", + } + + # Find the longest module name for formatting + longest_name = max(len(name) for name in modules.keys()) + + print("INSTALLED VERSIONS") + print("-" * (longest_name + 20)) + # modules + for name, module in modules.items(): + try: + imported_module = importlib.import_module(module) + version = imported_module.__version__ + except ImportError: + version = "Not installed" + print(f"{name.rjust(longest_name)} : {version}") + # platform information + print(f'{"Python".rjust(longest_name)} : {sys.version.split()[0]}') + print(f'{"Platform".rjust(longest_name)} : {platform.platform()}') + try: + lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator(None) + except Exception: + lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator() + # lightning hardware accelerators + print(f'{"Lightning device".rjust(longest_name)} : {lightning_auto_device}') From 991c51642e15521232683054016b3d0e9413ac27 Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Mon, 13 May 2024 17:49:00 +0800 Subject: [PATCH 4/6] update mtvrp --- rl4co/envs/__init__.py | 3 +- rl4co/envs/scheduling/__init__.py | 1 + rl4co/envs/scheduling/fjsp/__init__.py | 2 + rl4co/envs/scheduling/fjsp/env.py | 424 ++++++++++++++++++++++++ rl4co/envs/scheduling/fjsp/generator.py | 216 ++++++++++++ rl4co/envs/scheduling/fjsp/parser.py | 180 ++++++++++ rl4co/envs/scheduling/fjsp/render.py | 72 ++++ rl4co/envs/scheduling/fjsp/utils.py | 333 +++++++++++++++++++ rl4co/models/__init__.py | 1 + rl4co/models/nn/ops.py | 30 ++ rl4co/models/rl/reinforce/reinforce.py | 2 +- rl4co/models/zoo/__init__.py | 1 + rl4co/models/zoo/hetgnn/__init__.py | 1 + rl4co/models/zoo/hetgnn/decoder.py | 51 +++ rl4co/models/zoo/hetgnn/encoder.py | 132 ++++++++ rl4co/models/zoo/hetgnn/model.py | 38 +++ rl4co/models/zoo/hetgnn/policy.py | 99 ++++++ rl4co/tasks/eval.py | 16 +- rl4co/utils/decoding.py | 21 +- rl4co/utils/ops.py | 30 ++ rl4co/utils/trainer.py | 2 +- 21 files changed, 1640 insertions(+), 15 deletions(-) create mode 100644 rl4co/envs/scheduling/fjsp/__init__.py create mode 100644 rl4co/envs/scheduling/fjsp/env.py create mode 100644 rl4co/envs/scheduling/fjsp/generator.py create mode 100644 rl4co/envs/scheduling/fjsp/parser.py create mode 100644 rl4co/envs/scheduling/fjsp/render.py create mode 100644 rl4co/envs/scheduling/fjsp/utils.py create mode 100644 rl4co/models/zoo/hetgnn/__init__.py create mode 100644 rl4co/models/zoo/hetgnn/decoder.py create mode 100644 rl4co/models/zoo/hetgnn/encoder.py create mode 100644 rl4co/models/zoo/hetgnn/model.py create mode 100644 rl4co/models/zoo/hetgnn/policy.py diff --git a/rl4co/envs/__init__.py b/rl4co/envs/__init__.py index 17961f6f..ac588739 100644 --- a/rl4co/envs/__init__.py +++ b/rl4co/envs/__init__.py @@ -22,7 +22,7 @@ ) # Scheduling -from rl4co.envs.scheduling import FFSPEnv, SMTWTPEnv +from rl4co.envs.scheduling import FFSPEnv, FJSPEnv, SMTWTPEnv # Register environments ENV_REGISTRY = { @@ -31,6 +31,7 @@ "cvrptw": CVRPTWEnv, "dpp": DPPEnv, "ffsp": FFSPEnv, + "fjsp": FJSPEnv, "mdpp": MDPPEnv, "mtsp": MTSPEnv, "op": OPEnv, diff --git a/rl4co/envs/scheduling/__init__.py b/rl4co/envs/scheduling/__init__.py index 1c63820f..897ee755 100644 --- a/rl4co/envs/scheduling/__init__.py +++ b/rl4co/envs/scheduling/__init__.py @@ -1,2 +1,3 @@ from rl4co.envs.scheduling.ffsp.env import FFSPEnv +from rl4co.envs.scheduling.fjsp.env import FJSPEnv from rl4co.envs.scheduling.smtwtp.env import SMTWTPEnv diff --git a/rl4co/envs/scheduling/fjsp/__init__.py b/rl4co/envs/scheduling/fjsp/__init__.py new file mode 100644 index 00000000..4eb6d9df --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/__init__.py @@ -0,0 +1,2 @@ +NO_OP_ID = -1 +INIT_FINISH = 9999.0 diff --git a/rl4co/envs/scheduling/fjsp/env.py b/rl4co/envs/scheduling/fjsp/env.py new file mode 100644 index 00000000..4a6a217f --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/env.py @@ -0,0 +1,424 @@ +import torch + +from einops import rearrange, reduce +from tensordict.tensordict import TensorDict +from torchrl.data import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) + +from rl4co.envs.common.base import RL4COEnvBase as EnvBase +from rl4co.utils.ops import gather_by_index, sample_n_random_actions + +from . import INIT_FINISH, NO_OP_ID +from .generator import FJSPFileGenerator, FJSPGenerator +from .render import render +from .utils import calc_lower_bound, get_job_ops_mapping, op_is_ready + + +class FJSPEnv(EnvBase): + """Flexible Job-Shop Scheduling Problem (FJSP) environment + At each step, the agent chooses a job-machine combination. The operation to be processed next for the selected job is + then executed on the selected machine. The reward is 0 unless the agent scheduled all operations of all jobs. + In that case, the reward is (-)makespan of the schedule: maximizing the reward is equivalent to minimizing the makespan. + + Observations: + - time: current time + - next_op: next operation per job + - proc_times: processing time of operation-machine pairs + - pad_mask: specifies padded operations + - start_op_per_job: id of first operation per job + - end_op_per_job: id of last operation per job + - start_times: start time of operation (defaults to 0 if not scheduled) + - finish_times: finish time of operation (defaults to INIT_FINISH if not scheduled) + - job_ops_adj: adjacency matrix specifying job-operation affiliation + - ops_job_map: same as above but using ids of jobs to indicate affiliation + - ops_sequence_order: specifies the order in which operations have to be processed + - ma_assignment: specifies which operation has been scheduled on which machine + - busy_until: specifies until when the machine will be busy + - num_eligible: number of machines that can process an operation + - job_in_process: whether job is currently being processed + - job_done: whether the job is done + + Constrains: + the agent may not select: + - machines that are currently busy + - jobs that are done already + - jobs that are currently processed + - job-machine combinations, where the machine cannot process the next operation of the job + + Finish condition: + - the agent has scheduled all operations of all jobs + + Reward: + - the negative makespan of the final schedule + + Args: + generator: FJSPGenerator instance as the data generator + generator_params: parameters for the generator + mask_no_ops: if True, agent may not select waiting operation (unless instance is done) + """ + + name = "fjsp" + + def __init__( + self, + generator: FJSPGenerator = None, + generator_params: dict = {}, + mask_no_ops: bool = True, + **kwargs, + ): + super().__init__(check_solution=False, **kwargs) + if generator is None: + if generator_params.get("file_path", None) is not None: + generator = FJSPFileGenerator(**generator_params) + else: + generator = FJSPGenerator(**generator_params) + self.generator = generator + self.num_mas = generator.num_mas + self.num_jobs = generator.num_jobs + self.n_ops_max = generator.max_ops_per_job * self.num_jobs + self.mask_no_ops = mask_no_ops + self._make_spec(self.generator) + + def _decode_graph_structure(self, td: TensorDict): + batch_size = td.batch_size + start_op_per_job = td["start_op_per_job"] + end_op_per_job = td["end_op_per_job"] + pad_mask = td["pad_mask"] + n_ops_max = td["pad_mask"].size(-1) + + # here we will generate the operations-job mapping: + ops_job_map, ops_job_bin_map = get_job_ops_mapping( + start_op_per_job, end_op_per_job, n_ops_max + ) + + # mask invalid edges (caused by padding) + ops_job_bin_map[pad_mask.unsqueeze(1).expand_as(ops_job_bin_map)] = 0 + + # generate for each batch a sequence specifying the position of all operations in their respective jobs, + # e.g. [0,1,0,0,1,2,0,1,2,3,0,0] for jops with n_ops=[2,1,3,4,1,1] + # (bs, max_ops) + ops_seq_order = torch.sum( + ops_job_bin_map * (ops_job_bin_map.cumsum(2) - 1), dim=1 + ) + + # predecessor and successor adjacency matrices + pred = torch.diag_embed(torch.ones(n_ops_max - 1), offset=-1)[None].expand( + *batch_size, -1, -1 + ) + # the start of the sequence (of each job) does not have a predecessor, therefore we can + # mask all first ops of a job in the predecessor matrix + pred = pred * ops_seq_order.gt(0).unsqueeze(-1).expand_as(pred).to(pred) + succ = torch.diag_embed(torch.ones(n_ops_max - 1), offset=1)[None].expand( + *batch_size, -1, -1 + ) + # apply the same logic as above to mask the last op of a job, which does not have a successor. The last job of a job + # always comes before the 1st op of the next job, therefore performing a left shift of the ops seq tensor here + succ = succ * torch.cat( + (ops_seq_order[:, 1:], ops_seq_order.new_full((*batch_size, 1), 0)), dim=1 + ).gt(0).to(succ).unsqueeze(-1).expand_as(succ) + + # adjacency matrix = predecessors, successors and self loops + # (bs, max_ops, max_ops, 2) + ops_adj = torch.stack((pred, succ), dim=3) + + td = td.update( + { + "ops_adj": ops_adj, + "job_ops_adj": ops_job_bin_map, + "ops_job_map": ops_job_map, + # "op_spatial_enc": ops_spatial_enc, + "ops_sequence_order": ops_seq_order, + } + ) + + return td, n_ops_max + + def _reset(self, td: TensorDict = None, batch_size=None) -> TensorDict: + td_reset = td.clone() + + td_reset, n_ops_max = self._decode_graph_structure(td_reset) + + # schedule + start_op_per_job = td_reset["start_op_per_job"] + start_times = torch.zeros((*batch_size, n_ops_max)) + finish_times = torch.full((*batch_size, n_ops_max), INIT_FINISH) + ma_assignment = torch.zeros((*batch_size, self.num_mas, n_ops_max)) + + # reset feature space + busy_until = torch.zeros((*batch_size, self.num_mas)) + # (bs, ma, ops) + ops_ma_adj = (td_reset["proc_times"] > 0).to(torch.float32) + # (bs, ops) + num_eligible = torch.sum(ops_ma_adj, dim=1) + + td_reset = td_reset.update( + { + "start_times": start_times, + "finish_times": finish_times, + "ma_assignment": ma_assignment, + "busy_until": busy_until, + "num_eligible": num_eligible, + "next_op": start_op_per_job.clone().to(torch.int64), + "ops_ma_adj": ops_ma_adj, + "op_scheduled": torch.full((*batch_size, n_ops_max), False), + "job_in_process": torch.full((*batch_size, self.num_jobs), False), + "reward": torch.zeros((*batch_size,), dtype=torch.float32), + "time": torch.zeros((*batch_size,)), + "job_done": torch.full((*batch_size, self.num_jobs), False), + "done": torch.full((*batch_size, 1), False), + }, + ) + + td_reset.set("lbs", calc_lower_bound(td_reset)) + td_reset.set("is_ready", op_is_ready(td_reset)) + td_reset.set("action_mask", self.get_action_mask(td_reset)) + + return td_reset + + def get_action_mask(self, td: TensorDict) -> torch.Tensor: + batch_size = td.size(0) + + # (bs, jobs, machines) + action_mask = torch.full((batch_size, self.num_jobs, self.num_mas), False).to( + td.device + ) + + # mask jobs that are done already + action_mask.add_(td["job_done"].unsqueeze(2)) + # as well as jobs that are currently processed + action_mask.add_(td["job_in_process"].unsqueeze(2)) + + # mask machines that are currently busy + action_mask.add_(td["busy_until"].gt(td["time"].unsqueeze(1)).unsqueeze(1)) + + # exclude job-machine combinations, where the machine cannot process the next op of the job + next_ops_proc_times = gather_by_index( + td["proc_times"], td["next_op"].unsqueeze(1), dim=2, squeeze=False + ).transpose(1, 2) + action_mask.add_(next_ops_proc_times == 0) + if self.mask_no_ops: + no_op_mask = ~td["done"] + else: + no_op_mask = ~td["job_in_process"].any(1, keepdims=True) & ~td["done"] + # flatten action mask to correspond with logit shape + action_mask = rearrange(action_mask, "bs j m -> bs (j m)") + # NOTE: 1 means feasible action, 0 means infeasible action + mask = torch.cat((~no_op_mask, ~action_mask), dim=1) + return mask + + def _step(self, td: TensorDict): + # cloning required to avoid inplace operation which avoids gradient backtracking + td = td.clone() + td["action"].subtract_(1) + # (bs) + dones = td["done"].squeeze(1) + # specify which batch instances require which operation + no_op = td["action"].eq(NO_OP_ID) + no_op = no_op & ~dones + req_op = ~no_op & ~dones + + # transition to next time for no op instances + if no_op.any(): + td, dones = self._transit_to_next_time(no_op, td) + + td_op = td.masked_select(req_op) + + # (#req_op) + selected_job = td_op["action"] // self.num_mas + # (#req_op) + selected_machine = td_op["action"] % self.num_mas + td_op = self._make_step(td_op, selected_job, selected_machine) + + td[req_op] = td_op + + # action mask + td.set("action_mask", self.get_action_mask(td)) + + step_complete = self._check_step_complete(td, dones) + while step_complete.any(): + td, dones = self._transit_to_next_time(step_complete, td) + td.set("action_mask", self.get_action_mask(td)) + step_complete = self._check_step_complete(td, dones) + + # after we have transitioned to a next time step, we determine which operations are ready + td["is_ready"] = op_is_ready(td) + + td["lbs"] = calc_lower_bound(td) + + return td + + @staticmethod + def _check_step_complete(td, dones): + """check whether there a feasible actions left to be taken during the current + time step. If this is not the case (and the instance is not done), + we need to adance the timer of the repsective instance + """ + return ~reduce(td["action_mask"], "bs ... -> bs", "any") & ~dones + + def _make_step(self, td: TensorDict, selected_job, selected_machine) -> TensorDict: + """ + Environment transition function + """ + + batch_idx = torch.arange(td.size(0)) + + td["job_in_process"][batch_idx, selected_job] = 1 + + # (#req_op) + selected_op = td["next_op"].gather(1, selected_job[:, None]).squeeze(1) + + # mark op as schedules + td["op_scheduled"][batch_idx, selected_op] = True + + # update machine state + proc_time_of_action = td["proc_times"][batch_idx, selected_machine, selected_op] + # we may not select a machine that is busy + assert torch.all(td["busy_until"][batch_idx, selected_machine] <= td["time"]) + + # update schedule + td["start_times"][batch_idx, selected_op] = td["time"] + td["finish_times"][batch_idx, selected_op] = td["time"] + proc_time_of_action + td["ma_assignment"][batch_idx, selected_machine, selected_op] = 1 + # update the state of the selected machine + td["busy_until"][batch_idx, selected_machine] = td["time"] + proc_time_of_action + + return td + + def _transit_to_next_time(self, step_complete, td: TensorDict) -> TensorDict: + """ + Transit to the next time + """ + + # we need a transition to a next time step if either + # 1.) all machines are busy + # 2.) all operations are already currently in process (can only happen if num_jobs < num_machines) + # 3.) idle machines can not process any of the not yet scheduled operations + # 4.) no_op is choosen + available_time_ma = td["busy_until"] + end_op_per_job = td["end_op_per_job"] + # we want to transition to the next time step where a machine becomes idle again. This time step must be + # in the future, therefore we mask all machine idle times lying in the past / present + available_time = ( + torch.where( + available_time_ma > td["time"][:, None], available_time_ma, torch.inf + ) + .min(1) + .values + ) + + assert not torch.any(available_time[step_complete].isinf()) + td["time"] = torch.where(step_complete, available_time, td["time"]) + + # this may only be set when the operation is finished, not when it is scheduled + # operation of job is finished, set next operation and flag job as being idle + curr_ops_end = td["finish_times"].gather(1, td["next_op"]) + op_finished = td["job_in_process"] & (curr_ops_end <= td["time"][:, None]) + # check whether a job is finished, which is the case when the last operation of the job is finished + job_finished = op_finished & (td["next_op"] == end_op_per_job) + # determine the next operation for a job that is not done, but whose latest operation is finished + td["next_op"] = torch.where( + op_finished & ~job_finished, + td["next_op"] + 1, + td["next_op"], + ) + td["job_in_process"][op_finished] = False + + td["job_done"] = td["job_done"] + job_finished + td["done"] = td["job_done"].all(1, keepdim=True) + + return td, td["done"].squeeze(1) + + def _get_reward(self, td, actions=None) -> TensorDict: + return -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values + + def _make_spec(self, generator: FJSPGenerator): + self.observation_spec = CompositeSpec( + time=UnboundedDiscreteTensorSpec( + shape=(1,), + dtype=torch.int64, + ), + next_op=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs,), + dtype=torch.int64, + ), + proc_times=UnboundedDiscreteTensorSpec( + shape=(self.num_mas, self.n_ops_max), + dtype=torch.float32, + ), + pad_mask=UnboundedDiscreteTensorSpec( + shape=(self.num_mas, self.n_ops_max), + dtype=torch.bool, + ), + start_op_per_job=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs,), + dtype=torch.bool, + ), + end_op_per_job=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs,), + dtype=torch.bool, + ), + start_times=UnboundedDiscreteTensorSpec( + shape=(self.n_ops_max,), + dtype=torch.int64, + ), + finish_times=UnboundedDiscreteTensorSpec( + shape=(self.n_ops_max,), + dtype=torch.int64, + ), + job_ops_adj=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs, self.n_ops_max), + dtype=torch.int64, + ), + ops_job_map=UnboundedDiscreteTensorSpec( + shape=(self.n_ops_max), + dtype=torch.int64, + ), + ops_sequence_order=UnboundedDiscreteTensorSpec( + shape=(self.n_ops_max), + dtype=torch.int64, + ), + ma_assignment=UnboundedDiscreteTensorSpec( + shape=(self.num_mas, self.n_ops_max), + dtype=torch.int64, + ), + busy_until=UnboundedDiscreteTensorSpec( + shape=(self.num_mas,), + dtype=torch.int64, + ), + num_eligible=UnboundedDiscreteTensorSpec( + shape=(self.n_ops_max,), + dtype=torch.int64, + ), + job_in_process=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs,), + dtype=torch.bool, + ), + job_done=UnboundedDiscreteTensorSpec( + shape=(self.num_jobs,), + dtype=torch.bool, + ), + shape=(), + ) + self.action_spec = BoundedTensorSpec( + shape=(1,), + dtype=torch.int64, + low=-1, + high=self.n_ops_max, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,)) + self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool) + + @staticmethod + def render(td, idx): + return render(td, idx) + + def select_start_nodes(self, td: TensorDict, num_starts: int): + return sample_n_random_actions(td, num_starts) + + def get_num_starts(self, td): + # NOTE in the paper they use N_s = 100 + return 100 diff --git a/rl4co/envs/scheduling/fjsp/generator.py b/rl4co/envs/scheduling/fjsp/generator.py new file mode 100644 index 00000000..f1ae6202 --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/generator.py @@ -0,0 +1,216 @@ +from functools import partial +from typing import List + +import numpy as np +import torch + +from tensordict.tensordict import TensorDict + +from rl4co.envs.common.utils import Generator +from rl4co.utils.pylogger import get_pylogger + +from .parser import get_max_ops_from_files, read + +log = get_pylogger(__name__) + + +class FJSPGenerator(Generator): + + """Data generator for the Flexible Job-Shop Scheduling Problem (FJSP). + + Args: + num_stage: number of stages + num_machine: number of machines + num_job: number of jobs + min_time: minimum running time of each job on each machine + max_time: maximum running time of each job on each machine + flatten_stages: whether to flatten the stages + + Returns: + A TensorDict with the following key: + start_op_per_job [batch_size, num_jobs]: first operation of each job + end_op_per_job [batch_size, num_jobs]: last operation of each job + proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines + pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used + + """ + + def __init__( + self, + num_jobs: int = 10, + num_machines: int = 5, + min_ops_per_job: int = 4, + max_ops_per_job: int = 6, + min_processing_time: int = 1, + max_processing_time: int = 20, + min_eligible_ma_per_op: int = 1, + max_eligible_ma_per_op: int = None, + **unused_kwargs, + ): + self.num_jobs = num_jobs + self.num_mas = num_machines + self.min_ops_per_job = min_ops_per_job + self.max_ops_per_job = max_ops_per_job + self.min_processing_time = min_processing_time + self.max_processing_time = max_processing_time + self.min_eligible_ma_per_op = min_eligible_ma_per_op + self.max_eligible_ma_per_op = max_eligible_ma_per_op or num_machines + # determines whether to use a fixed number of total operations or let it vary between instances + # NOTE: due to the way rl4co builds datasets, we need a fixed size here + self.n_ops_max = max_ops_per_job * num_jobs + + # FFSP environment doen't have any other kwargs + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + def _simulate_processing_times( + self, n_eligible_per_ops: torch.Tensor + ) -> torch.Tensor: + bs, n_ops_max = n_eligible_per_ops.shape + + # (bs, max_ops, machines) + ma_seq_per_ops = torch.arange(1, self.num_mas + 1)[None, None].expand( + bs, n_ops_max, self.num_mas + ) + # generate a matrix of size (ops, mas) per batch, each row having as many ones as the operation eligible machines + # E.g. n_eligible_per_ops=[1,3,2]; num_mas=4 + # [[1,0,0,0], + # 1,1,1,0], + # 1,1,0,0]] + # This will be shuffled randomly to generate a machine-operation mapping + ma_ops_edges_unshuffled = torch.Tensor.float( + ma_seq_per_ops <= n_eligible_per_ops[..., None] + ) + # random shuffling + idx = torch.rand_like(ma_ops_edges_unshuffled).argsort() + ma_ops_edges = ma_ops_edges_unshuffled.gather(2, idx).transpose(1, 2) + + # (bs, max_ops, machines) + proc_times = torch.ones((bs, n_ops_max, self.num_mas)) + proc_times = torch.randint( + self.min_processing_time, + self.max_processing_time + 1, + size=(bs, self.num_mas, n_ops_max), + ) + + # remove proc_times for which there is no corresponding ma-ops connection + proc_times = proc_times * ma_ops_edges + return proc_times + + def _generate(self, batch_size) -> TensorDict: + # simulate how many operations each job has + n_ope_per_job = torch.randint( + self.min_ops_per_job, + self.max_ops_per_job + 1, + size=(*batch_size, self.num_jobs), + ) + + # determine the total number of operations per batch instance (which may differ) + n_ops_batch = n_ope_per_job.sum(1) # (bs) + # determine the maximum total number of operations over all batch instances + n_ops_max = self.n_ops_max or n_ops_batch.max() + + # generate a mask, specifying which operations are padded + pad_mask = torch.arange(n_ops_max).unsqueeze(0).expand(*batch_size, -1) + pad_mask = pad_mask.ge(n_ops_batch[:, None].expand_as(pad_mask)) + + # determine the id of the end operation for each job + end_op_per_job = n_ope_per_job.cumsum(1) - 1 + + # determine the id of the starting operation for each job + # (bs, num_jobs) + start_op_per_job = torch.cat( + ( + torch.zeros((*batch_size, 1)).to(end_op_per_job), + end_op_per_job[:, :-1] + 1, + ), + dim=1, + ) + + # here we simulate the eligible machines per operation and the processing times + n_eligible_per_ops = torch.randint( + self.min_eligible_ma_per_op, + self.max_eligible_ma_per_op + 1, + (*batch_size, n_ops_max), + ) + n_eligible_per_ops[pad_mask] = 0 + + # simulate processing times for machine-operation pairs + # (bs, num_mas, n_ops_max) + proc_times = self._simulate_processing_times(n_eligible_per_ops) + + td = TensorDict( + { + "start_op_per_job": start_op_per_job, + "end_op_per_job": end_op_per_job, + "proc_times": proc_times, + "pad_mask": pad_mask, + }, + batch_size=batch_size, + ) + + return td + + +class FJSPFileGenerator(Generator): + """Data generator for the Flexible Job-Shop Scheduling Problem (FJSP) using instance files + + Args: + path: path to files + + Returns: + A TensorDict with the following key: + start_op_per_job [batch_size, num_jobs]: first operation of each job + end_op_per_job [batch_size, num_jobs]: last operation of each job + proc_times [batch_size, num_machines, total_n_ops]: processing time of ops on machines + pad_mask [batch_size, total_n_ops]: not all instances have the same number of ops, so padding is used + + """ + + def __init__(self, file_path: str, n_ops_max: int = None, **unused_kwargs): + self.files = self.list_files(file_path) + self.num_samples = len(self.files) + + if len(unused_kwargs) > 0: + log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") + + if len(self.files) > 1: + n_ops_max = get_max_ops_from_files(self.files) + + ret = map(partial(read, max_ops=n_ops_max), self.files) + + td_list, num_jobs, num_machines, max_ops_per_job = list(zip(*list(ret))) + num_jobs, num_machines = map(lambda x: x[0], (num_jobs, num_machines)) + max_ops_per_job = max(max_ops_per_job) + + self.td = torch.cat(td_list, dim=0) + self.num_mas = num_machines + self.num_jobs = num_jobs + self.max_ops_per_job = max_ops_per_job + self.start_idx = 0 + + def _generate(self, batch_size: List[int]) -> TensorDict: + batch_size = np.prod(batch_size) + if batch_size > self.num_samples: + log.warning( + f"Only found {self.num_samples} instance files, but specified dataset size is {batch_size}" + ) + end_idx = self.start_idx + batch_size + td = self.td[self.start_idx : end_idx] + self.start_idx += batch_size + return td + + @staticmethod + def list_files(path): + import os + + files = [ + os.path.join(path, f) + for f in os.listdir(path) + if os.path.isfile(os.path.join(path, f)) + ] + assert len(files) > 0 + files = sorted( + files, key=lambda f: int(os.path.splitext(os.path.basename(f))[0][:4]) + ) + return files diff --git a/rl4co/envs/scheduling/fjsp/parser.py b/rl4co/envs/scheduling/fjsp/parser.py new file mode 100644 index 00000000..f05c8fca --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/parser.py @@ -0,0 +1,180 @@ +import os + +from functools import partial +from pathlib import Path +from typing import List, Tuple, Union + +import torch + +from tensordict import TensorDict + +ProcessingData = List[Tuple[int, int]] + + +def list_files(path): + import os + + files = [ + os.path.join(path, f) + for f in os.listdir(path) + if os.path.isfile(os.path.join(path, f)) + ] + return files + + +def parse_job_line(line: Tuple[int]) -> Tuple[ProcessingData]: + """ + Parses a FJSPLIB job data line of the following form: + + * ( * ( )) + + In words, the first value is the number of operations. Then, for each + operation, the first number represents the number of machines that can + process the operation, followed by, the machine index and processing time + for each eligible machine. + + Note that the machine indices start from 1, so we subtract 1 to make them + zero-based. + """ + num_operations = line[0] + operations = [] + idx = 1 + + for _ in range(num_operations): + num_pairs = int(line[idx]) * 2 + machines = line[idx + 1 : idx + 1 + num_pairs : 2] + durations = line[idx + 2 : idx + 2 + num_pairs : 2] + operations.append([(m, d) for m, d in zip(machines, durations)]) + + idx += 1 + num_pairs + + return operations + + +def get_n_ops_of_instance(file): + lines = file2lines(file) + jobs = [parse_job_line(line) for line in lines[1:]] + n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) + total_ops = int(n_ope_per_job.sum()) + return total_ops + + +def get_max_ops_from_files(files): + return max(map(get_n_ops_of_instance, files)) + + +def read(loc: Path, max_ops=None): + """ + Reads an FJSPLIB instance. + + Args: + loc: location of instance file + max_ops: optionally specify the maximum number of total operations (will be filled by padding) + + Returns: + instance: the parsed instance + """ + lines = file2lines(loc) + + # First line contains metadata. + num_jobs, num_machines = lines[0][0], lines[0][1] + + # The remaining lines contain the job-operation data, where each line + # represents a job and its operations. + jobs = [parse_job_line(line) for line in lines[1:]] + n_ope_per_job = torch.Tensor([len(x) for x in jobs]).unsqueeze(0) + total_ops = int(n_ope_per_job.sum()) + if max_ops is not None: + assert total_ops <= max_ops, "got more operations then specified through max_ops" + max_ops = max_ops or total_ops + max_ops_per_job = int(n_ope_per_job.max()) + + end_op_per_job = n_ope_per_job.cumsum(1) - 1 + start_op_per_job = torch.cat((torch.zeros((1, 1)), end_op_per_job[:, :-1] + 1), dim=1) + + pad_mask = torch.arange(max_ops) + pad_mask = pad_mask.ge(total_ops).unsqueeze(0) + + proc_times = torch.zeros((num_machines, max_ops)) + op_cnt = 0 + for job in jobs: + for op in job: + for ma, dur in op: + # subtract one to let indices start from zero + proc_times[ma - 1, op_cnt] = dur + op_cnt += 1 + proc_times = proc_times.unsqueeze(0) + + td = TensorDict( + { + "start_op_per_job": start_op_per_job, + "end_op_per_job": end_op_per_job, + "proc_times": proc_times, + "pad_mask": pad_mask, + }, + batch_size=[1], + ) + + return td, num_jobs, num_machines, max_ops_per_job + + +def file2lines(loc: Union[Path, str]) -> List[List[int]]: + with open(loc, "r") as fh: + lines = [line for line in fh.readlines() if line.strip()] + + def parse_num(word: str): + return int(word) if "." not in word else int(float(word)) + + return [[parse_num(x) for x in line.split()] for line in lines] + + +def write_one(args, where=None): + id, instance = args + assert ( + len(instance["proc_times"].shape) == 2 + ), "no batch dimension allowed in write operation" + lines = [] + + # The flexibility is the average number of eligible machines per operation. + num_eligible = (instance["proc_times"] > 0).sum() + n_ops = (~instance["pad_mask"]).sum() + num_jobs = instance["next_op"].size(0) + num_machines = instance["proc_times"].size(0) + flexibility = round(int(num_eligible) / int(n_ops), 5) + + metadata = f"{num_jobs}\t{num_machines}\t{flexibility}" + lines.append(metadata) + + for i in range(num_jobs): + ops_of_job = instance["job_ops_adj"][i].nonzero().squeeze(1) + job = [len(ops_of_job)] # number of operations of the job + + for op in ops_of_job: + eligible_ma = instance["proc_times"][:, op].nonzero().squeeze(1) + job.append(eligible_ma.size(0)) # num_eligible + + for machine in eligible_ma: + duration = instance["proc_times"][machine, op] + assert duration > 0, "something is wrong" + # add one since in song instances ma indices start from one + job.extend([int(machine.item()) + 1, int(duration.item())]) + + line = " ".join(str(num) for num in job) + lines.append(line) + + formatted = "\n".join(lines) + + file_name = f"{str(id+1).rjust(4, '0')}_{num_jobs}j_{num_machines}m.txt" + full_path = os.path.join(where, file_name) + + with open(full_path, "w") as fh: + fh.write(formatted) + + return formatted + + +def write(where: Union[Path, str], instances: TensorDict): + if not os.path.exists(where): + os.makedirs(where) + + return list(map(partial(write_one, where=where), enumerate(iter(instances)))) diff --git a/rl4co/envs/scheduling/fjsp/render.py b/rl4co/envs/scheduling/fjsp/render.py new file mode 100644 index 00000000..bfb86bf4 --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/render.py @@ -0,0 +1,72 @@ +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np + +from matplotlib.colors import ListedColormap +from tensordict.tensordict import TensorDict + +from rl4co.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +def render(td: TensorDict, idx: int): + inst = td[idx] + num_jobs = inst["job_ops_adj"].size(0) + + # Define a colormap with a color for each job + colors = plt.cm.tab10(np.linspace(0, 1, num_jobs)) + cmap = ListedColormap(colors) + + assign = inst["ma_assignment"].nonzero() + + schedule = defaultdict(list) + + for val in assign: + machine = val[0].item() + op = val[1].item() + # get start and end times of operation + start = inst["start_times"][val[1]] + end = inst["finish_times"][val[1]] + # write information to schedule dictionary + schedule[machine].append((op, start, end)) + + _, ax = plt.subplots() + + # Plot horizontal bars for each task + for ma, ops in schedule.items(): + for op, start, end in ops: + job = inst["job_ops_adj"][:, op].nonzero().item() + ax.barh( + ma, + end - start, + left=start, + height=0.6, + color=cmap(job), + edgecolor="black", + linewidth=1, + ) + + ax.text( + start + (end - start) / 2, ma, op, ha="center", va="center", color="white" + ) + + # Set labels and title + ax.set_yticks(range(len(schedule))) + ax.set_yticklabels([f"Machine {i}" for i in range(len(schedule))]) + ax.set_xlabel("Time") + ax.set_title("Gantt Chart") + + # Add a legend for class labels + handles = [plt.Rectangle((0, 0), 1, 1, color=cmap(i)) for i in range(num_jobs)] + ax.legend( + handles, + [f"Job {label}" for label in range(num_jobs)], + loc="center left", + bbox_to_anchor=(1, 0.5), + ) + + plt.tight_layout() + # Show the Gantt chart + plt.show() diff --git a/rl4co/envs/scheduling/fjsp/utils.py b/rl4co/envs/scheduling/fjsp/utils.py new file mode 100644 index 00000000..b3ee40b8 --- /dev/null +++ b/rl4co/envs/scheduling/fjsp/utils.py @@ -0,0 +1,333 @@ +import logging + +from typing import List, Tuple, Union + +import torch + +from tensordict import TensorDict +from torch import Size, Tensor + +from rl4co.envs.scheduling.fjsp import INIT_FINISH + +logger = logging.getLogger(__name__) + + +def get_op_features(td: TensorDict): + return torch.stack((td["lbs"], td["is_ready"], td["num_eligible"]), dim=-1) + + +def cat_and_norm_features( + td: TensorDict, feats: List[str], time_feats: List[str], norm_const: int +): + # logger.info(f"will scale the features {','.join(time_feats)} with a constant ({norm_const})") + feature_list = [] + for feat in feats: + if feat in time_feats: + feature_list.append(td[feat] / norm_const) + else: + feature_list.append(td[feat]) + + return torch.stack(feature_list, dim=-1).to(torch.float32) + + +def view( + tensor: Tensor, + idx: Tuple[Tensor], + pad_mask: Tensor, + new_shape: Union[Size, List[int]], + pad_value: Union[float, int], +): + # convert mask specifying which entries are padded into mask specifying which entries to keep + mask = ~pad_mask + new_view = tensor.new_full(size=new_shape, fill_value=pad_value) + new_view[idx] = tensor[mask] + return new_view + + +def _get_idx_for_job_op_view(td: TensorDict) -> tuple: + bs, _, n_total_ops = td["job_ops_adj"].shape + # (bs, ops) + batch_idx = torch.arange(bs, device=td.device).repeat_interleave(n_total_ops) + batch_idx = batch_idx.reshape(bs, -1) + # (bs, ops) + ops_job_map = td["ops_job_map"] + # (bs, ops) + ops_sequence_order = td["ops_sequence_order"] + # (bs*n_ops_max, 3) + idx = ( + torch.stack((batch_idx, ops_job_map, ops_sequence_order), dim=-1) + .to(torch.long) + .flatten(0, 1) + ) + # (bs, n_ops_max) + mask = ~td["pad_mask"] + # (total_ops_in_batch, 3) + idx = idx[mask.flatten(0, 1)] + b, j, o = map(lambda x: x.squeeze(1), idx.chunk(3, dim=-1)) + return b, j, o + + +def get_job_op_view( + td: TensorDict, keys: List[str] = [], pad_value: Union[float, int] = 0 +): + """This function reshapes all tensors of the tensordict from a flat operations-only view + to a nested job-operation view and creates a new tensordict from it. + :param _type_ td: tensordict + :return _type_: dict + """ + # ============= Prepare the new index ============= + bs, num_jobs, _ = td["job_ops_adj"].shape + max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) + idx = _get_idx_for_job_op_view(td) + new_shape = Size((bs, num_jobs, max_ops_per_job)) + pad_mask = td["pad_mask"] + # ============================================== + + # due to special structure, processing times are treated seperately + if "proc_times" in keys: + keys.remove("proc_times") + # reshape processing times; (bs, ma, ops) -> (bs, ma, jobs, ops_per_job) + new_proc_times_view = view( + td["proc_times"].permute(0, 2, 1), idx, pad_mask, new_shape, pad_value + ).permute(0, 3, 1, 2) + + # add padding mask if not in keys + if "pad_mask" not in keys: + keys.append("pad_mask") + + new_views = dict( + map(lambda key: (key, view(td[key], idx, pad_mask, new_shape)), keys) + ) + + # update tensordict clone with reshaped tensors + return {"proc_times": new_proc_times_view, **new_views} + + +def blockify(td, tensor: Tensor, pad_value: Union[float, int] = 0): + assert len(tensor.shape) in [ + 2, + 3, + ], "blockify only supports tensors of shape (bs, seq, (d)), where the feature dim d is optional" + # get the size of the blockified tensor + bs, _, *d = tensor.shape + num_jobs = td["job_ops_adj"].size(1) + max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) + new_shape = Size((bs, num_jobs, max_ops_per_job, *d)) + # get indices of valid entries of blockified tensor + idx = _get_idx_for_job_op_view(td) + pad_mask = td["pad_mask"] + # create the blockified view + new_view_tensor = view(tensor, idx, pad_mask, new_shape, pad_value) + return new_view_tensor + + +def unblockify( + td: TensorDict, tensor: Tensor, mask: Tensor = None, pad_value: Union[float, int] = 0 +): + assert len(tensor.shape) in [ + 3, + 4, + ], "blockify only supports tensors of shape (bs, nb, s, (d)), where the feature dim d is optional" + # get the size of the blockified tensor + bs, _, _, *d = tensor.shape + n_ops_per_batch = td["job_ops_adj"].sum((1, 2)).unsqueeze(1) # (bs) + seq_len = int(n_ops_per_batch.max()) + new_shape = Size((bs, seq_len, *d)) + + # create the mask to gather then entries of the blockified tensor. NOTE that only by + # blockifying the original pad_mask + pad_mask = td["pad_mask"] + pad_mask = blockify(td, pad_mask, True) + + # get indices of valid entrie in flat matrix + b = torch.arange(bs, device=td.device).repeat_interleave(seq_len).reshape(bs, seq_len) + i = torch.arange(seq_len, device=td.device)[None].repeat(bs, 1) + idx = tuple(map(lambda x: x[i < n_ops_per_batch], (b, i))) + # create view + new_tensor = view(tensor, idx, pad_mask, new_shape, pad_value=pad_value) + return new_tensor + + +def first_diff(x: Tensor, dim: int): + shape = x.shape + shape = (*shape[:dim], 1, *shape[dim + 1 :]) + seq_cutoff = x.index_select(dim, torch.arange(x.size(dim) - 1, device=x.device)) + lagged_seq = x - torch.cat((seq_cutoff.new_zeros(*shape), seq_cutoff), dim=dim) + return lagged_seq + + +def spatial_encoding(td: TensorDict): + """We use a spatial encoing as proposed in GraphFormer (https://arxiv.org/abs/2106.05234) + The spatial encoding in GraphFormer determines the distance of the shortest path between and + nodes i and j and uses a special value for node pairs that cannot be connected at all. + For any two operations i e=2) and for i>j the negative number of + operations that starting from j, have been completet before arriving at i (e.g. i=5 j=3 -> e=-2). + For i=j we set e=0 as well as for operations of different jobs. + + :param torch.Tensor[bs, n_ops] ops_job_map: tensor specifying the index of its corresponding job + :return torch.Tensor[bs, n_ops, n_ops]: length of shortest path between any two operations + """ + bs, _, n_total_ops = td["job_ops_adj"].shape + max_ops_per_job = int(td["job_ops_adj"].sum(-1).max()) + ops_job_map = td["ops_job_map"] + pad_mask = td["pad_mask"] + + same_job = (ops_job_map[:, None] == ops_job_map[..., None]).to(torch.int32) + # mask padded + same_job[pad_mask.unsqueeze(2).expand_as(same_job)] = 0 + same_job[pad_mask.unsqueeze(1).expand_as(same_job)] = 0 + # take upper triangular of same_job and set diagonal to zero for counting purposes + upper_tri = torch.triu(same_job) - torch.diag( + torch.ones(n_total_ops, device=td.device) + )[None].expand_as(same_job) + # cumsum and masking of operations that do not belong to the same job + num_jumps = upper_tri.cumsum(2) * upper_tri + # mirror the matrix + num_jumps = num_jumps + num_jumps.transpose(1, 2) + # NOTE: shifted this logic into the spatial encoding module + # num_jumps = num_jumps + (-num_jumps.transpose(1,2)) + assert not torch.any(num_jumps >= max_ops_per_job) + # special value for ops of different jobs and self-loops + num_jumps = torch.where(num_jumps == 0, -1, num_jumps) + self_mask = torch.eye(n_total_ops).repeat(bs, 1, 1).bool() + num_jumps[self_mask] = 0 + return num_jumps + + +def calc_lower_bound(td: TensorDict): + """Here we calculate the lower bound of the operations finish times. In the FJSP case, multiple things need to + be taken into account due to the usability of the different machines for multiple ops of different jobs: + + 1.) Operations may only start once their direct predecessor is finished. We calculate its lower bound by + adding the minimum possible operation time to this detected start time. However, we cannot use the proc_times + directly, but need to account for the fact, that machines might still be busy, once an operation can be processed. + We detect this offset by detecting ops-machine pairs, where the first possible start point of the operation is before + the machine becomes idle again - Therefore, we add this discrepancy to the proc_time of the respective ops-ma combination + + 2.) If an operation has been scheduled, we use its real finishing time as lower bound. In this case, using the cumulative sum + of all peedecessors of a job does not make sense, since it is likely to differ from the real finishing time of its direct + predecessor (its only a lower bound). Therefore, we add the finish time to the cumulative sum of processing time of all + UNSCHEDULED operations, to obtain the lower bound. + Making this work is a bit hacky: We compute the first differences of finishing times of those operations scheduled and + add them to the matrix of processing times, where already processed operations are masked (with zero) + + + :param TensorDict td: _description_ + :return _type_: _description_ + """ + + proc_times = td["proc_times"].clone() # (bs, ma, ops) + busy_until = td["busy_until"] # (bs, ma) + ops_adj = td["ops_adj"] # (bs, ops, ops, 2) + finish_times = td["finish_times"] # (bs, ops) + job_ops_adj = td["job_ops_adj"] # (bs, jobs, ops) + op_scheduled = td["op_scheduled"].to(torch.float32) # (bs, ops) + + ############## REGARDING POINT 1 OF DOCSTRING ############## + # for operations whose immidiate predecessor is scheduled, we can determine its earliest + # start time by the end time of the predecessor. + # (bs, num_ops, 1) + maybe_start_at = torch.bmm(ops_adj[..., 0], finish_times[..., None]).squeeze(2) + # using the start_time, we can determine if and how long an op needs to wait for a machine to finish + wait_for_ma_offset = torch.clip(busy_until[..., None] - maybe_start_at[:, None], 0) + # we add this required waiting time to the respective processing time - after that we determine the best machine for each operation + mask = proc_times == 0 + proc_times[mask] = torch.inf + proc_times += wait_for_ma_offset + # select best machine for operation, given the offset + min_proc_times = proc_times.min(1).values + + ############### REGARDING POINT 2 OF DOCSTRING ################### + # Now we determine all operations that are not scheduled yet (and thus have no finish_time). We will compute the cumulative + # sum over the processing time to determine the lower bound of unscheduled operations... + proc_matrix = job_ops_adj + ops_assigned = proc_matrix * op_scheduled[:, None] + proc_matrix_not_scheduled = proc_matrix * ( + torch.ones_like(proc_matrix) - op_scheduled[:, None] + ) + + # ...and add the finish_time of the last scheduled operation of the respective job to that. To make this work, using the cumsum logic, + # we calc the first differences of the finish times and seperate by job. + # We use the first differences, so that the finish times do not add up during cumulative sum below + # (bs, num_jobs, num_ops) + finish_times_1st_diff = ops_assigned * first_diff( + ops_assigned * finish_times[:, None], 2 + ) + + # masking the processing time of scheduled operations and add their finish times instead (first diff thereof) + lb_end_expand = ( + proc_matrix_not_scheduled * min_proc_times.unsqueeze(1).expand_as(job_ops_adj) + + finish_times_1st_diff + ) + # (bs, max_ops); lower bound finish time per operation using the cumsum logic + LBs = torch.sum(job_ops_adj * lb_end_expand.cumsum(-1), dim=1) + # remove nans + LBs = torch.nan_to_num(LBs, nan=0.0) + + # test + assert torch.where( + finish_times != INIT_FINISH, torch.isclose(LBs, finish_times), True + ).all() + + return LBs + + +def op_is_ready(td: TensorDict): + # compare finish times of predecessors with current time step; shape=(b, n_ops_max) + is_ready = ( + torch.bmm(td["ops_adj"][..., 0], td["finish_times"][..., None]).squeeze(2) + <= td["time"][:, None] + ) + # shape=(b, n_ops_max) + is_scheduled = td["ma_assignment"].sum(1).bool() + # op is ready for scheduling if it has not been scheduled and its predecessor is finished + return torch.logical_and(is_ready, ~is_scheduled) + + +def get_job_ops_mapping( + start_op_per_job: torch.Tensor, end_op_per_job: torch.Tensor, n_ops_max: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Implements a mapping function from operations to jobs + + :param torch.Tensor start_op_per_job: index of first operation of each job + :param torch.Tensor end_op_per_job: index of last operation of each job + :return Tuple[torch.Tensor, torch.Tensor]: + 1st.) index mapping (bs, num_ops): [0,0,1,1,1] means that first two operations belong to job 0 + 2st.) binary mapping (bs, num_jobs, num_ops): [[1,1,0], [0,0,1]] means that first two operations belong to job 0 + """ + device = end_op_per_job.device + end_op_per_job = end_op_per_job.clone() + + bs, num_jobs = end_op_per_job.shape + + # in order to avoid shape conflicts, set the end operation id to the id of max_ops (all batches have same #ops) + end_op_per_job[:, -1] = n_ops_max - 1 + + # here we will generate the operations-job mapping: + # Therefore we first generate a sequence of operation ids and expand it the the size of the mapping matrix: + # (bs, jobs, max_ops) + ops_seq_exp = torch.arange(n_ops_max, device=device)[None, None].expand( + bs, num_jobs, -1 + ) + # (bs, jobs, max_ops) # expanding start and end operation ids + end_op_per_job_exp = end_op_per_job[..., None].expand_as(ops_seq_exp) + start_op_per_job_exp = start_op_per_job[..., None].expand_as(ops_seq_exp) + # given ids of start and end operations per job, this generates the mapping of ops to jobs + # (bs, jobs, max_ops) + ops_job_map = torch.nonzero( + (ops_seq_exp <= end_op_per_job_exp) & (ops_seq_exp >= start_op_per_job_exp) + ) + # (bs, max_ops) + ops_job_map = torch.stack(ops_job_map[:, 1].split(n_ops_max), dim=0) + + # we might also want a binary mapping / adjacency matrix connecting jobs to operations + # (bs, num_jobs, num_ops) + ops_job_bin_map = torch.scatter_add( + input=ops_job_map.new_zeros((bs, num_jobs, n_ops_max)), + dim=1, + index=ops_job_map.unsqueeze(1), + src=ops_job_map.new_ones((bs, num_jobs, n_ops_max)), + ) + + return ops_job_map, ops_job_bin_map diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index a99a4278..0ebec158 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -19,6 +19,7 @@ from rl4co.models.rl.ppo.ppo import PPO from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, get_reinforce_baseline from rl4co.models.rl.reinforce.reinforce import REINFORCE +from rl4co.models.zoo import HetGNNModel from rl4co.models.zoo.active_search import ActiveSearch from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy from rl4co.models.zoo.amppo import AMPPO diff --git a/rl4co/models/nn/ops.py b/rl4co/models/nn/ops.py index 04fec365..ebbb5063 100644 --- a/rl4co/models/nn/ops.py +++ b/rl4co/models/nn/ops.py @@ -1,5 +1,6 @@ import math +import torch import torch.nn as nn @@ -35,3 +36,32 @@ def forward(self, x): else: assert self.normalizer is None, "Unknown normalizer type" return x + + +class PositionalEncoding(nn.Module): + def __init__(self, embed_dim: int, dropout: float = 0.1, max_len: int = 1000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + self.d_model = embed_dim + max_len = max_len + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model) + ) + pe = torch.zeros(max_len, 1, self.d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + pe = pe.transpose(0, 1) # [1, max_len, d_model] + self.register_buffer("pe", pe) + + def forward(self, hidden: torch.Tensor, seq_pos) -> torch.Tensor: + """ + Arguments: + x: Tensor, shape ``[batch_size, seq_len, embedding_dim]`` + seq_pos: Tensor, shape ``[batch_size, seq_len]`` + """ + pes = self.pe.expand(hidden.size(0), -1, -1).gather( + 1, seq_pos.unsqueeze(-1).expand(-1, -1, self.d_model) + ) + hidden = hidden + pes + return self.dropout(hidden) diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 477750c4..360c17aa 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -53,7 +53,7 @@ def shared_step( ): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) - out = self.policy(td, self.env, phase=phase) + out = self.policy(td, self.env, phase=phase, select_best=phase != "train") # Compute loss if phase == "train": diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py index 7796c630..c16bbe9b 100644 --- a/rl4co/models/zoo/__init__.py +++ b/rl4co/models/zoo/__init__.py @@ -10,6 +10,7 @@ HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy, ) +from rl4co.models.zoo.hetgnn import HetGNNModel from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy from rl4co.models.zoo.nargnn import NARGNNPolicy diff --git a/rl4co/models/zoo/hetgnn/__init__.py b/rl4co/models/zoo/hetgnn/__init__.py new file mode 100644 index 00000000..f98562b4 --- /dev/null +++ b/rl4co/models/zoo/hetgnn/__init__.py @@ -0,0 +1 @@ +from .model import HetGNNModel diff --git a/rl4co/models/zoo/hetgnn/decoder.py b/rl4co/models/zoo/hetgnn/decoder.py new file mode 100644 index 00000000..68bf1d36 --- /dev/null +++ b/rl4co/models/zoo/hetgnn/decoder.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn + +from rl4co.models.common.constructive.autoregressive import AutoregressiveDecoder +from rl4co.models.nn.mlp import MLP +from rl4co.utils.ops import batchify, gather_by_index + + +class HetGNNDecoder(AutoregressiveDecoder): + def __init__( + self, embed_dim, feed_forward_hidden_dim: int = 64, feed_forward_layers: int = 2 + ) -> None: + super().__init__() + self.mlp = MLP( + input_dim=2 * embed_dim, + output_dim=1, + num_neurons=[feed_forward_hidden_dim] * feed_forward_layers, + ) + self.dummy = nn.Parameter(torch.rand(2 * embed_dim)) + + def pre_decoder_hook(self, td, env, hidden, num_starts): + return td, env, hidden + + def forward(self, td, hidden, num_starts): + if num_starts > 1: + hidden = tuple(map(lambda x: batchify(x, num_starts), hidden)) + + ma_emb, ops_emb = hidden + bs, n_rows, emb_dim = ma_emb.shape + + # (bs, n_jobs, emb) + job_emb = gather_by_index(ops_emb, td["next_op"], squeeze=False) + + # (bs, n_jobs, n_ma, emb) + job_emb_expanded = job_emb.unsqueeze(2).expand(-1, -1, n_rows, -1) + ma_emb_expanded = ma_emb.unsqueeze(1).expand_as(job_emb_expanded) + + # Input of actor MLP + # shape: [bs, num_jobs * n_ma, 2*emb] + h_actions = torch.cat((job_emb_expanded, ma_emb_expanded), dim=-1).flatten(1, 2) + no_ops = self.dummy[None, None].expand(bs, 1, -1) # [bs, 1, 2*emb_dim] + # [bs, num_jobs * n_ma + 1, 2*emb_dim] + h_actions_w_noop = torch.cat((no_ops, h_actions), 1) + + # (b, j*m) + mask = td["action_mask"] + + # (b, j*m) + logits = self.mlp(h_actions_w_noop).squeeze(-1) + + return logits, mask diff --git a/rl4co/models/zoo/hetgnn/encoder.py b/rl4co/models/zoo/hetgnn/encoder.py new file mode 100644 index 00000000..6f966cf8 --- /dev/null +++ b/rl4co/models/zoo/hetgnn/encoder.py @@ -0,0 +1,132 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import einsum +from torch import Tensor + +from rl4co.models.nn.env_embeddings import env_init_embedding +from rl4co.models.nn.ops import Normalization + + +class HetGNNLayer(nn.Module): + def __init__( + self, + embed_dim: int, + ) -> None: + super().__init__() + + self.self_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) + self.cross_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) + self.edge_attn = nn.Parameter(torch.rand(size=(embed_dim, 1), dtype=torch.float)) + self.activation = nn.ReLU() + self.scale = 1 / math.sqrt(embed_dim) + + def forward( + self, self_emb: Tensor, other_emb: Tensor, edge_emb: Tensor, edges: Tensor + ): + bs, n_rows, _ = self_emb.shape + + # concat operation embeddings and o-m edge features (proc times) + # Calculate attention coefficients + er = einsum(self_emb, self.self_attn, "b m e, e one -> b m") * self.scale + ec = einsum(other_emb, self.cross_attn, "b o e, e one -> b o") * self.scale + ee = einsum(edge_emb, self.edge_attn, "b m o e, e one -> b m o") * self.scale + + # element wise multiplication similar to broadcast column logits over rows with masking + ec_expanded = einsum(edges, ec, "b m o, b o -> b m o") + # element wise multiplication similar to broadcast row logits over cols with masking + er_expanded = einsum(edges, er, "b m o, b m -> b m o") + + # adding the projections of different node types and edges together (equivalent to first concat and then project) + # (bs, n_rows, n_cols) + cross_logits = self.activation(ec_expanded + ee + er_expanded) + + # (bs, n_rows, 1) + self_logits = self.activation(er + er).unsqueeze(-1) + + # (bs, n_ma, n_ops + 1) + mask = torch.cat( + ( + edges == 1, + torch.full( + size=(bs, n_rows, 1), + dtype=torch.bool, + fill_value=True, + device=edges.device, + ), + ), + dim=-1, + ) + + # (bs, n_ma, n_ops + 1) + all_logits = torch.cat((cross_logits, self_logits), dim=-1) + all_logits[~mask] = -torch.inf + attn_scores = F.softmax(all_logits, dim=-1) + # (bs, n_ma, n_ops) + cross_attn_scores = attn_scores[..., :-1] + # (bs, n_ma, 1) + self_attn_scores = attn_scores[..., -1].unsqueeze(-1) + + # augment column embeddings with edge features, (bs, r, c, e) + other_emb_aug = edge_emb + other_emb.unsqueeze(-3) + cross_emb = einsum(cross_attn_scores, other_emb_aug, "b m o, b m o e -> b m e") + self_emb = self_emb * self_attn_scores + # (bs, n_ma, emb_dim) + hidden = torch.sigmoid(cross_emb + self_emb) + return hidden + + +class HetGNNBlock(nn.Module): + def __init__(self, embed_dim) -> None: + super().__init__() + self.norm1 = Normalization(embed_dim, normalization="batch") + self.norm2 = Normalization(embed_dim, normalization="batch") + self.hgnn1 = HetGNNLayer(embed_dim) + self.hgnn2 = HetGNNLayer(embed_dim) + + def forward(self, x1, x2, edge_emb, edges): + h1 = self.hgnn1(x1, x2, edge_emb, edges) + h1 = self.norm1(h1 + x1) + + h2 = self.hgnn2(x2, x1, edge_emb.transpose(1, 2), edges.transpose(1, 2)) + h2 = self.norm2(h2 + x2) + + return h1, h2 + + +class HetGNNEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + num_layers: int = 2, + init_embedding=None, + edge_key: str = "ops_ma_adj", + edge_weights_key: str = "proc_times", + linear_bias: bool = False, + ) -> None: + super().__init__() + + if init_embedding is None: + init_embedding = env_init_embedding("fjsp", {"embed_dim": embed_dim}) + self.init_embedding = init_embedding + + self.edge_key = edge_key + self.edge_weights_key = edge_weights_key + + self.num_layers = num_layers + self.layers = nn.ModuleList([HetGNNBlock(embed_dim) for _ in range(num_layers)]) + + def forward(self, td): + edges = td[self.edge_key] + bs, n_rows, n_cols = edges.shape + row_emb, col_emb, edge_emb = self.init_embedding(td) + assert row_emb.size(1) == n_rows, "incorrect number of row embeddings" + assert col_emb.size(1) == n_cols, "incorrect number of column embeddings" + + for layer in self.layers: + row_emb, col_emb = layer(row_emb, col_emb, edge_emb, edges) + + return (row_emb, col_emb), None diff --git a/rl4co/models/zoo/hetgnn/model.py b/rl4co/models/zoo/hetgnn/model.py new file mode 100644 index 00000000..40f27de2 --- /dev/null +++ b/rl4co/models/zoo/hetgnn/model.py @@ -0,0 +1,38 @@ +from typing import Union + +from rl4co.envs.common.base import RL4COEnvBase +from rl4co.models.rl import REINFORCE +from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline + +from .policy import HetGNNPolicy + + +class HetGNNModel(REINFORCE): + """Heterogenous Graph Neural Network Model as described by Song et al. (2022): + 'Flexible Job Shop Scheduling via Graph Neural Network and Deep Reinforcement Learning' + + Args: + env: Environment to use for the algorithm + policy: Policy to use for the algorithm + baseline: REINFORCE baseline. Defaults to rollout (1 epoch of exponential, then greedy rollout baseline) + policy_kwargs: Keyword arguments for policy + baseline_kwargs: Keyword arguments for baseline + **kwargs: Keyword arguments passed to the superclass + """ + + def __init__( + self, + env: RL4COEnvBase, + policy: HetGNNPolicy = None, + baseline: Union[REINFORCEBaseline, str] = "rollout", + policy_kwargs={}, + baseline_kwargs={}, + **kwargs, + ): + assert ( + env.name == "fjsp" + ), "HetGNNModel currently only works for FJSP (Flexible Job-Shop Scheduling Problem)" + if policy is None: + policy = HetGNNPolicy(env_name=env.name, **policy_kwargs) + + super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/hetgnn/policy.py b/rl4co/models/zoo/hetgnn/policy.py new file mode 100644 index 00000000..c51dc30e --- /dev/null +++ b/rl4co/models/zoo/hetgnn/policy.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch.nn as nn + +from rl4co.models.common.constructive.autoregressive import ( + AutoregressiveDecoder, + AutoregressiveEncoder, + AutoregressivePolicy, +) +from rl4co.utils.pylogger import get_pylogger + +from .decoder import HetGNNDecoder +from .encoder import HetGNNEncoder + +log = get_pylogger(__name__) + + +class HetGNNPolicy(AutoregressivePolicy): + """ + Base Non-autoregressive policy for NCO construction methods. + This creates a heatmap of NxN for N nodes (i.e., heuristic) that models the probability to go from one node to another for all nodes. + + The policy performs the following steps: + 1. Encode the environment initial state into node embeddings + 2. Decode (non-autoregressively) to construct the solution to the NCO problem + + Warning: + The effectiveness of the non-autoregressive approach can vary significantly across different problem types and configurations. + It may require careful tuning of the model architecture and decoding strategy to achieve competitive results. + + Args: + encoder: Encoder module. Can be passed by sub-classes + decoder: Decoder module. Note that this moule defaults to the non-autoregressive decoder + embed_dim: Dimension of the embeddings + env_name: Name of the environment used to initialize embeddings + init_embedding: Model to use for the initial embedding. If None, use the default embedding for the environment + edge_embedding: Model to use for the edge embedding. If None, use the default embedding for the environment + graph_network: Model to use for the graph network. If None, use the default embedding for the environment + heatmap_generator: Model to use for the heatmap generator. If None, use the default embedding for the environment + num_layers_heatmap_generator: Number of layers in the heatmap generator + num_layers_graph_encoder: Number of layers in the graph encoder + act_fn: Activation function to use in the encoder + agg_fn: Aggregation function to use in the encoder + linear_bias: Whether to use bias in the encoder + train_decode_type: Type of decoding during training + val_decode_type: Type of decoding during validation + test_decode_type: Type of decoding during testing + **constructive_policy_kw: Unused keyword arguments + """ + + def __init__( + self, + encoder: Optional[AutoregressiveEncoder] = None, + decoder: Optional[AutoregressiveDecoder] = None, + embed_dim: int = 64, + num_encoder_layers: int = 2, + env_name: str = "fjsp", + init_embedding: Optional[nn.Module] = None, + linear_bias: bool = True, + train_decode_type: str = "sampling", + val_decode_type: str = "greedy", + test_decode_type: str = "multistart_sampling", + **constructive_policy_kw, + ): + if len(constructive_policy_kw) > 0: + log.warn(f"Unused kwargs: {constructive_policy_kw}") + + if encoder is None: + encoder = HetGNNEncoder( + embed_dim=embed_dim, + num_layers=num_encoder_layers, + init_embedding=init_embedding, + linear_bias=linear_bias, + ) + + # The decoder generates logits given the current td and heatmap + if decoder is None: + decoder = HetGNNDecoder( + embed_dim=embed_dim, + feed_forward_hidden_dim=embed_dim, + feed_forward_layers=2, + ) + else: + # check if the decoder has trainable parameters + if any(p.requires_grad for p in decoder.parameters()): + log.error( + "The decoder contains trainable parameters. This should not happen in a non-autoregressive policy." + ) + + # Pass to constructive policy + super(HetGNNPolicy, self).__init__( + encoder=encoder, + decoder=decoder, + env_name=env_name, + train_decode_type=train_decode_type, + val_decode_type=val_decode_type, + test_decode_type=test_decode_type, + **constructive_policy_kw, + ) diff --git a/rl4co/tasks/eval.py b/rl4co/tasks/eval.py index 5be1abcc..bfa7de0a 100644 --- a/rl4co/tasks/eval.py +++ b/rl4co/tasks/eval.py @@ -5,7 +5,7 @@ from tqdm.auto import tqdm from rl4co.data.transforms import StateAugmentation -from rl4co.utils.ops import batchify, gather_by_index, unbatchify +from rl4co.utils.ops import batchify, gather_by_index, sample_n_random_actions, unbatchify def check_unused_kwargs(class_, kwargs): @@ -169,23 +169,21 @@ def __init__(self, env, samples, softmax_temp=None, **kwargs): self.softmax_temp = softmax_temp def _inner(self, policy, td): - td = batchify(td, self.samples) out = policy( td.clone(), decode_type="sampling", - num_starts=0, + num_starts=self.samples, + multistart=True, return_actions=True, softmax_temp=self.softmax_temp, + select_best=True, + select_start_nodes_fn=lambda td, _, n: sample_n_random_actions(td, n), ) # Move into batches and compute rewards - rewards = self.env.get_reward(td, out["actions"]) - rewards = unbatchify(rewards, self.samples) - actions = unbatchify(out["actions"], self.samples) + rewards = out["reward"] + actions = out["actions"] - # Get the best reward and action for each sample - rewards, max_idxs = rewards.max(dim=1) - actions = gather_by_index(actions, max_idxs, dim=1) return actions, rewards diff --git a/rl4co/utils/decoding.py b/rl4co/utils/decoding.py index b0a0ae90..76a33b82 100644 --- a/rl4co/utils/decoding.py +++ b/rl4co/utils/decoding.py @@ -8,7 +8,7 @@ from tensordict.tensordict import TensorDict from rl4co.envs import RL4COEnvBase -from rl4co.utils.ops import batchify +from rl4co.utils.ops import batchify, unbatchify, unbatchify_and_gather from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) @@ -215,6 +215,7 @@ def __init__( multistart: bool = False, num_starts: Optional[int] = None, select_start_nodes_fn: Optional[callable] = None, + select_best: bool = False, **kwargs, ) -> None: self.temperature = temperature @@ -225,6 +226,7 @@ def __init__( self.multistart = multistart self.num_starts = num_starts self.select_start_nodes_fn = select_start_nodes_fn + self.select_best = select_best # initialize buffers self.actions = [] self.logprobs = [] @@ -293,8 +295,11 @@ def post_decoder_hook( assert ( len(self.logprobs) > 0 ), "No logprobs were collected because all environments were done. Check your initial state" - - return torch.stack(self.logprobs, 1), torch.stack(self.actions, 1), td, env + logprobs = torch.stack(self.logprobs, 1) + actions = torch.stack(self.actions, 1) + if self.num_starts > 0 and self.select_best: + logprobs, actions, td, env = self._select_best(logprobs, actions, td, env) + return logprobs, actions, td, env def step( self, @@ -360,6 +365,16 @@ def sampling(logprobs, mask=None): return selected + def _select_best(self, logprobs, actions, td: TensorDict, env: RL4COEnvBase): + rewards = env.get_reward(td, actions) + _, max_idxs = unbatchify(rewards, self.num_starts).max(dim=-1) + + actions = unbatchify_and_gather(actions, max_idxs, self.num_starts) + logprobs = unbatchify_and_gather(logprobs, max_idxs, self.num_starts) + td = unbatchify_and_gather(td, max_idxs, self.num_starts) + + return logprobs, actions, td, env + class Greedy(DecodingStrategy): name = "greedy" diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py index 35a7441b..c6ad889a 100644 --- a/rl4co/utils/ops.py +++ b/rl4co/utils/ops.py @@ -76,6 +76,14 @@ def gather_by_index(src, idx, dim=1, squeeze=True): return src.gather(dim, idx).squeeze() if squeeze else src.gather(dim, idx) +def unbatchify_and_gather(x: Tensor, idx: Tensor, n: int): + """first unbatchify a tensor by n and then gather (usually along the unbatchified dimension) + by the specified index + """ + x = unbatchify(x, n) + return gather_by_index(x, idx, dim=idx.dim()) + + @torch.jit.script def get_distance(x: Tensor, y: Tensor): """Euclidean distance between two tensors of shape `[..., n, dim]`""" @@ -146,6 +154,8 @@ def select_start_nodes(td, env, num_starts): torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc ) + elif env.name == "fjsp": + raise NotImplementedError("Multistart not yet supported for FJSP") else: # Environments with depot: we do not select the depot as a start node selected = ( @@ -212,3 +222,23 @@ def get_full_graph_edge_index(num_node: int, self_loop=False) -> Tensor: adj_matrix.fill_diagonal_(0) edge_index = torch.permute(torch.nonzero(adj_matrix), (1, 0)) return edge_index + + +def sample_n_random_actions(td: TensorDict, n: int): + """Helper function to sample n random actions from available actions. If + number of valid actions is less then n, we sample with replacement from the + valid actions + """ + action_mask = td["action_mask"] + # check whether to use replacement or not + n_valid_actions = torch.sum(action_mask[:, 1:], 1).min() + if n_valid_actions < n: + replace = True + else: + replace = False + ps = torch.rand((action_mask.shape)) + ps[~action_mask] = -torch.inf + ps = torch.softmax(ps, dim=1) + selected = torch.multinomial(ps, n, replacement=replace).squeeze(1) + selected = rearrange(selected, "b n -> (n b)") + return selected.to(td.device) diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py index 77350a50..0ad10fa1 100644 --- a/rl4co/utils/trainer.py +++ b/rl4co/utils/trainer.py @@ -73,7 +73,7 @@ def __init__( if auto_configure_ddp and strategy == "auto": if devices == "auto": n_devices = num_cuda_devices() - elif isinstance(devices, list): + elif isinstance(devices, Iterable): n_devices = len(devices) else: n_devices = devices From 1fd55d6663574e5ad168e2512ff5a16460ea0a91 Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Mon, 13 May 2024 17:53:36 +0800 Subject: [PATCH 5/6] Update init.py --- rl4co/models/nn/env_embeddings/init.py | 64 +++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index fef341aa..d5c00be0 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -2,7 +2,7 @@ import torch.nn as nn from tensordict.tensordict import TensorDict - +from rl4co.models.nn.ops import PositionalEncoding def env_init_embedding(env_name: str, config: dict) -> nn.Module: """Get environment initial embedding. The init embedding is used to initialize the @@ -30,6 +30,7 @@ def env_init_embedding(env_name: str, config: dict) -> nn.Module: "mtsp": MTSPInitEmbedding, "smtwtp": SMTWTPInitEmbedding, "mdcpdp": MDCPDPInitEmbedding, + "fjsp": FJSPFeatureEmbedding, "mtvrp":MTVRPInitEmbedding, } @@ -402,3 +403,64 @@ def forward(self, td): delivery_embeddings = self.init_embed_delivery(delivery_feats) # concatenate on graph size dimension return torch.cat([depot_embeddings, pick_embeddings, delivery_embeddings], -2) + +class FJSPFeatureEmbedding(nn.Module): + def __init__(self, embed_dim, linear_bias=True, norm_coef: int = 100): + super().__init__() + self.embed_dim = embed_dim + self.norm_coef = norm_coef + + self.init_ope_embed = nn.Linear(4, self.embed_dim, bias=False) + self.edge_embed = nn.Linear(1, embed_dim, bias=False) + + self.ope_pos_enc = PositionalEncoding(embed_dim) + # TODO allow for reencoding after each step + self.stepwise = False + + def forward(self, td: TensorDict): + if self.stepwise: + ops_emb = self._stepwise_operations_embed(td) + ma_emb = self._stepwise_machine_embed(td) + edge_emb = None + else: + ops_emb = self._init_operations_embed(td) + ma_emb = self._init_machine_embed(td) + edge_emb = self._init_edge_embed(td) + return ma_emb, ops_emb, edge_emb + + def _init_operations_embed(self, td: TensorDict): + pos = td["ops_sequence_order"] + + features = [ + td["lbs"].unsqueeze(-1) / self.norm_coef, + td["is_ready"].unsqueeze(-1), + td["num_eligible"].unsqueeze(-1), + td["ops_job_map"].unsqueeze(-1), + ] + features = torch.cat(features, dim=-1) + # (bs, num_ops, emb_dim) + ops_embeddings = self.init_ope_embed(features) + + # (bs, num_ops, emb_dim) + ops_embeddings = self.ope_pos_enc(ops_embeddings, pos.to(torch.int64)) + # zero out padded entries + ops_embeddings[td["pad_mask"].unsqueeze(-1).expand_as(ops_embeddings)] = 0 + return ops_embeddings + + def _init_machine_embed(self, td: TensorDict): + bs, num_ma = td["busy_until"].shape + ma_embeddings = torch.zeros( + (bs, num_ma, self.embed_dim), device=td.device, dtype=torch.float32 + ) + return ma_embeddings + + def _init_edge_embed(self, td: TensorDict): + proc_times = td["proc_times"].unsqueeze(-1) / self.norm_coef + edge_embed = self.edge_embed(proc_times) + return edge_embed + + def _stepwise_operations_embed(self, td: TensorDict): + raise NotImplementedError("Stepwise encoding not yet implemented") + + def _stepwise_machine_embed(self, td: TensorDict): + raise NotImplementedError("Stepwise encoding not yet implemented") \ No newline at end of file From 4e02795850a14b51dba6b719258768c1ad25fbdd Mon Sep 17 00:00:00 2001 From: FeiLiu <18729537605@163.com> Date: Mon, 13 May 2024 17:55:33 +0800 Subject: [PATCH 6/6] Update env.py --- rl4co/envs/routing/mtvrp/env.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/rl4co/envs/routing/mtvrp/env.py b/rl4co/envs/routing/mtvrp/env.py index e46ccc1e..fa74f1f2 100644 --- a/rl4co/envs/routing/mtvrp/env.py +++ b/rl4co/envs/routing/mtvrp/env.py @@ -349,25 +349,10 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): curr_time = torch.max( curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] ) - # if not torch.all( - # curr_time-1E-6 <= gather_by_index(td["time_windows"], next_node)[..., 1] - # ): - # unsatisfied_indices = torch.nonzero(~(curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] - # ), as_tuple=True) - # print() + new_shape = curr_time.size() skip_open_end = td["open_route"].view(*new_shape) & (next_node == 0).view(*new_shape) - if not torch.all( - (curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end - ): - unsatisfied_indices = torch.nonzero(~((curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end - ), as_tuple=True) - print(skip_open_end) - print(unsatisfied_indices) - print(curr_time) - print(curr_time[unsatisfied_indices]) - print(next_node[unsatisfied_indices]) - input() + assert torch.all( (curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]) | skip_open_end ), "vehicle cannot start service before deadline"