Skip to content

Commit

Permalink
Add spot option
Browse files Browse the repository at this point in the history
  • Loading branch information
nstogner committed Mar 4, 2024
1 parent 7c71200 commit 594259b
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions tpu-provisioner/test/e2e/test/jobset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,38 @@ const (
)

func TestTPUJobsets(t *testing.T) {
spot := os.Getenv("TEST_SPOT") == "true"
var (
spot = os.Getenv("TEST_SPOT") == "true"
reservation = os.Getenv("TEST_RESERVATION")
)

cases := []struct {
name string
config tpuConfig
}{
// v4
{
name: "v4-2x2x2-tpu",
config: tpuConfig{
accelerator: "tpu-v4-podslice",
topoX: 2,
topoY: 2,
topoZ: 2,
chipsPerNode: 4,
sliceCount: 1,
},
},
{
name: "v4-2x2x4-tpu",
config: tpuConfig{
accelerator: "tpu-v4-podslice",
topoX: 2,
topoY: 2,
topoZ: 4,
chipsPerNode: 4,
sliceCount: 1,
},
},
// {
// name: "v4-2x2x2-tpu",
// config: tpuConfig{
// accelerator: "tpu-v4-podslice",
// topoX: 2,
// topoY: 2,
// topoZ: 2,
// chipsPerNode: 4,
// sliceCount: 1,
// },
// },
// {
// name: "v4-2x2x4-tpu",
// config: tpuConfig{
// accelerator: "tpu-v4-podslice",
// topoX: 2,
// topoY: 2,
// topoZ: 4,
// chipsPerNode: 4,
// sliceCount: 1,
// },
// },
// v5e
/*
{
Expand Down Expand Up @@ -97,6 +101,7 @@ func TestTPUJobsets(t *testing.T) {
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c.config.spot = spot
c.config.reservation = reservation

js := newJobset(c.name, c.config)
err := client.Create(ctx, js)
Expand Down Expand Up @@ -232,7 +237,7 @@ func newJobset(name string, c tpuConfig) *jobset.JobSet {
Containers: []v1.Container{
{
Name: "main",
Image: "python:3.8",
Image: "python:3.11",
Ports: []v1.ContainerPort{
{
ContainerPort: 8471,
Expand All @@ -244,7 +249,7 @@ func newJobset(name string, c tpuConfig) *jobset.JobSet {
Command: []string{
"bash",
"-c",
`pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; env; python -c 'import jax; print("Device count: total =", jax.device_count(), " local =", jax.local_device_count())'`,
`pip install "jax[tpu]==0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"`,
},
Args: []string{"echo", "job1"},
Resources: v1.ResourceRequirements{
Expand Down

0 comments on commit 594259b

Please sign in to comment.