diff --git a/director/cache_ads.go b/director/cache_ads.go index 1cc1ca2dd..6f5160cb2 100644 --- a/director/cache_ads.go +++ b/director/cache_ads.go @@ -108,6 +108,7 @@ func UpdateLatLong(ad *ServerAd) error { func matchesPrefix(reqPath string, namespaceAds []NamespaceAd) *NamespaceAd { var best *NamespaceAd + for _, namespace := range namespaceAds { serverPath := namespace.Path if strings.Compare(serverPath, reqPath) == 0 { @@ -116,7 +117,8 @@ func matchesPrefix(reqPath string, namespaceAds []NamespaceAd) *NamespaceAd { // Some namespaces in Topology already have the trailing /, some don't // Perhaps this should be standardized, but in case it isn't we need to - // handle it + // handle it throughout this function. Note that reqPath already has the + // tail from being called by GetAdsForPath if serverPath[len(serverPath)-1:] != "/" { serverPath += "/" } @@ -145,8 +147,16 @@ func matchesPrefix(reqPath string, namespaceAds []NamespaceAd) *NamespaceAd { func GetAdsForPath(reqPath string) (originNamespace NamespaceAd, originAds []ServerAd, cacheAds []ServerAd) { serverAdMutex.RLock() defer serverAdMutex.RUnlock() + + // Clean the path, but re-append a trailing / to deal with some namespaces + // from topo that have a trailing / reqPath = path.Clean(reqPath) + reqPath += "/" + // Iterate through all of the server ads. For each "item", the key + // is the server ad itself (either cache or origin), and the value + // is a slice of namespace prefixes are supported by that server + var best *NamespaceAd for _, item := range serverAds.Items() { if item == nil { continue @@ -155,15 +165,31 @@ func GetAdsForPath(reqPath string) (originNamespace NamespaceAd, originAds []Ser if serverAd.Type == OriginType { ns := matchesPrefix(reqPath, item.Value()) if ns != nil { - originNamespace = *ns - originAds = append(originAds, serverAd) + if best == nil || len(ns.Path) > len(best.Path) { + best = ns + // If anything was previously set by a namespace that constituted a shorter + // prefix, we overwrite that here because we found a better ns. We also clear + // the other slice of server ads, because we know those aren't good anymore + originAds = append(originAds[:0], serverAd) + cacheAds = []ServerAd{} + } else if ns.Path == best.Path { + originAds = append(originAds, serverAd) + } } continue } else if serverAd.Type == CacheType { - if matchesPrefix(reqPath, item.Value()) != nil { - cacheAds = append(cacheAds, serverAd) + if ns := matchesPrefix(reqPath, item.Value()); ns != nil { + if best == nil || len(ns.Path) > len(best.Path) { + best = ns + cacheAds = append(cacheAds[:0], serverAd) + originAds = []ServerAd{} + } else if ns.Path == best.Path { + cacheAds = append(cacheAds, serverAd) + } } } } + + originNamespace = *best return } diff --git a/director/cache_ads_test.go b/director/cache_ads_test.go new file mode 100644 index 000000000..2c17ec766 --- /dev/null +++ b/director/cache_ads_test.go @@ -0,0 +1,171 @@ +/*************************************************************** + * + * Copyright (C) 2023, Pelican Project, Morgridge Institute for Research + * + * Licensed under the Apache License, Version 2.0 (the "License"); you + * may not use this file except in compliance with the License. You may + * obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + ***************************************************************/ + +package director + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func hasServerAdWithName(serverAds []ServerAd, name string) bool { + for _, serverAd := range serverAds { + if serverAd.Name == name { + return true + } + } + return false +} + +// Test getAdsForPath to make sure various nuanced cases work. Under the hood +// this really tests matchesPrefix, but we test this higher level function to +// avoid having to mess with the cache. +func TestGetAdsForPath(t *testing.T) { + /* + FLOW: + - Set up a few dummy namespaces, origin, and cache ads + - Record the ads + - Query for a few paths and make sure the correct ads are returned + */ + nsAd1 := NamespaceAd{ + RequireToken: true, + Path: "/chtc", + Issuer: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + } + + nsAd2 := NamespaceAd{ + RequireToken: false, + Path: "/chtc/PUBLIC", + Issuer: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + } + + nsAd3 := NamespaceAd{ + RequireToken: false, + Path: "/chtc/PUBLIC2/", + Issuer: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + } + + cacheAd1 := ServerAd{ + Name: "cache1", + AuthURL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + URL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + Type: CacheType, + } + + cacheAd2 := ServerAd{ + Name: "cache2", + AuthURL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + URL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + Type: CacheType, + } + + originAd1 := ServerAd{ + Name: "origin1", + AuthURL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + URL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + Type: OriginType, + } + + originAd2 := ServerAd{ + Name: "origin2", + AuthURL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + URL: url.URL{ + Scheme: "https", + Host: "wisc.edu", + }, + Type: OriginType, + } + + o1Slice := []NamespaceAd{nsAd1} + o2Slice := []NamespaceAd{nsAd2, nsAd3} + c1Slice := []NamespaceAd{nsAd1, nsAd2} + RecordAd(originAd2, &o2Slice) + RecordAd(originAd1, &o1Slice) + RecordAd(cacheAd1, &c1Slice) + RecordAd(cacheAd2, &o1Slice) + + nsAd, oAds, cAds := GetAdsForPath("/chtc") + assert.Equal(t, nsAd.Path, "/chtc") + assert.Equal(t, len(oAds), 1) + assert.Equal(t, len(cAds), 2) + assert.True(t, hasServerAdWithName(oAds, "origin1")) + assert.True(t, hasServerAdWithName(cAds, "cache1")) + assert.True(t, hasServerAdWithName(cAds, "cache2")) + + nsAd, oAds, cAds = GetAdsForPath("/chtc/") + assert.Equal(t, nsAd.Path, "/chtc") + assert.Equal(t, len(oAds), 1) + assert.Equal(t, len(cAds), 2) + assert.True(t, hasServerAdWithName(oAds, "origin1")) + assert.True(t, hasServerAdWithName(cAds, "cache1")) + assert.True(t, hasServerAdWithName(cAds, "cache2")) + + nsAd, oAds, cAds = GetAdsForPath("/chtc/PUBLI") + assert.Equal(t, nsAd.Path, "/chtc") + assert.Equal(t, len(oAds), 1) + assert.Equal(t, len(cAds), 2) + assert.True(t, hasServerAdWithName(oAds, "origin1")) + assert.True(t, hasServerAdWithName(cAds, "cache1")) + assert.True(t, hasServerAdWithName(cAds, "cache2")) + + nsAd, oAds, cAds = GetAdsForPath("/chtc/PUBLIC") + assert.Equal(t, nsAd.Path, "/chtc/PUBLIC") + assert.Equal(t, len(oAds), 1) + assert.Equal(t, len(cAds), 1) + assert.True(t, hasServerAdWithName(oAds, "origin2")) + assert.True(t, hasServerAdWithName(cAds, "cache1")) + + nsAd, oAds, cAds = GetAdsForPath("/chtc/PUBLIC2") + // since the stored path is actually /chtc/PUBLIC2/, the extra / is returned + assert.Equal(t, nsAd.Path, "/chtc/PUBLIC2/") + assert.Equal(t, len(oAds), 1) + assert.Equal(t, len(cAds), 0) + assert.True(t, hasServerAdWithName(oAds, "origin2")) +}