Skip to content

Commit

Permalink
Update TPU_guide.md
Browse files Browse the repository at this point in the history
  • Loading branch information
richardsliu authored Sep 21, 2023
1 parent cf9715c commit 5f89864
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions ray-on-gke/TPU_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,18 @@ If you are using multiple TPU hosts with JAX, you need to manually set JAX envir
```
@ray.remote(resources={"google.com/tpu": 4})
def get_hostname():
import socket
import time;
import time
time.sleep(1)
return socket.gethostname()
return ray.util.get_node_ip_address()
@ray.remote(resources={"google.com/tpu": 4})
def init_tpu_env_from_ray(id_hostname_map):
import os
import socket
import time;
import time
time.sleep(1)
hostname = socket.gethostname()
hostname = ray.util.get_node_ip_address()
worker_id = id_hostname_map[hostname]
os.environ["TPU_WORKER_ID"] = str(worker_id)
Expand All @@ -104,6 +102,7 @@ def init_jax_from_ray(num_workers: int):
result = [init_tpu_env_from_ray.remote(id_hostname_map) for _ in range(num_workers)]
print(ray.get(result))
init_jax_from_ray(num_workers=2)
```
Expand Down

0 comments on commit 5f89864

Please sign in to comment.