Skip to content

Commit

Permalink
[Minor] Make fbcode happy with imports
Browse files Browse the repository at this point in the history
ghstack-source-id: d4bfce9d51269bc0ab6154ee4c2d1e1ff7af0895
Pull Request resolved: #2517
  • Loading branch information
vmoens committed Oct 26, 2024
1 parent 5e03a55 commit a70b258
Show file tree
Hide file tree
Showing 17 changed files with 227 additions and 76 deletions.
6 changes: 5 additions & 1 deletion test/smoke_test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import os
import tempfile

import pytest
Expand Down Expand Up @@ -46,7 +47,10 @@ def test_gym():
from torchrl.envs.libs.gym import _has_gym, GymEnv # noqa

assert _has_gym
from _utils_internal import PONG_VERSIONED
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import PONG_VERSIONED
else:
from _utils_internal import PONG_VERSIONED

env = GymEnv(PONG_VERSIONED())
env.reset()
Expand Down
7 changes: 6 additions & 1 deletion test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os

import pytest
import torch

from _utils_internal import get_default_devices
from mocking_classes import NestedCountingEnv
from tensordict import TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
Expand All @@ -31,6 +31,11 @@
ValueOperator,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices


@pytest.mark.parametrize(
"log_prob_key",
Expand Down
41 changes: 28 additions & 13 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,13 @@
import argparse
import functools
import gc
import os

import sys

import numpy as np
import pytest
import torch

from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
PONG_VERSIONED,
retry,
)
from mocking_classes import (
ContinuousActionVecMockEnv,
CountingBatchedEnv,
Expand Down Expand Up @@ -102,6 +90,33 @@
)
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
PONG_VERSIONED,
retry,
)
else:
from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
PONG_VERSIONED,
retry,
)

# torch.set_default_dtype(torch.double)
IS_WINDOWS = sys.platform == "win32"
IS_OSX = sys.platform == "darwin"
Expand Down
20 changes: 15 additions & 5 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import itertools
import operator
import os

import sys
import warnings
Expand All @@ -16,11 +17,7 @@
import numpy as np
import pytest
import torch
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)

from mocking_classes import ContinuousActionConvMockEnv

from packaging import version, version as pack_version
Expand Down Expand Up @@ -138,6 +135,19 @@
_split_and_pad_sequence,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)
else:
from _utils_internal import ( # noqa
dtype_fixture,
get_available_devices,
get_default_devices,
)

_has_functorch = True
try:
import functorch as ft # noqa
Expand Down
7 changes: 6 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import argparse
import importlib.util
import os

import pytest
import torch
import torch.nn.functional as F

from _utils_internal import get_default_devices
from tensordict import TensorDictBase
from torch import autograd, nn
from torch.utils._pytree import tree_map
Expand All @@ -29,6 +29,11 @@
)
from torchrl.modules.distributions.continuous import SafeTanhTransform

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices

_has_scipy = importlib.util.find_spec("scipy", None) is not None


Expand Down
35 changes: 24 additions & 11 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
import torch
import yaml

from _utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
else:
from _utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
from mocking_classes import (
ActionObsMergeLinear,
AutoResetHeteroCountingEnv,
Expand Down
7 changes: 6 additions & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# LICENSE file in the root directory of this source tree.

import argparse
import os

import pytest
import torch
from _utils_internal import get_default_devices
from mocking_classes import (
ContinuousActionVecMockEnv,
CountingEnvCountModule,
Expand Down Expand Up @@ -47,6 +47,11 @@
OrnsteinUhlenbeckProcessWrapper,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices


class TestEGreedy:
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1])
Expand Down
8 changes: 6 additions & 2 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@

import argparse
import dataclasses
import os
import pathlib
import sys
from time import sleep

import pytest
import torch

from _utils_internal import generate_seeds, get_default_devices
from torchrl._utils import timeit

try:
Expand Down Expand Up @@ -50,6 +49,11 @@
make_dqn_actor,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import generate_seeds, get_default_devices
else:
from _utils_internal import generate_seeds, get_default_devices

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
if TORCH_VERSION < version.parse("1.12.0"):
UNSQUEEZE_SINGLETON = True
Expand Down
38 changes: 26 additions & 12 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,32 @@
import pytest
import torch

from _utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
else:
from _utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
from packaging import version
from tensordict import (
assert_allclose_td,
Expand Down
6 changes: 5 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import re

from numbers import Number
Expand All @@ -11,7 +12,10 @@
import pytest
import torch

from _utils_internal import get_default_devices, retry
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices, retry
else:
from _utils_internal import get_default_devices, retry
from mocking_classes import MockBatchedUnLockedEnv
from packaging import version
from tensordict import TensorDict
Expand Down
7 changes: 6 additions & 1 deletion test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import os

import pytest
import torch
from _utils_internal import get_default_devices
from tensordict import assert_allclose_td, TensorDict

from torchrl._utils import _ends_with
from torchrl.collectors.utils import split_trajectories
from torchrl.data.postprocs.postprocs import MultiStep

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices


@pytest.mark.parametrize("n", range(1, 14))
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
25 changes: 18 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@
import pytest
import torch

from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
else:
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)

from mocking_classes import CountingEnv
from packaging import version
Expand Down Expand Up @@ -3640,7 +3648,10 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls):
def test_rb_multidim_collector(
self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device
):
from _utils_internal import CARTPOLE_VERSIONED
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED
else:
from _utils_internal import CARTPOLE_VERSIONED

torch.manual_seed(0)
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device)
Expand Down
Loading

0 comments on commit a70b258

Please sign in to comment.