Skip to content

Commit

Permalink
bump: torch to >=2.1.1 (rm workaround) (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
quinn-dougherty committed Apr 9, 2024
1 parent c264c94 commit d276e4a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 62 deletions.
93 changes: 55 additions & 38 deletions python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 2 additions & 24 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,8 @@ importlib-metadata = ">=5.1.0"
numpy = [{ version = ">=1.20,<1.25", python = ">=3.8,<3.9" },
{ version = ">=1.24", python = ">=3.9,<3.12" },
{ version = ">=1.26", python = ">=3.12,<3.13" }]
python = ">=3.8"
torch = ">=1.10" # See PyTorch 2 fix below
# PyTorch 2.1.0 Bug Fix PyTorch didn't put their dependencies metadata into all wheels for 2.1.0, so
# it doesn't work with Poetry. This is a known bug - the workaround is to place them manually here
# (from the one wheel that did correctly list them). This was broken in 2.0.1 and the fix wasn't
# made for 2.1.0, however Meta are aware of the issue and once it is fixed (and the torch version
# requirement bumped) this should be removed. Note also the python version is used to specify that
# this is only added where v2 torch is installed (as per the torch version requirement above).
# https://github.com/pytorch/pytorch/issues/100974
# https://github.com/python-poetry/poetry/issues/7902#issuecomment-1583078794
nvidia-cuda-nvrtc-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cuda-runtime-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cuda-cupti-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cudnn-cu12 = { version = "==8.9.2.26", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cublas-cu12 = { version = "==12.1.3.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cufft-cu12 = { version = "==11.0.2.54", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-curand-cu12 = { version = "==10.3.2.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cusolver-cu12 = { version = "==11.4.5.107", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-cusparse-cu12 = { version = "==12.1.0.106", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-nccl-cu12 = { version = "==2.18.1", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
nvidia-nvtx-cu12 = { version = "==12.1.105", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
triton = { version = "==2.1.0", markers = "platform_system == 'Linux' and platform_machine == 'x86_64'" }
# End PyTorch 2.1.0 Bug Fix
python = ">=3.8"
torch = ">=2.1.1"

[tool.poetry.group.dev.dependencies]
autopep8 = ">=2.0"
Expand All @@ -44,7 +23,6 @@ pytest = ">=7.2"
snapshottest = ">=0.6"
twine = ">=4.0.1"


[tool.poetry.group.jupyter.dependencies]
jupyterlab = ">=3.5"

Expand Down

0 comments on commit d276e4a

Please sign in to comment.