diff --git a/jumanji/types.py b/jumanji/types.py index ecbb069a7..249851846 100644 --- a/jumanji/types.py +++ b/jumanji/types.py @@ -95,6 +95,7 @@ def restart( observation: Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`. @@ -107,6 +108,8 @@ def restart( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the rewards and discounts. + Defaults to `float`. Returns: TimeStep identified as a reset. @@ -114,8 +117,8 @@ def restart( extras = extras or {} return TimeStep( step_type=StepType.FIRST, - reward=jnp.zeros(shape, dtype=float), - discount=jnp.ones(shape, dtype=float), + reward=jnp.zeros(shape, dtype=dtype), + discount=jnp.ones(shape, dtype=dtype), observation=observation, extras=extras, ) @@ -127,6 +130,7 @@ def transition( discount: Optional[Array] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.MID`. @@ -141,11 +145,13 @@ def transition( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. Returns: TimeStep identified as a transition. """ - discount = discount if discount is not None else jnp.ones(shape, dtype=float) + discount = discount if discount is not None else jnp.ones(shape, dtype=dtype) extras = extras or {} return TimeStep( step_type=StepType.MID, @@ -161,6 +167,7 @@ def termination( observation: Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.LAST`. @@ -174,6 +181,8 @@ def termination( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. Returns: TimeStep identified as the termination of an episode. @@ -182,7 +191,7 @@ def termination( return TimeStep( step_type=StepType.LAST, reward=reward, - discount=jnp.zeros(shape, dtype=float), + discount=jnp.zeros(shape, dtype=dtype), observation=observation, extras=extras, ) @@ -194,6 +203,7 @@ def truncation( discount: Optional[Array] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = (), + dtype: Union[jnp.dtype, type] = float, ) -> TimeStep: """Returns a `TimeStep` with `step_type` set to `StepType.LAST`. @@ -208,10 +218,13 @@ def truncation( shape: optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount. + dtype: Optional parameter to specify the data type of the discounts. Defaults + to `float`. + Returns: TimeStep identified as the truncation of an episode. """ - discount = discount if discount is not None else jnp.ones(shape, dtype=float) + discount = discount if discount is not None else jnp.ones(shape, dtype=dtype) extras = extras or {} return TimeStep( step_type=StepType.LAST,