From 3c5644dca0d502d261b1ded885e8eff9fb5b2856 Mon Sep 17 00:00:00 2001 From: SherlockShemol Date: Sun, 29 Sep 2024 11:08:19 +0800 Subject: [PATCH] enable JointInferenceService controller to update related pods when modifying CRD, and add test files to ensure functionality stability and correctness Signed-off-by: SherlockShemol --- .../jointinference/jointinferenceservice.go | 82 +++-- .../jointinferenceservice_test.go | 305 ++++++++++++++++++ 2 files changed, 369 insertions(+), 18 deletions(-) create mode 100644 pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go diff --git a/pkg/globalmanager/controllers/jointinference/jointinferenceservice.go b/pkg/globalmanager/controllers/jointinference/jointinferenceservice.go index 7bf8a168..55780d44 100644 --- a/pkg/globalmanager/controllers/jointinference/jointinferenceservice.go +++ b/pkg/globalmanager/controllers/jointinference/jointinferenceservice.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "reflect" "strconv" "time" @@ -27,6 +28,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" @@ -94,6 +96,10 @@ type Controller struct { cfg *config.ControllerConfig sendToEdgeFunc runtime.DownstreamSendFunc + + bigModelHost string + + selector labels.Selector } // Run starts the main goroutine responsible for watching and syncing services. @@ -278,9 +284,12 @@ func (c *Controller) sync(key string) (bool, error) { // more details at https://github.com/kubernetes/kubernetes/issues/3030 service.SetGroupVersionKind(Kind) - selector, _ := runtime.GenerateSelector(&service) - pods, err := c.podStore.Pods(service.Namespace).List(selector) - deployments, err := c.deploymentsLister.Deployments(service.Namespace).List(selector) + c.selector, _ = runtime.GenerateSelector(&service) + pods, err := c.podStore.Pods(service.Namespace).List(c.selector) + if err != nil { + return false, err + } + deployments, err := c.deploymentsLister.Deployments(service.Namespace).List(c.selector) if err != nil { return false, err @@ -422,31 +431,35 @@ func isServiceFinished(j *sednav1.JointInferenceService) bool { func (c *Controller) createWorkers(service *sednav1.JointInferenceService, activeCloudPod *bool, activeCloudDeployment *bool, activeEdgePod *bool, activeEdgeDeployment *bool) (activePods, activeDeployments int32, err error) { var bigModelPort int32 = BigModelPort // create cloud worker - err = c.createCloudWorker(service, bigModelPort, activeCloudPod, activeCloudDeployment) + err = c.createCloudWorker(service, bigModelPort) if err != nil { return activePods, activeDeployments, fmt.Errorf("failed to create cloudWorker: %w", err) } + *activeCloudPod = true + *activeCloudDeployment = true activePods++ activeDeployments++ // create k8s service for cloudPod - bigModelHost, err := runtime.CreateEdgeMeshService(c.kubeClient, service, jointInferenceForCloud, bigModelPort) + c.bigModelHost, err = runtime.CreateEdgeMeshService(c.kubeClient, service, jointInferenceForCloud, bigModelPort) if err != nil { return activePods, activeDeployments, fmt.Errorf("failed to create edgemesh service: %w", err) } // create edge worker - err = c.createEdgeWorker(service, bigModelHost, bigModelPort, activeEdgePod, activeEdgeDeployment) + err = c.createEdgeWorker(service, c.bigModelHost, bigModelPort) if err != nil { return activePods, activeDeployments, fmt.Errorf("failed to create edgeWorker: %w", err) } + *activeEdgePod = true + *activeEdgeDeployment = true activePods++ activeDeployments++ return activePods, activeDeployments, err } -func (c *Controller) createCloudWorker(service *sednav1.JointInferenceService, bigModelPort int32, activeCloudPod *bool, activeCloudDeployment *bool) error { +func (c *Controller) createCloudWorker(service *sednav1.JointInferenceService, bigModelPort int32) error { // deliver deployment for cloudworker cloudModelName := service.Spec.CloudWorker.Model.Name cloudModel, err := c.client.Models(service.Namespace).Get(context.Background(), cloudModelName, metav1.GetOptions{}) @@ -494,14 +507,10 @@ func (c *Controller) createCloudWorker(service *sednav1.JointInferenceService, b if err != nil { return fmt.Errorf("failed to create cloudWorker deployment: %w", err) } - - *activeCloudDeployment = true - *activeCloudPod = true - return nil } -func (c *Controller) createEdgeWorker(service *sednav1.JointInferenceService, bigModelHost string, bigModelPort int32, activeEdgePod *bool, activeEdgeDeployment *bool) error { +func (c *Controller) createEdgeWorker(service *sednav1.JointInferenceService, bigModelHost string, bigModelPort int32) error { // deliver edge deployment for edgeworker ctx := context.Background() edgeModelName := service.Spec.EdgeWorker.Model.Name @@ -565,8 +574,6 @@ func (c *Controller) createEdgeWorker(service *sednav1.JointInferenceService, bi return fmt.Errorf("failed to create edgeWorker deployment: %w", err) } - *activeEdgeDeployment = true - *activeEdgePod = true return nil } @@ -599,10 +606,7 @@ func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) { jc.syncToEdge(watch.Added, obj) }, - UpdateFunc: func(old, cur interface{}) { - jc.enqueueController(cur, true) - jc.syncToEdge(watch.Added, cur) - }, + UpdateFunc: jc.updateService, DeleteFunc: func(obj interface{}) { jc.enqueueController(obj, true) @@ -692,3 +696,45 @@ func (c *Controller) enqueueByDeployment(deployment *appsv1.Deployment) { c.enqueueController(service, true) } + +func (c *Controller) updateService(old, cur interface{}) { + oldService, ok := old.(*sednav1.JointInferenceService) + if !ok { + return + } + curService, ok := cur.(*sednav1.JointInferenceService) + if !ok { + return + } + + if oldService == curService { + return + } + + if reflect.DeepEqual(oldService.Spec, curService.Spec) { + return + } + // if CRD is changed,and the service.Generation is changed, update deployment settings + klog.Infof("Service is updated, delete previous deployments") + curService.SetGroupVersionKind(Kind) + // if the service.Generation is changed, update deployment settings + if oldService.Generation != curService.Generation { + // delete previous deployments + deployments, err := c.deploymentsLister.Deployments(curService.Namespace).List(c.selector) + if err != nil { + klog.Errorf("Failed to list deployments: %v", err) + return + } + for _, deployment := range deployments { + c.kubeClient.AppsV1().Deployments(curService.Namespace).Delete(context.TODO(), deployment.Name, metav1.DeleteOptions{}) + } + + c.createEdgeWorker(curService, c.bigModelHost, BigModelPort) + c.createCloudWorker(curService, BigModelPort) + + // update the service status + c.client.JointInferenceServices(curService.Namespace).UpdateStatus(context.TODO(), curService, metav1.UpdateOptions{}) + } + c.enqueueController(curService, true) + c.syncToEdge(watch.Added, curService) +} diff --git a/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go b/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go new file mode 100644 index 00000000..a6aa637f --- /dev/null +++ b/pkg/globalmanager/controllers/jointinference/jointinferenceservice_test.go @@ -0,0 +1,305 @@ +package jointinference + +import ( + "context" + "testing" + + sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1" + fakeseednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/fake" + "github.com/kubeedge/sedna/pkg/globalmanager/config" + "github.com/kubeedge/sedna/pkg/globalmanager/runtime" + appsv1 "k8s.io/api/apps/v1" + v1 "k8s.io/api/core/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/watch" + kubernetesfake "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/kubernetes/scheme" + v1core "k8s.io/client-go/kubernetes/typed/core/v1" + corelisters "k8s.io/client-go/listers/apps/v1" + corelistersv1 "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/workqueue" +) + +type mockPodLister struct { + pods []*v1.Pod +} + +type mockPodNamespaceLister struct { + pods []*v1.Pod + namespace string +} + +func (m *mockPodLister) Pods(namespace string) corelistersv1.PodNamespaceLister { + return mockPodNamespaceLister{pods: m.pods, namespace: namespace} +} + +func (m *mockPodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { + return m.pods, nil +} + +func (m mockPodNamespaceLister) List(selector labels.Selector) ([]*v1.Pod, error) { + var filteredPods []*v1.Pod + for _, pod := range m.pods { + if pod.Namespace == m.namespace { + filteredPods = append(filteredPods, pod) + } + } + return filteredPods, nil +} + +type mockDeploymentLister struct { + deployments []*appsv1.Deployment +} + +func (m *mockDeploymentLister) List(selector labels.Selector) (ret []*appsv1.Deployment, err error) { + return m.deployments, nil +} + +func (m *mockDeploymentLister) Deployments(namespace string) corelisters.DeploymentNamespaceLister { + return mockDeploymentNamespaceLister{deployments: m.deployments, namespace: namespace} +} + +func (m mockPodNamespaceLister) Get(name string) (*v1.Pod, error) { + for _, pod := range m.pods { + if pod.Namespace == m.namespace && pod.Name == name { + return pod, nil + } + } + return nil, k8serrors.NewNotFound(v1.Resource("pod"), name) +} + +type mockDeploymentNamespaceLister struct { + deployments []*appsv1.Deployment + namespace string +} + +func (m mockDeploymentNamespaceLister) List(selector labels.Selector) ([]*appsv1.Deployment, error) { + var filteredDeployments []*appsv1.Deployment + for _, deployment := range m.deployments { + if deployment.Namespace == m.namespace { + filteredDeployments = append(filteredDeployments, deployment) + } + } + return filteredDeployments, nil +} + +func (m mockDeploymentNamespaceLister) Get(name string) (*appsv1.Deployment, error) { + for _, deployment := range m.deployments { + if deployment.Namespace == m.namespace && deployment.Name == name { + return deployment, nil + } + } + return nil, k8serrors.NewNotFound(v1.Resource("deployment"), name) +} + +func Test_updateService(t *testing.T) { + t.Run("update joint inference service successfully", func(t *testing.T) { + // Create fake clients + fakeSednaClient := fakeseednaclientset.NewSimpleClientset() + fakeKubeClient := kubernetesfake.NewSimpleClientset() + + // Create a test joint inference service + oldService := &sednav1.JointInferenceService{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service", + Namespace: "default", + Generation: 1, + ResourceVersion: "1", + }, + Spec: sednav1.JointInferenceServiceSpec{ + EdgeWorker: sednav1.EdgeWorker{ + Model: sednav1.SmallModel{ + Name: "test-edge-model", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "edge-container", + Image: "edge-image:v1", + }, + }, + }, + }, + HardExampleMining: sednav1.HardExampleMining{ + Name: "test-hem", + Parameters: []sednav1.ParaSpec{ + { + Key: "param1", + Value: "value1", + }, + }, + }, + }, + CloudWorker: sednav1.CloudWorker{ + Model: sednav1.BigModel{ + Name: "test-cloud-model", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "cloud-container", + Image: "cloud-image:v1", + }, + }, + }, + }, + }, + }, + } + + //Create Big Model Resource Object for Cloud + bigModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cloud-model", + Namespace: "default", + }, + } + _, err := fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), bigModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test big model: %v", err) + } + + // Create Small Model Resource Object for Edge + smallModel := &sednav1.Model{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-edge-model", + Namespace: "default", + }, + } + _, err = fakeSednaClient.SednaV1alpha1().Models("default").Create(context.TODO(), smallModel, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test small model: %v", err) + } + + // Create the service using the fake client + _, err = fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Create(context.TODO(), oldService, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create test service: %v", err) + } + + // Create test deployments + edgeDeployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-deployment-edge", + Namespace: "default", + }, + Spec: appsv1.DeploymentSpec{ + Template: oldService.Spec.EdgeWorker.Template, + }, + } + cloudDeployment := &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-deployment-cloud", + Namespace: "default", + }, + Spec: appsv1.DeploymentSpec{ + Template: oldService.Spec.CloudWorker.Template, + }, + } + + _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), edgeDeployment, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create edge deployment: %v", err) + } + _, err = fakeKubeClient.AppsV1().Deployments("default").Create(context.TODO(), cloudDeployment, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create cloud deployment: %v", err) + } + + // Manually create pods for the deployments + edgePod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service-edge-pod", + Namespace: "default", + Labels: map[string]string{ + "jointinferenceservice.sedna.io/service-name": "test-ji-service", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: edgeDeployment.Name, + UID: edgeDeployment.UID, + }, + }, + }, + Spec: edgeDeployment.Spec.Template.Spec, + } + cloudPod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ji-service-cloud-pod", + Namespace: "default", + Labels: map[string]string{ + "jointinferenceservice.sedna.io/service-name": "test-ji-service", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "apps/v1", + Kind: "Deployment", + Name: cloudDeployment.Name, + UID: cloudDeployment.UID, + }, + }, + }, + Spec: cloudDeployment.Spec.Template.Spec, + } + + // Add pods to the fake client + _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), edgePod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create edge pod: %v", err) + } + _, err = fakeKubeClient.CoreV1().Pods("default").Create(context.TODO(), cloudPod, metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create cloud pod: %v", err) + } + + cfg := &config.ControllerConfig{ + LC: config.LCConfig{ + Server: "http://test-lc-server:8080", + }, + } + + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: fakeKubeClient.CoreV1().Events("")}) + + // Create a controller with the fake clients + c := &Controller{ + kubeClient: fakeKubeClient, + client: fakeSednaClient.SednaV1alpha1(), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "test-ji-service"), + recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "test-ji-service"}), + cfg: cfg, + deploymentsLister: &mockDeploymentLister{deployments: []*appsv1.Deployment{edgeDeployment, cloudDeployment}}, + podStore: &mockPodLister{pods: []*v1.Pod{edgePod, cloudPod}}, + bigModelHost: "test-ji-service-cloud.default", + selector: labels.SelectorFromSet(labels.Set{"jointinferenceservice.sedna.io/service-name": "test-ji-service"}), + sendToEdgeFunc: func(nodeName string, eventType watch.EventType, job interface{}) error { + return nil + }, + } + + // Update the service + newService := oldService.DeepCopy() + // change parameter of hard example mining + newService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value = "value2" + newService.Generation = 2 + newService.ResourceVersion = "2" + // Call updateService function + c.updateService(oldService, newService) + + // Verify that the deployments were deleted and recreated + updatedService, err := fakeSednaClient.SednaV1alpha1().JointInferenceServices("default").Get(context.TODO(), "test-ji-service", metav1.GetOptions{}) + if err != nil { + t.Fatalf("Failed to get updated service: %v", err) + } + if updatedService.Spec.EdgeWorker.HardExampleMining.Parameters[0].Value != "value2" { + t.Fatalf("Service was not updated correctly") + } + }) +}