Skip to content

Commit

Permalink
Merge branch 'GoogleCloudPlatform:main' into dws-mulitclusters-gkeap
Browse files Browse the repository at this point in the history
  • Loading branch information
leroyjb authored Feb 11, 2025
2 parents 79c7c1a + c5ffe8d commit 489dd0e
Show file tree
Hide file tree
Showing 31 changed files with 1,709 additions and 115 deletions.
33 changes: 27 additions & 6 deletions cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ steps:
args:
- '-c'
- |
mkdir -p security_test/scan_target/ && find . -mindepth 1 -maxdepth 1 -type d ! -name "security_test" -exec cp -r {} security_test/scan_target/ \;
mkdir -p /workspace/security_test/scan_target
# Exclude /workspace/security_test from the copy to avoid recursive issue
find . -mindepth 1 -maxdepth 1 ! -path "./security_test" -exec cp -r {} /workspace/security_test/scan_target/ \;
chown -R 65532:65532 /workspace/security_test/scan_target
mkdir -p /workspace/security_test/allowlist
cp security_test/config.yaml /workspace/security_test/config.yaml
cp -r security_test/allowlist/* /workspace/security_test/allowlist/ || echo "Allowlist folder is empty or not found"
Expand All @@ -308,7 +313,7 @@ steps:
# gcr.io/cloud-builders/docker is a special image: This image provided by Google Cloud contains the docker command-line tool, which is essential for executing Docker commands like docker build and docker run within your Cloud Build steps.
# gcr.io/${_PROJECT_ID}/check_violations:latest is your application image: This image contains your security check tool and its dependencies. It's designed to be run, not to build or run other Docker images.
- name: 'gcr.io/cloud-builders/docker'
id: 'run shipshape'
id: 'Run shipshape on cluster'
args:
- 'run'
- '--network=cloudbuild'
Expand All @@ -322,16 +327,32 @@ steps:
- '${_SHIPSHAPE_IMAGE}'
- '--mode=cluster'
- '--allowlist_folder=/workspace/security_test/allowlist'
- '--cluster_name=ml-${SHORT_SHA}-${_PR_NUMBER}-${_BUILD_ID}-cluster'
- '--project_id=$PROJECT_ID'
- '--location=${_REGION}'
- '--kube_config_path=/root/.kube/config'
- '--max_wait_duration=3000'
- '--max_parallel=100'
- '--cluster_scan_config_path=/workspace/security_test/config.yaml'
allowFailure: true
waitFor: ['Copy metadata']


- id: 'Run Shipshape on helm'
name: 'gcr.io/cloud-builders/docker'
args:
- 'run'
- '--network=cloudbuild'
- '--rm'
- '-v'
- '/workspace/security_test/allowlist:/workspace/security_test/allowlist'
- '-v'
- '/workspace/security_test/scan_target:/workspace/security_test/scan_target'
- '${_SHIPSHAPE_IMAGE}'
- '--mode=helm'
- '--allowlist_folder=/workspace/security_test/allowlist'
- '--scan_path=/workspace/security_test/scan_target'
- '--max_wait_duration=60'
allowFailure: true
waitFor: ['Copy metadata']

- id: 'cleanup rag'
name: 'gcr.io/$PROJECT_ID/terraform'
entrypoint: 'bash'
Expand All @@ -358,7 +379,7 @@ steps:
-var=cloudsql_instance=pgvector-instance-$SHORT_SHA-$_BUILD_ID \
-auto-approve -no-color
allowFailure: true
waitFor: ['run shipshape']
waitFor: ['Run shipshape on cluster', 'Run Shipshape on helm']

- id: 'cleanup gke cluster'
name: 'gcr.io/$PROJECT_ID/terraform'
Expand Down Expand Up @@ -441,7 +462,7 @@ substitutions:
_USER_NAME: github
_AUTOPILOT_CLUSTER: "false"
_BUILD_ID: ${BUILD_ID:0:8}
_SHIPSHAPE_IMAGE: us-docker.pkg.dev/k8ssecurityvalidation-agent/k8ssecurityvalidation-agent/k8ssecurityvalidation-agent@sha256:1f80ae746014330a3a83b5ee2fabeacf90fa488c4f0922072b628b95144f25c8
_SHIPSHAPE_IMAGE: us-docker.pkg.dev/k8ssecurityvalidation-agent/k8ssecurityvalidation-agent/k8ssecurityvalidation-agent@sha256:cd45e6cd84e9a45462ddbca18c4731fd4e264d517ee98131eb5be4eb57691f44
logsBucket: gs://ai-on-gke-build-logs
options:
substitutionOption: "ALLOW_LOOSE"
Expand Down
11 changes: 11 additions & 0 deletions infrastructure/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ output "ca_certificate" {

}

output "service_account" {
value = var.create_cluster && var.autopilot_cluster && var.private_cluster ? module.private-gke-autopilot-cluster[0].service_account : (
var.create_cluster && !var.autopilot_cluster && var.private_cluster ? module.private-gke-standard-cluster[0].service_account : (
var.create_cluster && var.autopilot_cluster && !var.private_cluster ? module.public-gke-autopilot-cluster[0].service_account : (
var.create_cluster && !var.autopilot_cluster && !var.private_cluster ? module.public-gke-standard-cluster[0].service_account :
"")))
sensitive = true
depends_on = [module.private-gke-autopilot-cluster, module.private-gke-standard-cluster, module.public-gke-autopilot-cluster, module.public-gke-standard-cluster]

}

output "private_cluster" {
value = var.private_cluster
}
4 changes: 4 additions & 0 deletions modules/gke-autopilot-private-cluster/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ output "endpoint" {

output "ca_certificate" {
value = module.gke.ca_certificate
}

output "service_account" {
value = module.gke.service_account
}
4 changes: 4 additions & 0 deletions modules/gke-autopilot-public-cluster/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ output "endpoint" {

output "ca_certificate" {
value = module.gke.ca_certificate
}

output "service_account" {
value = module.gke.service_account
}
4 changes: 4 additions & 0 deletions modules/gke-standard-private-cluster/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ output "endpoint" {

output "ca_certificate" {
value = module.gke.ca_certificate
}

output "service_account" {
value = module.gke.service_account
}
5 changes: 5 additions & 0 deletions modules/gke-standard-public-cluster/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ output "endpoint" {
output "ca_certificate" {
value = module.gke.ca_certificate
}


output "service_account" {
value = module.gke.service_account
}
8 changes: 4 additions & 4 deletions ray-on-gke/tpu/kuberay-tpu-webhook/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ require (
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect
golang.org/x/net v0.23.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/oauth2 v0.12.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/time v0.3.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
Expand Down
20 changes: 10 additions & 10 deletions ray-on-gke/tpu/kuberay-tpu-webhook/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.12.0 h1:smVPGxink+n1ZI5pkQa8y6fZT0RW0MgCO5bFpepy4B4=
golang.org/x/oauth2 v0.12.0/go.mod h1:A74bZ3aGXgCY0qaIC9Ahg6Lglin4AMAco8cIv9baba4=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -139,23 +139,23 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ=
golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
19 changes: 15 additions & 4 deletions ray-on-gke/tpu/kuberay-tpu-webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"time"

ray "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
utils "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
admissionv1 "k8s.io/api/admission/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -244,16 +245,26 @@ func extractRayCluster(admissionReview *admissionv1.AdmissionReview) (*ray.RayCl
return &rayCluster, nil
}

// generateHeadlessServiceName returns the expected TPU headless service name for a RayCluster
func generateHeadlessServiceName(clusterName string) string {
serviceName := fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)

// Apply the same truncation as in the RayCluster controller when generating the headless service
// name. This is to maintain the up-to 63 char compatibility guarantee for hostnames (RFC 1123).
return utils.CheckName(serviceName)
}

// genDNSHostnames returns list of DNS hostnames for TPU VM hosts as a string
func genDNSHostnames(numOfHosts int32, groupName string, clusterName string, namespace string, replicaIndex int) (string, error) {
if numOfHosts == 0 {
err := errors.New("workerGroupSpec NumOfHosts not set")
return "", err
}
headlessServiceName := generateHeadlessServiceName(clusterName)
hostNames := make([]string, numOfHosts)
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.headless-worker-svc
// Host names will be of the form {WORKER_GROUP_NAME}-{REPLICA_INDEX}-{HOST_INDEX}.{CLUSTER_NAME}-headless-worker-svc
for j := 0; j < int(numOfHosts); j++ {
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s-%s", groupName, replicaIndex, j, clusterName, headlessServiceSuffix)
hostNames[j] = fmt.Sprintf("%s-%d-%d.%s", groupName, replicaIndex, j, headlessServiceName)
}
klog.V(1).InfoS("genDNSHostnames", "RayCluster", namespace+"/"+clusterName, "NumOfHosts", numOfHosts, "Replica Index", replicaIndex)
return strings.Join(hostNames, ","), nil
Expand All @@ -268,7 +279,7 @@ func injectHostnames(clusterName string, hostNames string, envPath string, conta
Value: hostNames,
}
subdomainPatch["path"] = subdomainPath
subdomainPatch["value"] = fmt.Sprintf("%s-%s", clusterName, headlessServiceSuffix)
subdomainPatch["value"] = generateHeadlessServiceName(clusterName)
// create new EnvVar array if container.Env is empty, and append hostnames if not
if len(container.Env) == 0 {
hostNamesPatch["path"] = envPath
Expand Down Expand Up @@ -678,7 +689,7 @@ func (t *TPUWebhookServer) mutatePod(admissionReview *admissionv1.AdmissionRevie
return nil, err
}
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "TPU_WORKER_HOSTNAMES", hostnames)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", clusterName+"-"+headlessServiceSuffix)
klog.V(1).InfoS("mutatePod", "RayCluster", namespace+"/"+clusterName, "subdomain", generateHeadlessServiceName(clusterName))
injectHostnames(clusterName, hostnames, path, container, &patches)
}
// inject TPU_WORKER_ID
Expand Down
78 changes: 60 additions & 18 deletions ray-on-gke/tpu/kuberay-tpu-webhook/webhook_main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,26 +778,30 @@ func Test_ExtractRayCluster(t *testing.T) {

func Test_GenDNSHostnames(t *testing.T) {
tests := map[string]struct {
clusterName string
replicaIndex int
numOfHosts int32
expectedHostnames string
expectedError error
}{
"genDNSHostnames with NumOfHosts == 0": {
// a workergroup can't have NumOfHosts set to 0 so this should error out
clusterName: "test-cluster",
replicaIndex: 0,
numOfHosts: int32(0),
expectedError: errors.New("workerGroupSpec NumOfHosts not set"),
},
"genDNSHostnames with NumOfHosts == 1": {
// Single-host worker group, should return a single DNS hostname. This function will
// never be called for single-host groups, but we don't necessarily want it to error if it does.
clusterName: "test-cluster",
replicaIndex: 0,
numOfHosts: int32(1),
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
},
"genDNSHostnames with NumOfHosts > 1": {
// multi-host worker group, should return a string list of DNS hostnames for the given replica
clusterName: "test-cluster",
replicaIndex: 1,
numOfHosts: int32(4),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
Expand All @@ -806,12 +810,21 @@ func Test_GenDNSHostnames(t *testing.T) {
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
}, ","),
},
"genDNSHostnames with long RayCluster name": {
// Multi-host worker group in a RayCluster with a name that will be truncated
clusterName: "long-raycluster-name-to-be-truncated",
replicaIndex: 1,
numOfHosts: int32(2),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "aycluster-name-to-be-truncated", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "aycluster-name-to-be-truncated", headlessServiceSuffix),
}, ","),
},
}

// validate that genDNSHostnames correctly returns a string list of DNS addressable hostnames
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", "test-cluster", "test-namespace", tc.replicaIndex)
hostnames, err := genDNSHostnames(tc.numOfHosts, "test-group", tc.clusterName, "test-namespace", tc.replicaIndex)
if err != nil {
assert.Equal(t, tc.expectedError, err)
} else {
Expand All @@ -823,21 +836,15 @@ func Test_GenDNSHostnames(t *testing.T) {

func Test_InjectHostnames(t *testing.T) {
tests := map[string]struct {
numOfHosts int
clusterName string
groupName string
expectedSubdomain string
expectedHostnames string
}{
"injectHostnames for single-host worker group": {
// should create a patch to set the subdomain and a single TPU_WORKER_HOSTNAMES DNS hostname
numOfHosts: 1,
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
expectedHostnames: fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 0, 0, "test-cluster", headlessServiceSuffix),
},
"injectHostnames for multi-host worker group": {
// should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts
numOfHosts: 1,
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts.
// This function is only called for multi-host TPU worker groups.
clusterName: "test-cluster",
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "test-cluster", headlessServiceSuffix),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "test-cluster", headlessServiceSuffix),
Expand All @@ -846,21 +853,33 @@ func Test_InjectHostnames(t *testing.T) {
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "test-cluster", headlessServiceSuffix),
}, ","),
},
"injectHostnames for multi-host worker group with truncated service name": {
// Should create a patch to set the subdomain and TPU_WORKER_HOSTNAMES for all hosts, with the
// correct subdomain truncated to match the created service name.
clusterName: "extremely-long-test-raycluster-name",
groupName: "test-group-name",
expectedSubdomain: fmt.Sprintf("%s-%s", "mely-long-test-raycluster-name", headlessServiceSuffix),
expectedHostnames: strings.Join([]string{fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 0, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 1, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 2, "mely-long-test-raycluster-name", headlessServiceSuffix),
fmt.Sprintf("%s-%d-%d.%s-%s", "test-group", 1, 3, "mely-long-test-raycluster-name", headlessServiceSuffix),
}, ","),
},
}

// check that a valid subdomain and TPU_WORKER_HOSTNAMES are injected into the Pod
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
testPod := getTestTPUWorker("test-cluster", "test-group", "test-namespace", "tpu-v4-podslice", "2x2x1", "4")
testPod := getTestTPUWorker(tc.clusterName, tc.groupName, "test-namespace", "tpu-v4-podslice", "2x2x2", "4")
expectedEnv := []corev1.EnvVar{corev1.EnvVar{Name: "TPU_WORKER_HOSTNAMES", Value: tc.expectedHostnames}}
expectedPatches := []patch{}
injectHostnames("test-cluster", tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &expectedPatches)
patches := []patch{}
injectHostnames(tc.clusterName, tc.expectedHostnames, "/spec/containers/0/env", testPod.Spec.Containers[0], &patches)
// check subdomain patch
assert.Equal(t, "/spec/subdomain", expectedPatches[0]["path"])
assert.Equal(t, tc.expectedSubdomain, expectedPatches[0]["value"])
assert.Equal(t, "/spec/subdomain", patches[0]["path"])
assert.Equal(t, tc.expectedSubdomain, patches[0]["value"])
// check hostnames patch
assert.Equal(t, "/spec/containers/0/env", expectedPatches[1]["path"])
assert.Equal(t, expectedEnv, expectedPatches[1]["value"])
assert.Equal(t, "/spec/containers/0/env", patches[1]["path"])
assert.Equal(t, expectedEnv, patches[1]["value"])
})
}
}
Expand Down Expand Up @@ -1464,3 +1483,26 @@ func Test_MutatePod(t *testing.T) {
})
}
}

func Test_GenerateHeadlessServiceName(t *testing.T) {
tests := map[string]struct {
testRayClusterName string
expectedServiceName string
}{
"RayCluster name + headless-worker-svc is less than 50 chars, no truncation": {
testRayClusterName: "test-raycluster", // 15 chars
expectedServiceName: "test-raycluster-headless-worker-svc", // 35 chars
},
"RayCluster name + headless-worker-svc is more than 50 chars, name is truncated": {
testRayClusterName: "extremely-long-test-raycluster-name", // 35 chars
expectedServiceName: "mely-long-test-raycluster-name-headless-worker-svc", // 50 chars
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
serviceName := generateHeadlessServiceName(tc.testRayClusterName)
assert.Equal(t, tc.expectedServiceName, serviceName)
})
}
}
Loading

0 comments on commit 489dd0e

Please sign in to comment.