From fbd9eb0d92615dec753cf878058f506d29fd795c Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Wed, 28 Aug 2024 13:19:48 +0200 Subject: [PATCH] crypto: do not load already loaded TRC files (#4588) This commit introduces a TRC loader that keeps track of what TRC files have already been read. By doing so we can omit duplicate work of parsing TRC files and storing them in the DB. This applies to the control service and the daemon. --- control/trust.go | 2 +- control/trust/crypto_loader.go | 32 ++++++++++++++++---- daemon/cmd/daemon/main.go | 11 ++++--- private/trust/store.go | 55 ++++++++++++++++++++++++++++++++-- private/trust/store_test.go | 42 ++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 13 deletions(-) diff --git a/control/trust.go b/control/trust.go index b16e49c05f..cff947de3d 100644 --- a/control/trust.go +++ b/control/trust.go @@ -103,7 +103,7 @@ func newCachingSignerGen( gen := trust.SignerGen{ IA: ia, - DB: cstrust.CryptoLoader{ + DB: &cstrust.CryptoLoader{ Dir: filepath.Join(cfgDir, "crypto/as"), TRCDirs: []string{filepath.Join(cfgDir, "certs")}, DB: db, diff --git a/control/trust/crypto_loader.go b/control/trust/crypto_loader.go index 238e69b30e..1081499468 100644 --- a/control/trust/crypto_loader.go +++ b/control/trust/crypto_loader.go @@ -18,6 +18,7 @@ import ( "context" "crypto/x509" "errors" + "sync" "github.com/scionproto/scion/pkg/log" "github.com/scionproto/scion/pkg/private/serrors" @@ -33,9 +34,13 @@ type CryptoLoader struct { Dir string // TRCDirs are optional directories from which TRCs are loaded. TRCDirs []string + + trcLoaders map[string]*trust.TRCLoader + trcLoadersInitMtx sync.Mutex } -func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) { +func (l *CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) { + l.initTRCLoaders() if err := l.loadTRCs(ctx); err != nil { log.FromCtx(ctx).Info("Failed to load TRCs from disk, continuing", "err", err) } @@ -44,7 +49,7 @@ func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.Sign // Chains loads chains from disk, stores them to DB, and returns the result from // DB. The fallback mode is always the result of the DB. -func (l CryptoLoader) Chains(ctx context.Context, +func (l *CryptoLoader) Chains(ctx context.Context, query trust.ChainQuery) ([][]*x509.Certificate, error) { r, err := trust.LoadChains(ctx, l.Dir, l.DB) @@ -70,7 +75,24 @@ func (l CryptoLoader) Chains(ctx context.Context, return l.DB.Chains(ctx, query) } -func (l CryptoLoader) loadTRCs(ctx context.Context) error { +func (l *CryptoLoader) initTRCLoaders() { + l.trcLoadersInitMtx.Lock() + defer l.trcLoadersInitMtx.Unlock() + if l.trcLoaders != nil { + return + } + l.trcLoaders = make(map[string]*trust.TRCLoader, len(l.TRCDirs)+1) + for _, dir := range append([]string{l.Dir}, l.TRCDirs...) { + if _, ok := l.trcLoaders[dir]; !ok { + l.trcLoaders[dir] = &trust.TRCLoader{ + DB: l.DB, + Dir: dir, + } + } + } +} + +func (l *CryptoLoader) loadTRCs(ctx context.Context) error { var errs serrors.List for _, dir := range append([]string{l.Dir}, l.TRCDirs...) { if err := l.loadTRCsFromDir(ctx, dir); err != nil { @@ -80,8 +102,8 @@ func (l CryptoLoader) loadTRCs(ctx context.Context) error { return errs.ToError() } -func (l CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error { - r, err := trust.LoadTRCs(ctx, dir, l.DB) +func (l *CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error { + r, err := l.trcLoaders[dir].Load(ctx) if err != nil { return err } diff --git a/daemon/cmd/daemon/main.go b/daemon/cmd/daemon/main.go index 5faf0c9bc2..32fd124427 100644 --- a/daemon/cmd/daemon/main.go +++ b/daemon/cmd/daemon/main.go @@ -159,10 +159,13 @@ func realMain(ctx context.Context) error { CacheHits: metrics.NewPromCounter(trustmetrics.CacheHitsTotal), MaxCacheExpiration: globalCfg.TrustEngine.Cache.Expiration.Duration, } - trcLoader := periodic.Start(periodic.Func{ + trcLoader := trust.TRCLoader{ + Dir: filepath.Join(globalCfg.General.ConfigDir, "certs"), + DB: trustDB, + } + trcLoaderTask := periodic.Start(periodic.Func{ Task: func(ctx context.Context) { - trcDirs := filepath.Join(globalCfg.General.ConfigDir, "certs") - res, err := trust.LoadTRCs(ctx, trcDirs, trustDB) + res, err := trcLoader.Load(ctx) if err != nil { log.SafeInfo(log.FromCtx(ctx), "TRC loading failed", "err", err) } @@ -172,7 +175,7 @@ func realMain(ctx context.Context) error { }, TaskName: "daemon_trc_loader", }, 10*time.Second, 10*time.Second) - defer trcLoader.Stop() + defer trcLoaderTask.Stop() var drkeyClientEngine *sd_drkey.ClientEngine if globalCfg.DRKeyLevel2DB.Connection != "" { diff --git a/private/trust/store.go b/private/trust/store.go index 1eaea6f017..76815cb656 100644 --- a/private/trust/store.go +++ b/private/trust/store.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "time" "github.com/scionproto/scion/pkg/private/serrors" @@ -112,9 +113,22 @@ func LoadChains(ctx context.Context, dir string, db DB) (LoadResult, error) { } // LoadTRCs loads all *.trc located in a directory in the database. This -// function exits on the first encountered error. TRCs with a not before time -// in the future are ignored. +// function exits on the first encountered error. TRCs with a not before time in +// the future are ignored. +// +// This function is not recommended for repeated use as it will read all TRC +// files in a directory on every invocation. Consider using a TRCLoader if you +// want to monitor a directory for new TRCs. func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { + return loadTRCs(ctx, dir, db, nil) +} + +func loadTRCs( + ctx context.Context, + dir string, + db DB, + ignoreFiles map[string]struct{}, +) (LoadResult, error) { if _, err := os.Stat(dir); err != nil { return LoadResult{}, serrors.WrapNoStack("stating directory", err, "dir", dir) } @@ -127,6 +141,10 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { res := LoadResult{Ignored: map[string]error{}} // TODO(roosd): should probably be a transaction. for _, f := range files { + // ignore as per request of the caller + if _, ok := ignoreFiles[f]; ok { + continue + } raw, err := os.ReadFile(f) if err != nil { return res, serrors.WrapNoStack("reading TRC", err, "file", f) @@ -148,10 +166,41 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { return res, serrors.WrapNoStack("adding TRC to DB", err, "file", f) } if !inserted { - res.Ignored[f] = serrors.JoinNoStack(ErrAlreadyExists, err) + res.Ignored[f] = ErrAlreadyExists continue } res.Loaded = append(res.Loaded, f) } return res, nil } + +// TRCLoader loads TRCs from a directory and stores them in the database. It +// tracks files that it has already loaded and does not load them again. +type TRCLoader struct { + Dir string + DB DB + + seen map[string]struct{} + mtx sync.Mutex +} + +// Load loads all TRCs from the directory into database. Files that have been +// loaded by a previous Load invocation are silently ignored. +func (l *TRCLoader) Load(ctx context.Context) (LoadResult, error) { + l.mtx.Lock() + defer l.mtx.Unlock() + if l.seen == nil { + l.seen = make(map[string]struct{}) + } + + result, err := loadTRCs(ctx, l.Dir, l.DB, l.seen) + for _, f := range result.Loaded { + l.seen[f] = struct{}{} + } + for f, err := range result.Ignored { + if errors.Is(err, ErrAlreadyExists) { + l.seen[f] = struct{}{} + } + } + return result, err +} diff --git a/private/trust/store_test.go b/private/trust/store_test.go index b5c040342a..642b07bde7 100644 --- a/private/trust/store_test.go +++ b/private/trust/store_test.go @@ -317,3 +317,45 @@ func TestLoadTRCs(t *testing.T) { }) } } + +func TestTRCLoaderLoad(t *testing.T) { + dir := genCrypto(t) + + testCases := map[string]struct { + inputDir string + setupDB func(ctrl *gomock.Controller) trust.DB + test func(t *testing.T, loader *trust.TRCLoader) + }{ + "repeated load": { + inputDir: filepath.Join(dir, "ISD1/trcs"), + setupDB: func(ctrl *gomock.Controller) trust.DB { + db := mock_trust.NewMockDB(ctrl) + db.EXPECT().InsertTRC(gomock.Any(), gomock.Any()).Times(2).Return( + true, nil, + ) + return db + }, + test: func(t *testing.T, loader *trust.TRCLoader) { + res, err := loader.Load(context.Background()) + require.NoError(t, err) + assert.Len(t, res.Loaded, 2) + res, err = loader.Load(context.Background()) + require.NoError(t, err) + assert.Len(t, res.Loaded, 0) + }, + }, + } + for name, tc := range testCases { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + db := tc.setupDB(ctrl) + loader := &trust.TRCLoader{ + DB: db, + Dir: tc.inputDir, + } + tc.test(t, loader) + }) + } +}