From fff50df8f4453d2dc3cd460558e126d6685d628f Mon Sep 17 00:00:00 2001 From: Kirby Chin <37311900+kabicin@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:37:13 -0400 Subject: [PATCH] Remove additional endpoint checks when DNS is unreachable also, add nil check to dnsEndpoints.Subsets and GetEndpointPortByName --- .../webspherelibertyapplication_controller.go | 69 ++++++++++--------- utils/utils.go | 8 ++- 2 files changed, 42 insertions(+), 35 deletions(-) diff --git a/controllers/webspherelibertyapplication_controller.go b/controllers/webspherelibertyapplication_controller.go index 3b157842..63902178 100644 --- a/controllers/webspherelibertyapplication_controller.go +++ b/controllers/webspherelibertyapplication_controller.go @@ -425,36 +425,39 @@ func (r *ReconcileWebSphereLiberty) Reconcile(ctx context.Context, request ctrl. } apiServerNetworkPolicy.Spec.Egress = append(apiServerNetworkPolicy.Spec.Egress, dnsRule) - // If allowed, add an Egress rule to access the API server. - // Otherwise, if the OpenShift DNS or K8s CoreDNS Egress rule does not provide permissive cluster-wide access - // and the K8s API server could not be found, use a permissive cluster-wide Egress rule. - if apiServerEndpoints, err := r.getEndpoints("kubernetes", "default"); err == nil { - rule := networkingv1.NetworkPolicyEgressRule{} - // Define the port - port := networkingv1.NetworkPolicyPort{} - port.Protocol = &apiServerEndpoints.Subsets[0].Ports[0].Protocol - var portNumber intstr.IntOrString = intstr.FromInt((int)(apiServerEndpoints.Subsets[0].Ports[0].Port)) - port.Port = &portNumber - rule.Ports = append(rule.Ports, port) - - // Add the endpoint address as ipBlock entries - for _, endpoint := range apiServerEndpoints.Subsets { - for _, address := range endpoint.Addresses { - peer := networkingv1.NetworkPolicyPeer{} - ipBlock := networkingv1.IPBlock{} - ipBlock.CIDR = address.IP + "/32" - - peer.IPBlock = &ipBlock - rule.To = append(rule.To, peer) + // If the DNS rule is a specific Egress rule also check if another Egress rule can be created for the API server. + // Otherwise, fallback to a permissive cluster-wide Egress rule. + if !usingPermissiveRule { + if apiServerEndpoints, err := r.getEndpoints("kubernetes", "default"); err == nil { + rule := networkingv1.NetworkPolicyEgressRule{} + // Define the port + port := networkingv1.NetworkPolicyPort{} + port.Protocol = &apiServerEndpoints.Subsets[0].Ports[0].Protocol + var portNumber intstr.IntOrString = intstr.FromInt((int)(apiServerEndpoints.Subsets[0].Ports[0].Port)) + port.Port = &portNumber + rule.Ports = append(rule.Ports, port) + + // Add the endpoint address as ipBlock entries + for _, endpoint := range apiServerEndpoints.Subsets { + for _, address := range endpoint.Addresses { + peer := networkingv1.NetworkPolicyPeer{} + ipBlock := networkingv1.IPBlock{} + ipBlock.CIDR = address.IP + "/32" + + peer.IPBlock = &ipBlock + rule.To = append(rule.To, peer) + } } + apiServerNetworkPolicy.Spec.Egress = append(apiServerNetworkPolicy.Spec.Egress, rule) + reqLogger.Info("Found endpoints for kubernetes service in the default namespace") + } else { + // The operator couldn't create a rule for the K8s API server so add a permissive Egress rule + rule := networkingv1.NetworkPolicyEgressRule{} + apiServerNetworkPolicy.Spec.Egress = append(apiServerNetworkPolicy.Spec.Egress, rule) + reqLogger.Info("Found endpoints for kubernetes service in the default namespace") } - apiServerNetworkPolicy.Spec.Egress = append(apiServerNetworkPolicy.Spec.Egress, rule) - reqLogger.Info("Found endpoints for kubernetes service in the default namespace") - } else if !usingPermissiveRule { - rule := networkingv1.NetworkPolicyEgressRule{} - apiServerNetworkPolicy.Spec.Egress = append(apiServerNetworkPolicy.Spec.Egress, rule) - reqLogger.Info("Found endpoints for kubernetes service in the default namespace") } + apiServerNetworkPolicy.Labels = ba.GetLabels() apiServerNetworkPolicy.Annotations = oputils.MergeMaps(apiServerNetworkPolicy.Annotations, ba.GetAnnotations()) apiServerNetworkPolicy.Spec.PolicyTypes = []networkingv1.PolicyType{networkingv1.PolicyTypeEgress} @@ -953,11 +956,13 @@ func (r *ReconcileWebSphereLiberty) getEndpoints(serviceName string, namespace s func (r *ReconcileWebSphereLiberty) getDNSEgressRule(reqLogger logr.Logger, endpointsName string, endpointsNamespace string) (bool, networkingv1.NetworkPolicyEgressRule) { dnsRule := networkingv1.NetworkPolicyEgressRule{} if dnsEndpoints, err := r.getEndpoints(endpointsName, endpointsNamespace); err == nil { - if endpointPort := lutils.GetEndpointPortByName(&dnsEndpoints.Subsets[0].Ports, "dns"); endpointPort != nil { - dnsRule.Ports = append(dnsRule.Ports, lutils.CreateNetworkPolicyPortFromEndpointPort(endpointPort)) - } - if endpointPort := lutils.GetEndpointPortByName(&dnsEndpoints.Subsets[0].Ports, "dns-tcp"); endpointPort != nil { - dnsRule.Ports = append(dnsRule.Ports, lutils.CreateNetworkPolicyPortFromEndpointPort(endpointPort)) + if len(dnsEndpoints.Subsets) > 0 { + if endpointPort := lutils.GetEndpointPortByName(&dnsEndpoints.Subsets[0].Ports, "dns"); endpointPort != nil { + dnsRule.Ports = append(dnsRule.Ports, lutils.CreateNetworkPolicyPortFromEndpointPort(endpointPort)) + } + if endpointPort := lutils.GetEndpointPortByName(&dnsEndpoints.Subsets[0].Ports, "dns-tcp"); endpointPort != nil { + dnsRule.Ports = append(dnsRule.Ports, lutils.CreateNetworkPolicyPortFromEndpointPort(endpointPort)) + } } peer := networkingv1.NetworkPolicyPeer{} peer.NamespaceSelector = &metav1.LabelSelector{ diff --git a/utils/utils.go b/utils/utils.go index e96d0cfb..da3ef373 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -776,9 +776,11 @@ func GetRequiredLabels(name string, instance string) map[string]string { } func GetEndpointPortByName(endpointPorts *[]corev1.EndpointPort, name string) *corev1.EndpointPort { - for _, endpointPort := range *endpointPorts { - if endpointPort.Name == name { - return &endpointPort + if endpointPorts != nil { + for _, endpointPort := range *endpointPorts { + if endpointPort.Name == name { + return &endpointPort + } } } return nil