From 324960a95c00112ce6b9b858d9311da1597cfb8b Mon Sep 17 00:00:00 2001
From: Siyuan Liu <lsiyuan@google.com>
Date: Fri, 24 Jan 2025 23:23:03 -0800
Subject: [PATCH] [TPU][CI] Update torchxla version in requirement-tpu.txt
 (#12422)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
---
 Dockerfile.tpu       |  2 +-
 requirements-tpu.txt | 21 +++++++++++----------
 2 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/Dockerfile.tpu b/Dockerfile.tpu
index ee0d94d98e82b..e268b39476665 100644
--- a/Dockerfile.tpu
+++ b/Dockerfile.tpu
@@ -1,4 +1,4 @@
-ARG NIGHTLY_DATE="20250122"
+ARG NIGHTLY_DATE="20250124"
 ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
 
 FROM $BASE_IMAGE
diff --git a/requirements-tpu.txt b/requirements-tpu.txt
index 8ab18b3770ae8..51a0c65eac5aa 100644
--- a/requirements-tpu.txt
+++ b/requirements-tpu.txt
@@ -10,16 +10,17 @@ wheel
 jinja2
 ray[default]
 
-# Install torch_xla
---pre
---extra-index-url https://download.pytorch.org/whl/nightly/cpu
+# Install torch, torch_xla
 --find-links https://storage.googleapis.com/libtpu-releases/index.html
 --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
 --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
-torch==2.6.0.dev20241126+cpu
-torchvision==0.20.0.dev20241126+cpu
-torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
-torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
-torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
-jaxlib==0.4.36.dev20241122
-jax==0.4.36.dev20241122
+# Note: This torch whl can be slightly different from the official torch nightly whl
+# since they are not built on the same commit (but on the same day). This difference may cause C++ undefined symbol issue
+# if some change between the 2 commits introduce some C++ API change.
+# Here we install the exact torch whl from which torch_xla is built from, to avoid potential C++ undefined symbol issue.
+torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
+torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
+torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
+torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
+torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
+torch_xla[pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"