Skip to content

Commit

Permalink
Merge pull request #54 from k8gb-io/support-case-sensitive-queries
Browse files Browse the repository at this point in the history
Supporting CaseInsensitive Queries
  • Loading branch information
kuritka authored Aug 3, 2023
2 parents be4e720 + 3521238 commit 659b497
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 33 deletions.
59 changes: 53 additions & 6 deletions common/k8sctrl/ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,60 @@ func endpointHostnameIndexFunc(obj interface{}) ([]string, error) {

func (ctrl *KubeController) getEndpointByName(host string, clientIP net.IP) (lep LocalDNSEndpoint) {
log.Infof("Index key %+v", host)
objs, _ := ctrl.epc.GetIndexer().ByIndex(endpointHostnameIndex, strings.ToLower(host))
for _, obj := range objs {
endpoints := ctrl.getEndpointsByCaseInsensitiveName(host, clientIP)
lep = ctrl.margeLocalDNSEndpoints(host, endpoints)
return lep
}

// The function tries to find all case sensitive variants. Returns a map where the call is hostname and the value is LocalDNSEndpoint
func (ctrl *KubeController) getEndpointsByCaseInsensitiveName(host string, clientIP net.IP) (result map[string]LocalDNSEndpoint) {

// The function extracts LocalDNSEndpoints from *DNSEndpoint. The function is hardwired with a case-sensitive extraction scenario and is only used in a
// single location, so it is currently declared inside the calling function.
extractLocalEndpoints := func(ep *endpoint.DNSEndpoint, ip net.IP, host string) (result []LocalDNSEndpoint) {
result = []LocalDNSEndpoint{}
for _, e := range ep.Spec.Endpoints {
if strings.EqualFold(e.DNSName, host) {
r := LocalDNSEndpoint{}
r.DNSName = e.DNSName
r.Labels = e.Labels
r.TTL = e.RecordTTL
r.Targets = e.Targets
if e.Labels["strategy"] == "geoip" {
targets := r.extractGeo(e, ip)
if len(targets) > 0 {
r.Targets = targets
}
}
result = append(result, r)
}
}
return result
}

epList := ctrl.epc.GetIndexer().List()
result = make(map[string]LocalDNSEndpoint, 0)
for _, obj := range epList {
ep := obj.(*endpoint.DNSEndpoint)
lep = extractLocalEndpoint(ep, clientIP, host)
if !lep.isEmpty() {
break
extracts := extractLocalEndpoints(ep, clientIP, host)
for _, extracted := range extracts {
if strings.EqualFold(extracted.DNSName, host) {
result[extracted.DNSName] = extracted
log.Debugf("including DNSEndpoint: %s", extracted.String())
}
}
}
return lep
return result
}

func (ctrl *KubeController) margeLocalDNSEndpoints(host string, endpoints map[string]LocalDNSEndpoint) LocalDNSEndpoint {
result := LocalDNSEndpoint{
DNSName: host,
}
result.Labels = endpoints[host].Labels
result.TTL = endpoints[host].TTL
for _, ep := range endpoints {
result.Targets = append(result.Targets, ep.Targets...)
}
return result
}
144 changes: 141 additions & 3 deletions common/k8sctrl/ctrl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Generated by GoLic, for more details see: https://github.com/AbsaOSS/golic
import (
"context"
"reflect"
"sort"
"testing"
"unsafe"

Expand All @@ -37,6 +38,11 @@ import (
func TestKubeController(t *testing.T) {
const label = "k8gb.absa.oss/dnstype=local"
const host = "roundrobin.cloud.example.com"
const hOSTCaseInsensitive = "roundrobin-case-insensitive.CLOUD.EXAMPLE.COM"
const hostCaseInsensitive = "roundrobin-case-insensitive.cloud.example.com"
const embeddedCaseSensitive = "embedded.CLOUD.EXAMPLE.COM"
const embeddedCaseInsensitive = "embedded.cloud.example.com"
var clientIP = []byte{0x0A, 0x0A, 0x0A, 0x01}
defaultEP := &endpoint.DNSEndpoint{
Spec: endpoint.DNSEndpointSpec{
Endpoints: []*endpoint.Endpoint{
Expand All @@ -48,16 +54,77 @@ func TestKubeController(t *testing.T) {
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"10.0.0.1", "10.0.0.2"},
},
{
DNSName: hOSTCaseInsensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"1.1.1.1", "1.1.1.2"},
},
{
DNSName: hostCaseInsensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"2.2.2.2"},
},
},
},
}

ep1 := &endpoint.DNSEndpoint{
Spec: endpoint.DNSEndpointSpec{
Endpoints: []*endpoint.Endpoint{
{
DNSName: "localtargets-" + hOSTCaseInsensitive,
},
{
DNSName: hOSTCaseInsensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"1.1.1.1", "1.1.1.2"},
},
},
},
}

ep2 := &endpoint.DNSEndpoint{
Spec: endpoint.DNSEndpointSpec{
Endpoints: []*endpoint.Endpoint{
{
DNSName: "localtargets-" + hOSTCaseInsensitive,
},
{
DNSName: hostCaseInsensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"2.2.2.2"},
},
},
},
}

epEmbedded := &endpoint.DNSEndpoint{
Spec: endpoint.DNSEndpointSpec{
Endpoints: []*endpoint.Endpoint{
{
DNSName: "localtargets-" + embeddedCaseSensitive,
},
{
DNSName: embeddedCaseInsensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"10.10.10.32", "10.10.10.30"},
},
{
DNSName: embeddedCaseSensitive,
Labels: map[string]string{"strategy": "roundrobin"},
Targets: []string{"10.10.10.2"},
},
},
},
}

