Skip to content

Commit

Permalink
Merge pull request PelicanPlatform#1575 from bbockelm/cache_stat
Browse files Browse the repository at this point in the history
Enable caching of object existence queries
  • Loading branch information
jhiemstrawisc authored Oct 23, 2024
2 parents 0c2ab6e + 707d5d0 commit 1ddcc0e
Show file tree
Hide file tree
Showing 12 changed files with 495 additions and 72 deletions.
2 changes: 2 additions & 0 deletions cmd/fed_serve_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func TestFedServeCache(t *testing.T) {
viper.Set("Origin.EnableCmsd", false)
viper.Set("Origin.EnableMacaroons", false)
viper.Set("Origin.EnableVoms", false)
viper.Set("Server.WebPort", 0)
viper.Set("Origin.Port", 0)
viper.Set("TLSSkipVerify", true)
viper.Set("Server.EnableUI", false)
viper.Set("Registry.DbLocation", filepath.Join(t.TempDir(), "ns-registry.sqlite"))
Expand Down
8 changes: 6 additions & 2 deletions config/resources/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ Director:
CacheSortMethod: "distance"
MinStatResponse: 1
MaxStatResponse: 1
StatTimeout: 1000ms
StatTimeout: 2000ms
StatConcurrencyLimit: 1000
AdvertisementTTL: 15m
OriginCacheHealthTestInterval: 15s
EnableBroker: true
EnableStat: true
CheckOriginPresence: true
CheckCachePresence: true
AssumePresenceAtSingleOrigin: true
CachePresenceTTL: 1m
CachePresenceCapacity: 10000
Cache:
Port: 8442
SelfTest: true
Expand Down
8 changes: 7 additions & 1 deletion director/cache_ads.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (

"github.com/pelicanplatform/pelican/param"
"github.com/pelicanplatform/pelican/server_structs"
"github.com/pelicanplatform/pelican/utils"
)

type filterType string
Expand Down Expand Up @@ -143,12 +144,17 @@ func recordAd(ctx context.Context, sAd server_structs.ServerAd, namespaceAds *[]
if concLimit == 0 {
concLimit = -1
}
statErrGrp := errgroup.Group{}
statErrGrp := utils.Group{}
statErrGrp.SetLimit(concLimit)
newUtil := serverStatUtil{
Errgroup: &statErrGrp,
Cancel: cancel,
Context: baseCtx,
ResultCache: ttlcache.New[string, *objectMetadata](
ttlcache.WithTTL[string, *objectMetadata](param.Director_CachePresenceTTL.GetDuration()),
ttlcache.WithDisableTouchOnHit[string, *objectMetadata](),
ttlcache.WithCapacity[string, *objectMetadata](uint64(param.Director_CachePresenceCapacity.GetInt())),
),
}
statUtils[ad.URL.String()] = newUtil
}
Expand Down
14 changes: 8 additions & 6 deletions director/director.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/hashicorp/go-version"
"github.com/jellydator/ttlcache/v3"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -67,9 +68,10 @@ type (
}
// Utility struct to keep track of the `stat` call the director made to the origin/cache servers
serverStatUtil struct {
Context context.Context
Cancel context.CancelFunc
Errgroup *errgroup.Group
Context context.Context
Cancel context.CancelFunc
Errgroup *utils.Group
ResultCache *ttlcache.Cache[string, *objectMetadata]
}

// Context key for the project name
Expand Down Expand Up @@ -382,7 +384,7 @@ func redirectToCache(ginCtx *gin.Context) {

reqParams := getRequestParameters(ginCtx.Request)

disableStat := !param.Director_EnableStat.GetBool()
disableStat := !param.Director_CheckCachePresence.GetBool()

// Skip the stat check for object availability
// If either disableStat or skipstat is set, then skip the stat query
Expand Down Expand Up @@ -603,7 +605,7 @@ func redirectToOrigin(ginCtx *gin.Context) {
reqParams := getRequestParameters(ginCtx.Request)

// Skip the stat check for object availability if either disableStat or skipstat is set
skipStat := reqParams.Has(pelican_url.QuerySkipStat) || !param.Director_EnableStat.GetBool()
skipStat := reqParams.Has(pelican_url.QuerySkipStat) || !param.Director_CheckOriginPresence.GetBool()

// Include caches in the response if Director.CachesPullFromCaches is enabled
// AND prefercached query parameter is set
Expand All @@ -622,7 +624,7 @@ func redirectToOrigin(ginCtx *gin.Context) {
}

// If the namespace requires a token yet there's no token available, skip the stat.
if !namespaceAd.Caps.PublicReads && reqParams.Get("authz") == "" {
if (!namespaceAd.Caps.PublicReads && reqParams.Get("authz") == "") || (param.Director_AssumePresenceAtSingleOrigin.GetBool() && len(originAds) == 1) {
skipStat = true
}

Expand Down
111 changes: 83 additions & 28 deletions director/stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strconv"
"time"

"github.com/jellydator/ttlcache/v3"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -263,6 +264,21 @@ func WithToken(tk string) queryOption {
}
}

func getStatUtils(ads []server_structs.ServerAd) map[string]*serverStatUtil {
statUtilsMutex.RLock()
defer statUtilsMutex.RUnlock()

result := make(map[string]*serverStatUtil, len(ads))
for _, ad := range ads {
url := ad.URL.String()
statUtil, ok := statUtils[url]
if ok {
result[url] = &statUtil
}
}
return result
}

// Implementation of querying origins/cache servers for their availability of an object.
// It blocks until max successful requests has been received, all potential origins/caches responded (or timeout), or cancelContext was closed.
//
Expand Down Expand Up @@ -310,9 +326,11 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
return
}
timeout := param.Director_StatTimeout.GetDuration()
positiveReqChan := make(chan *objectMetadata)
negativeReqChan := make(chan error)
deniedReqChan := make(chan *headReqForbiddenErr) // Requests with 403 response
// Note there is a small buffer in each channel; in the case of a cache hit, we write
// to the channel from within this goroutine.
positiveReqChan := make(chan *objectMetadata, 5)
negativeReqChan := make(chan error, 5)
deniedReqChan := make(chan *headReqForbiddenErr, 5) // Requests with 403 response
// Cancel the rest of the requests when requests received >= max required
maxCancelCtx, maxCancel := context.WithCancel(ctx)
numTotalReq := 0
Expand All @@ -327,12 +345,9 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
return
}

// Use RLock to allolw multiple queries
statUtilsMutex.RLock()
defer statUtilsMutex.RUnlock()

utils := getStatUtils(ads)
for _, adExt := range ads {
statUtil, ok := statUtils[adExt.URL.String()]
statUtil, ok := utils[adExt.URL.String()]
if !ok {
numTotalReq += 1
log.Debugf("Server %q is missing data for stat call, skip querying...", adExt.Name)
Expand All @@ -345,15 +360,24 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
}
// Use an anonymous func to pass variable safely to the goroutine
func(serverAd server_structs.ServerAd) {
statUtil.Errgroup.Go(func() error {
baseUrl := serverAd.URL

// For the topology server, if the server does not support public read,
// or the token is provided, or the object is protected, then it's safe to assume this request goes to authenticated endpoint
// For Pelican server, we don't populate authURL and only use server URL as the base URL
if serverAd.FromTopology && (!serverAd.Caps.PublicReads || cfg.protected || cfg.token != "") && serverAd.AuthURL.String() != "" {
baseUrl = serverAd.AuthURL
}

baseUrl := serverAd.URL
// For the topology server, if the server does not support public read,
// or the token is provided, or the object is protected, then it's safe to assume this request goes to authenticated endpoint
// For Pelican server, we don't populate authURL and only use server URL as the base URL
if serverAd.FromTopology && (!serverAd.Caps.PublicReads || cfg.protected || cfg.token != "") && serverAd.AuthURL.String() != "" {
baseUrl = serverAd.AuthURL
}

totalLabels := prometheus.Labels{
"server_name": serverAd.Name,
"server_url": baseUrl.String(),
"server_type": string(serverAd.Type),
"cached_result": "false",
"result": "",
}

queryFunc := func() (metadata *objectMetadata, err error) {

activeLabels := prometheus.Labels{
"server_name": serverAd.Name,
Expand All @@ -363,22 +387,30 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
metrics.PelicanDirectorStatActive.With(activeLabels).Inc()
defer metrics.PelicanDirectorStatActive.With(activeLabels).Dec()

metadata, err := stat.ReqHandler(maxCancelCtx, objectName, baseUrl, true, cfg.token, timeout)
metadata, err = stat.ReqHandler(maxCancelCtx, objectName, baseUrl, true, cfg.token, timeout)

var reqNotFound *headReqNotFoundErr
cancelErr := &headReqCancelledErr{}
if err != nil && !errors.As(err, &cancelErr) { // Skip additional requests if the previous one is cancelled
if err != nil && !errors.As(err, &cancelErr) && !errors.As(err, &reqNotFound) {
// If the request returns 403 or 500, it could be because we request a digest and xrootd
// does not have this turned on, or had trouble calculating the checksum
// Retry without digest
metadata, err = stat.ReqHandler(maxCancelCtx, objectName, baseUrl, false, cfg.token, timeout)
}

totalLabels := prometheus.Labels{
"server_name": serverAd.Name,
"server_url": baseUrl.String(),
"server_type": string(serverAd.Type),
"result": "",
// If get a 404, record it in the cache.
if errors.As(err, &reqNotFound) {
statUtil.ResultCache.Set(objectName, nil, ttlcache.DefaultTTL)
} else if err == nil {
statUtil.ResultCache.Set(objectName, metadata, ttlcache.DefaultTTL)
}

return
}

lookupFunc := func() error {

metadata, err := queryFunc()
if err != nil {
switch e := err.(type) {
case *headReqTimeoutErr:
Expand Down Expand Up @@ -418,7 +450,30 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
positiveReqChan <- metadata
}
return nil
})
}

if item := statUtil.ResultCache.Get(objectName); item != nil {
// If we get a cache hit -- but the cache item is going to expire in the next 10 seconds,
// then we assume this is a "hot" object and we'll benefit from the preemptively refreshing
// the ttlcache. If we can, asynchronously query the service.
if time.Until(item.ExpiresAt()) < 10*time.Second {
statUtil.Errgroup.TryGo(func() (err error) { _, err = queryFunc(); return })
}
totalLabels["cached_result"] = "true"
if metadata := item.Value(); metadata != nil {
totalLabels["result"] = string(metrics.StatSucceeded)
metrics.PelicanDirectorStatTotal.With(totalLabels).Inc()
positiveReqChan <- metadata
} else {
log.Debugf("Object %s not found at %s server %s: (cached result)", objectName, serverAd.Type, baseUrl.String())
negativeReqChan <- &headReqNotFoundErr{}
totalLabels["result"] = string(metrics.StatNotFound)
metrics.PelicanDirectorStatTotal.With(totalLabels).Inc()
}
metrics.PelicanDirectorStatTotal.With(totalLabels).Inc()
} else {
statUtil.Errgroup.TryGoUntil(ctx, lookupFunc)
}
}(adExt)
}

Expand Down Expand Up @@ -454,9 +509,9 @@ func (stat *ObjectStat) queryServersForObject(ctx context.Context, objectName st
qResult.Status = queryFailed
qResult.ErrorType = queryInsufficientResErr
qResult.Msg = fmt.Sprintf("Number of success response: %d is less than MinStatResponse (%d) required.", len(successResult), minReq)
serverIssuers := []string{}
for _, dErr := range deniedResult {
serverIssuers = append(serverIssuers, dErr.IssuerUrl)
serverIssuers := make([]string, len(deniedResult))
for idx, dErr := range deniedResult {
serverIssuers[idx] = dErr.IssuerUrl
}
qResult.DeniedServers = serverIssuers
return
Expand Down
Loading

0 comments on commit 1ddcc0e

Please sign in to comment.