diff --git a/examples/resnet_distributed_torch.yaml b/examples/resnet_distributed_torch.yaml index 7690642b80e..b92ed9feef9 100644 --- a/examples/resnet_distributed_torch.yaml +++ b/examples/resnet_distributed_torch.yaml @@ -11,7 +11,7 @@ setup: | git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet cd pytorch-distributed-resnet # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5). - pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + pip3 install -r requirements.txt numpy==1.26.4 torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 mkdir -p data && mkdir -p saved_models && cd data && \ wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz tar -xvzf cifar-10-python.tar.gz diff --git a/examples/tpu/tpuvm_mnist.yaml b/examples/tpu/tpuvm_mnist.yaml index d4e119119e0..d1fd434fad6 100644 --- a/examples/tpu/tpuvm_mnist.yaml +++ b/examples/tpu/tpuvm_mnist.yaml @@ -14,10 +14,11 @@ setup: | conda create -n flax python=3.10 -y conda activate flax # Make sure to install TPU related packages in a conda env to avoid package conflicts. - pip install "jax[tpu]==0.4.26" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - pip install clu + pip install \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.25" \ + clu \ + tensorflow tensorflow-datasets pip install -e flax - pip install tensorflow tensorflow-datasets fi