ctrl := gomock.NewController(t)
defer ctrl.Finish()
mctrl := mocks.NewMockInterface(ctrl)
mcache := mocks.NewMockSharedIndexInformer(ctrl)
midx := mocks.NewMockIndexer(ctrl)
midx.EXPECT().ByIndex(endpointHostnameIndex, host).Return([]interface{}{defaultEP}, nil).Times(1)
mcache.EXPECT().GetIndexer().Return(midx)
midx.EXPECT().List().Return([]interface{}{defaultEP, ep1, ep2, epEmbedded}).AnyTimes()
mcache.EXPECT().GetIndexer().Return(midx).AnyTimes()

client := getClient(mctrl)

Expand All @@ -72,11 +139,82 @@ func TestKubeController(t *testing.T) {
k8sctrl.epc = mcache

t.Run("get no-geo endpoint by name", func(t *testing.T) {
lep := k8sctrl.getEndpointByName("roundrobin.cloud.example.com", []byte{0x0A, 0x0A, 0x0A, 0x01})
lep := k8sctrl.getEndpointByName(host, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "roundrobin.cloud.example.com: 0, Targets: [10.0.0.1 10.0.0.2], Labels: map[strategy:roundrobin]", lep.String())
})

t.Run("valid uppercase domain query", func(t *testing.T) {
lep := k8sctrl.getEndpointByName(hOSTCaseInsensitive, clientIP)
assert.NotNil(t, lep)
sort.Strings(lep.Targets)
assert.Equal(t, "roundrobin-case-insensitive.CLOUD.EXAMPLE.COM: 0, Targets: [1.1.1.1 1.1.1.2 2.2.2.2], Labels: map[strategy:roundrobin]", lep.String())

lep = k8sctrl.getEndpointByName(hostCaseInsensitive, clientIP)
assert.NotNil(t, lep)
sort.Strings(lep.Targets)
assert.Equal(t, "roundrobin-case-insensitive.cloud.example.com: 0, Targets: [1.1.1.1 1.1.1.2 2.2.2.2], Labels: map[strategy:roundrobin]", lep.String())
})

t.Run("handle multiple embedded endpoints", func(t *testing.T) {
lep := k8sctrl.getEndpointByName(embeddedCaseInsensitive, clientIP)
assert.NotNil(t, lep)
sort.Strings(lep.Targets)
assert.Equal(t, "embedded.cloud.example.com: 0, Targets: [10.10.10.2 10.10.10.30 10.10.10.32], Labels: map[strategy:roundrobin]", lep.String())

lep = k8sctrl.getEndpointByName(embeddedCaseSensitive, clientIP)
assert.NotNil(t, lep)
sort.Strings(lep.Targets)
assert.Equal(t, "embedded.CLOUD.EXAMPLE.COM: 0, Targets: [10.10.10.2 10.10.10.30 10.10.10.32], Labels: map[strategy:roundrobin]", lep.String())
})

t.Run("handle multiple embedded endpoints but one EP is empty", func(t *testing.T) {
epEmbedded.Spec.Endpoints[1].Targets = []string{}
lep := k8sctrl.getEndpointByName(embeddedCaseInsensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.cloud.example.com: 0, Targets: [10.10.10.2], Labels: map[strategy:roundrobin]", lep.String())

lep = k8sctrl.getEndpointByName(embeddedCaseSensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.CLOUD.EXAMPLE.COM: 0, Targets: [10.10.10.2], Labels: map[strategy:roundrobin]", lep.String())

})

t.Run("EP has empty targets", func(t *testing.T) {
epEmbedded.Spec.Endpoints[1].Targets = []string{}
epEmbedded.Spec.Endpoints[2].Targets = []string{}
epEmbedded.Spec.Endpoints[0].Targets = []string{}
lep := k8sctrl.getEndpointByName(embeddedCaseInsensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.cloud.example.com: 0, Targets: [], Labels: map[strategy:roundrobin]", lep.String())

lep = k8sctrl.getEndpointByName(embeddedCaseSensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.CLOUD.EXAMPLE.COM: 0, Targets: [], Labels: map[strategy:roundrobin]", lep.String())
})

t.Run("EP has no dns endpoints", func(t *testing.T) {
epEmbedded.Spec.Endpoints = []*endpoint.Endpoint{}
lep := k8sctrl.getEndpointByName(embeddedCaseInsensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.cloud.example.com: 0, Targets: [], Labels: map[]", lep.String())

lep = k8sctrl.getEndpointByName(embeddedCaseSensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.CLOUD.EXAMPLE.COM: 0, Targets: [], Labels: map[]", lep.String())
})

t.Run("EP has nil dns endpoints", func(t *testing.T) {
epEmbedded.Spec.Endpoints = nil
lep := k8sctrl.getEndpointByName(embeddedCaseInsensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.cloud.example.com: 0, Targets: [], Labels: map[]", lep.String())

lep = k8sctrl.getEndpointByName(embeddedCaseSensitive, clientIP)
assert.NotNil(t, lep)
assert.Equal(t, "embedded.CLOUD.EXAMPLE.COM: 0, Targets: [], Labels: map[]", lep.String())
})

}

func getClient(i rest.Interface) (c *dnsendpoint.ExtDNSClient) {
Expand Down
24 changes: 0 additions & 24 deletions common/k8sctrl/ep.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,6 @@ func (lep LocalDNSEndpoint) String() string {
return fmt.Sprintf("%s: %v, Targets: %v, Labels: %v", lep.DNSName, lep.TTL, lep.Targets, lep.Labels)
}

func extractLocalEndpoint(ep *endpoint.DNSEndpoint, ip net.IP, host string) (result LocalDNSEndpoint) {
result = LocalDNSEndpoint{}
for _, e := range ep.Spec.Endpoints {
if e.DNSName == host {
result.DNSName = host
result.Labels = e.Labels
result.TTL = e.RecordTTL
result.Targets = e.Targets
if e.Labels["strategy"] == "geoip" {
targets := result.extractGeo(e, ip)
if len(targets) > 0 {
result.Targets = targets
}
}
break
}
}
return result
}

func (lep LocalDNSEndpoint) isEmpty() bool {
return len(lep.Targets) == 0 && (len(lep.Labels) == 0) && (lep.TTL == 0)
}

func (lep LocalDNSEndpoint) extractGeo(endpoint *endpoint.Endpoint, clientIP net.IP) (result []string) {
db, err := maxminddb.Open("geoip.mmdb")
if err != nil {
Expand Down

0 comments on commit 659b497

Please sign in to comment.