diff --git a/pearl/replay_buffers/replay_buffer.py b/pearl/replay_buffers/replay_buffer.py index fceba54..a45a26c 100644 --- a/pearl/replay_buffers/replay_buffer.py +++ b/pearl/replay_buffers/replay_buffer.py @@ -58,6 +58,8 @@ def push( curr_available_actions: ActionSpace | None = None, next_state: SubjectiveState | None = None, next_available_actions: ActionSpace | None = None, + # max_number_actions should be specified when the size of the action space + # varies across different time steps. max_number_actions: int | None = None, cost: float | None = None, ) -> None: diff --git a/pearl/replay_buffers/tensor_based_replay_buffer.py b/pearl/replay_buffers/tensor_based_replay_buffer.py index c31e6e2..a1486b7 100644 --- a/pearl/replay_buffers/tensor_based_replay_buffer.py +++ b/pearl/replay_buffers/tensor_based_replay_buffer.py @@ -81,6 +81,12 @@ def push( None, ) else: + # If the action space is discrete and max_number_actions is not specified, + # then we assume that the size of the action space does not change over time + # and use this size as max_number_actions. + if max_number_actions is None: + assert isinstance(curr_available_actions, DiscreteActionSpace) + max_number_actions = curr_available_actions.n ( curr_available_actions_tensor_with_padding, curr_unavailable_actions_mask,