@@ -220,7 +220,7 @@ class PendulumEnv(EnvBase):
220220
221221 def __init__ (self , td_params = None , seed = None , device = None ):
222222 if td_params is None :
223- td_params = self .gen_params ()
223+ td_params = self .gen_params (device = self . device )
224224
225225 super ().__init__ (device = device )
226226 self ._make_spec (td_params )
@@ -273,7 +273,7 @@ def _reset(self, tensordict):
273273 # if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274 # Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275 # parameters to get started.
276- tensordict = self .gen_params (batch_size = batch_size )
276+ tensordict = self .gen_params (batch_size = batch_size , device = self . device )
277277
278278 high_th = torch .tensor (self .DEFAULT_X , device = self .device )
279279 high_thdot = torch .tensor (self .DEFAULT_Y , device = self .device )
@@ -355,12 +355,12 @@ def make_composite_from_td(td):
355355 return composite
356356
357357 def _set_seed (self , seed : int ):
358- rng = torch .Generator ()
358+ rng = torch .Generator (device = self . device )
359359 rng .manual_seed (seed )
360360 self .rng = rng
361361
362362 @staticmethod
363- def gen_params (g = 10.0 , batch_size = None ) -> TensorDictBase :
363+ def gen_params (g = 10.0 , batch_size = None , device = None ) -> TensorDictBase :
364364 """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
365365 if batch_size is None :
366366 batch_size = []
@@ -379,6 +379,7 @@ def gen_params(g=10.0, batch_size=None) -> TensorDictBase:
379379 )
380380 },
381381 [],
382+ device = device ,
382383 )
383384 if batch_size :
384385 td = td .expand (batch_size ).contiguous ()
0 commit comments