21
21
CompositeSpec ,
22
22
UnboundedContinuousTensorSpec ,
23
23
)
24
+ from torchrl .envs import EnvCreator , SerialEnv
24
25
from torchrl .envs .utils import set_exploration_type , step_mdp
25
26
from torchrl .modules import (
26
27
AdditiveGaussianWrapper ,
@@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based):
1782
1783
)
1783
1784
1784
1785
@pytest .mark .parametrize ("python_based" , [True , False ])
1785
- def test_lstm_parallel_env (self , python_based ):
1786
+ @pytest .mark .parametrize ("parallel" , [True , False ])
1787
+ @pytest .mark .parametrize ("heterogeneous" , [True , False ])
1788
+ def test_lstm_parallel_env (self , python_based , parallel , heterogeneous ):
1786
1789
from torchrl .envs import InitTracker , ParallelEnv , TransformedEnv
1787
1790
1791
+ torch .manual_seed (0 )
1788
1792
device = "cuda" if torch .cuda .device_count () else "cpu"
1789
1793
# tests that hidden states are carried over with parallel envs
1790
1794
lstm_module = LSTMModule (
@@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based):
1796
1800
device = device ,
1797
1801
python_based = python_based ,
1798
1802
)
1803
+ if parallel :
1804
+ cls = ParallelEnv
1805
+ else :
1806
+ cls = SerialEnv
1799
1807
1800
1808
def create_transformed_env ():
1801
1809
primer = lstm_module .make_tensordict_primer ()
@@ -1807,7 +1815,12 @@ def create_transformed_env():
1807
1815
env .append_transform (primer )
1808
1816
return env
1809
1817
1810
- env = ParallelEnv (
1818
+ if heterogeneous :
1819
+ create_transformed_env = [
1820
+ EnvCreator (create_transformed_env ),
1821
+ EnvCreator (create_transformed_env ),
1822
+ ]
1823
+ env = cls (
1811
1824
create_env_fn = create_transformed_env ,
1812
1825
num_workers = 2 ,
1813
1826
)
@@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based):
2109
2122
)
2110
2123
2111
2124
@pytest .mark .parametrize ("python_based" , [True , False ])
2112
- def test_gru_parallel_env (self , python_based ):
2125
+ @pytest .mark .parametrize ("parallel" , [True , False ])
2126
+ @pytest .mark .parametrize ("heterogeneous" , [True , False ])
2127
+ def test_gru_parallel_env (self , python_based , parallel , heterogeneous ):
2113
2128
from torchrl .envs import InitTracker , ParallelEnv , TransformedEnv
2114
2129
2130
+ torch .manual_seed (0 )
2131
+
2115
2132
device = "cuda" if torch .cuda .device_count () else "cpu"
2116
2133
# tests that hidden states are carried over with parallel envs
2117
2134
gru_module = GRUModule (
@@ -2134,7 +2151,17 @@ def create_transformed_env():
2134
2151
env .append_transform (primer )
2135
2152
return env
2136
2153
2137
- env = ParallelEnv (
2154
+ if parallel :
2155
+ cls = ParallelEnv
2156
+ else :
2157
+ cls = SerialEnv
2158
+ if heterogeneous :
2159
+ create_transformed_env = [
2160
+ EnvCreator (create_transformed_env ),
2161
+ EnvCreator (create_transformed_env ),
2162
+ ]
2163
+
2164
+ env = cls (
2138
2165
create_env_fn = create_transformed_env ,
2139
2166
num_workers = 2 ,
2140
2167
)
0 commit comments