diff --git a/test/support/conditions.go b/test/support/conditions.go index e7c5097a..16b26583 100644 --- a/test/support/conditions.go +++ b/test/support/conditions.go @@ -18,6 +18,7 @@ limitations under the License. package support import ( + appsv1 "k8s.io/api/apps/v1" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" ) @@ -34,7 +35,10 @@ func ConditionStatus[T conditionType](conditionType T) func(any) corev1.Conditio if c := getJobCondition(o.Status.Conditions, batchv1.JobConditionType(conditionType)); c != nil { return c.Status } - + case *appsv1.Deployment: + if c := getDeploymentCondition(o.Status.Conditions, appsv1.DeploymentConditionType(conditionType)); c != nil { + return c.Status + } } return corev1.ConditionUnknown @@ -51,3 +55,12 @@ func getJobCondition(conditions []batchv1.JobCondition, conditionType batchv1.Jo } return nil } + +func getDeploymentCondition(conditions []appsv1.DeploymentCondition, conditionType appsv1.DeploymentConditionType) *appsv1.DeploymentCondition { + for _, c := range conditions { + if c.Type == conditionType { + return &c + } + } + return nil +} diff --git a/test/support/ray.go b/test/support/ray.go index cd5a9d87..52bf2d65 100644 --- a/test/support/ray.go +++ b/test/support/ray.go @@ -68,3 +68,20 @@ func WriteRayJobLogs(t Test, namespace, name string) { t.T().Logf("Retrieving RayJob %s/%s logs", namespace, name) WriteToOutputDir(t, name, Log, GetRayJobLogs(t, namespace, name)) } + +func RayCluster(t Test, namespace, name string) func(g gomega.Gomega) *rayv1alpha1.RayCluster { + return func(g gomega.Gomega) *rayv1alpha1.RayCluster { + cluster, err := t.Client().Ray().RayV1alpha1().RayClusters(namespace).Get(t.Ctx(), name, metav1.GetOptions{}) + g.Expect(err).NotTo(gomega.HaveOccurred()) + return cluster + } +} + +func GetRayCluster(t Test, namespace, name string) *rayv1alpha1.RayCluster { + t.T().Helper() + return RayCluster(t, namespace, name)(t) +} + +func RayClusterState(cluster *rayv1alpha1.RayCluster) rayv1alpha1.ClusterState { + return cluster.Status.State +}