Skip to content

Commit

Permalink
Ensure safe_to_tensor moves tensors to the specified device. (#831)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the `safe_to_tensor` utility: previously it did not move tensors to a new device according to the `device` kwarg which caused issues when there is more than one device available.
The bug went unnoticed for a long while since our circleCI runners do not have GPUs enabled.
  • Loading branch information
ernestum authored Jan 7, 2025
1 parent a8b079c commit e5ef188
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 16 deletions.
20 changes: 14 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ commands:
steps:
- run:
name: install macOS packages
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils parallel gnu-getopt
command: HOMEBREW_NO_AUTO_UPDATE=1 brew install coreutils gnu-getopt parallel [email protected] virtualenv

- checkout

Expand Down Expand Up @@ -138,11 +138,13 @@ commands:
# Download and cache dependencies
- restore_cache:
keys:
- v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
- v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}

- run:
name: install python
command: choco install --allow-downgrade -y python --version=3.8.10
# Use python3.9 in Windows instead of python3.8 because otherwise
# pytest-notebook's indirect dependency pywinpty will fail to build.
command: choco install --allow-downgrade -y python --version=3.9.13
shell: powershell.exe

- run:
Expand All @@ -163,14 +165,20 @@ commands:

- run:
name: install dependencies
# Only create venv if it's not been restored from cache
command: if (-not (Test-Path venv)) { .\ci\build_and_activate_venv.ps1 -venv venv }
# Only create venv if it's not been restored from cache.
# Need to throw error explicitly on error or else {} will get rid of
# the exit code.
command: |
if (-not (Test-Path venv)) {
.\ci\build_and_activate_venv.ps1 -venv venv
if ($LASTEXITCODE -ne 0) { throw "Failed to create venv" }
}
shell: powershell.exe

- save_cache:
paths:
- .\venv
key: v11win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}
key: v13win-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_and_activate_venv.ps1" }}

- run:
name: install imitation
Expand Down
2 changes: 1 addition & 1 deletion ci/build_and_activate_venv.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ If ($venv -eq $null) {
$venv = "venv"
}

virtualenv -p python3.8 $venv
virtualenv -p python3.9 $venv
& $venv\Scripts\activate
pip install ".[docs,parallel,test]"
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[mypy]
ignore_missing_imports = true
exclude = output

# torch had some type errors, we ignore them because they're not our fault
[mypy-torch._dynamo.*]
follow_imports = skip
follow_imports_for_stubs = True
2 changes: 1 addition & 1 deletion src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.debug_use_ground_truth = debug_use_ground_truth
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._reward_net: reward_nets.RewardNet = reward_net.to(gen_algo.device)
self._log_dir = util.parse_path(log_dir)

# Create graph for optimising/recording stats on discriminator
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/data/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
trajectories: The trajectories to save.
"""
p = util.parse_path(path)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(str(p))
logging.info(f"Dumped demonstrations to {p}.")


Expand Down
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _predict(
):
np_actions = []
if isinstance(obs, dict):
np_obs = types.DictObs(
np_obs: Union[types.DictObs, np.ndarray] = types.DictObs(
{k: v.detach().cpu().numpy() for k, v in obs.items()},
)
else:
Expand Down
7 changes: 2 additions & 5 deletions src/imitation/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,7 @@ def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor:
Returns:
A PyTorch tensor with the same content as `array`.
"""
if isinstance(array, th.Tensor):
return array

if not array.flags.writeable:
if isinstance(array, np.ndarray) and not array.flags.writeable:
array = array.copy()

return th.as_tensor(array, **kwargs)
Expand Down Expand Up @@ -476,6 +473,6 @@ def split_in_half(x: int) -> Tuple[int, int]:
def clear_screen() -> None:
"""Clears the console screen."""
if os.name == "nt": # Windows
os.system("cls")
os.system("cls") # pragma: no cover
else:
os.system("clear")
2 changes: 1 addition & 1 deletion tests/algorithms/test_sqil.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_sqil_performance_continuous(
pytestconfig: pytest.Config,
pendulum_single_venv: vec_env.VecEnv,
rl_algo_class: Type[off_policy_algorithm.OffPolicyAlgorithm],
):
): # pragma: no cover
rl_kwargs = dict(
learning_starts=500,
learning_rate=0.001,
Expand Down
4 changes: 4 additions & 0 deletions tests/data/test_huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_save_load_roundtrip(


@hypothesis.given(st.data(), h_strats.trajectories_list)
# the first run sometimes takes longer, so we give it more time
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Trajectory]):
"""Test that slicing a TrajectoryDatasetSequence behaves as expected."""
# GIVEN
Expand All @@ -84,6 +86,8 @@ def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Traject


@hypothesis.given(st.data(), h_strats.trajectory)
# the first run sometimes takes longer, so we give it more time
@hypothesis.settings(deadline=datetime.timedelta(milliseconds=300))
def test_sliced_info_dict_access(
data: st.DataObject,
trajectory: types.Trajectory,
Expand Down
2 changes: 2 additions & 0 deletions tests/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def test_safe_to_numpy():
numpy = util.safe_to_numpy(tensor)
assert (numpy == tensor.numpy()).all()
assert util.safe_to_numpy(None) is None
with pytest.warns(UserWarning, match=".*performance.*"):
util.safe_to_numpy(tensor, warn=True)


def test_tensor_iter_norm():
Expand Down

0 comments on commit e5ef188

Please sign in to comment.