diff --git a/control/trust.go b/control/trust.go index b16e49c05f..cff947de3d 100644 --- a/control/trust.go +++ b/control/trust.go @@ -103,7 +103,7 @@ func newCachingSignerGen( gen := trust.SignerGen{ IA: ia, - DB: cstrust.CryptoLoader{ + DB: &cstrust.CryptoLoader{ Dir: filepath.Join(cfgDir, "crypto/as"), TRCDirs: []string{filepath.Join(cfgDir, "certs")}, DB: db, diff --git a/control/trust/crypto_loader.go b/control/trust/crypto_loader.go index 238e69b30e..1081499468 100644 --- a/control/trust/crypto_loader.go +++ b/control/trust/crypto_loader.go @@ -18,6 +18,7 @@ import ( "context" "crypto/x509" "errors" + "sync" "github.com/scionproto/scion/pkg/log" "github.com/scionproto/scion/pkg/private/serrors" @@ -33,9 +34,13 @@ type CryptoLoader struct { Dir string // TRCDirs are optional directories from which TRCs are loaded. TRCDirs []string + + trcLoaders map[string]*trust.TRCLoader + trcLoadersInitMtx sync.Mutex } -func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) { +func (l *CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) { + l.initTRCLoaders() if err := l.loadTRCs(ctx); err != nil { log.FromCtx(ctx).Info("Failed to load TRCs from disk, continuing", "err", err) } @@ -44,7 +49,7 @@ func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.Sign // Chains loads chains from disk, stores them to DB, and returns the result from // DB. The fallback mode is always the result of the DB. -func (l CryptoLoader) Chains(ctx context.Context, +func (l *CryptoLoader) Chains(ctx context.Context, query trust.ChainQuery) ([][]*x509.Certificate, error) { r, err := trust.LoadChains(ctx, l.Dir, l.DB) @@ -70,7 +75,24 @@ func (l CryptoLoader) Chains(ctx context.Context, return l.DB.Chains(ctx, query) } -func (l CryptoLoader) loadTRCs(ctx context.Context) error { +func (l *CryptoLoader) initTRCLoaders() { + l.trcLoadersInitMtx.Lock() + defer l.trcLoadersInitMtx.Unlock() + if l.trcLoaders != nil { + return + } + l.trcLoaders = make(map[string]*trust.TRCLoader, len(l.TRCDirs)+1) + for _, dir := range append([]string{l.Dir}, l.TRCDirs...) { + if _, ok := l.trcLoaders[dir]; !ok { + l.trcLoaders[dir] = &trust.TRCLoader{ + DB: l.DB, + Dir: dir, + } + } + } +} + +func (l *CryptoLoader) loadTRCs(ctx context.Context) error { var errs serrors.List for _, dir := range append([]string{l.Dir}, l.TRCDirs...) { if err := l.loadTRCsFromDir(ctx, dir); err != nil { @@ -80,8 +102,8 @@ func (l CryptoLoader) loadTRCs(ctx context.Context) error { return errs.ToError() } -func (l CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error { - r, err := trust.LoadTRCs(ctx, dir, l.DB) +func (l *CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error { + r, err := l.trcLoaders[dir].Load(ctx) if err != nil { return err } diff --git a/daemon/cmd/daemon/main.go b/daemon/cmd/daemon/main.go index 5faf0c9bc2..32fd124427 100644 --- a/daemon/cmd/daemon/main.go +++ b/daemon/cmd/daemon/main.go @@ -159,10 +159,13 @@ func realMain(ctx context.Context) error { CacheHits: metrics.NewPromCounter(trustmetrics.CacheHitsTotal), MaxCacheExpiration: globalCfg.TrustEngine.Cache.Expiration.Duration, } - trcLoader := periodic.Start(periodic.Func{ + trcLoader := trust.TRCLoader{ + Dir: filepath.Join(globalCfg.General.ConfigDir, "certs"), + DB: trustDB, + } + trcLoaderTask := periodic.Start(periodic.Func{ Task: func(ctx context.Context) { - trcDirs := filepath.Join(globalCfg.General.ConfigDir, "certs") - res, err := trust.LoadTRCs(ctx, trcDirs, trustDB) + res, err := trcLoader.Load(ctx) if err != nil { log.SafeInfo(log.FromCtx(ctx), "TRC loading failed", "err", err) } @@ -172,7 +175,7 @@ func realMain(ctx context.Context) error { }, TaskName: "daemon_trc_loader", }, 10*time.Second, 10*time.Second) - defer trcLoader.Stop() + defer trcLoaderTask.Stop() var drkeyClientEngine *sd_drkey.ClientEngine if globalCfg.DRKeyLevel2DB.Connection != "" { diff --git a/private/trust/store.go b/private/trust/store.go index 1eaea6f017..76815cb656 100644 --- a/private/trust/store.go +++ b/private/trust/store.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path/filepath" + "sync" "time" "github.com/scionproto/scion/pkg/private/serrors" @@ -112,9 +113,22 @@ func LoadChains(ctx context.Context, dir string, db DB) (LoadResult, error) { } // LoadTRCs loads all *.trc located in a directory in the database. This -// function exits on the first encountered error. TRCs with a not before time -// in the future are ignored. +// function exits on the first encountered error. TRCs with a not before time in +// the future are ignored. +// +// This function is not recommended for repeated use as it will read all TRC +// files in a directory on every invocation. Consider using a TRCLoader if you +// want to monitor a directory for new TRCs. func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { + return loadTRCs(ctx, dir, db, nil) +} + +func loadTRCs( + ctx context.Context, + dir string, + db DB, + ignoreFiles map[string]struct{}, +) (LoadResult, error) { if _, err := os.Stat(dir); err != nil { return LoadResult{}, serrors.WrapNoStack("stating directory", err, "dir", dir) } @@ -127,6 +141,10 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { res := LoadResult{Ignored: map[string]error{}} // TODO(roosd): should probably be a transaction. for _, f := range files { + // ignore as per request of the caller + if _, ok := ignoreFiles[f]; ok { + continue + } raw, err := os.ReadFile(f) if err != nil { return res, serrors.WrapNoStack("reading TRC", err, "file", f) @@ -148,10 +166,41 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) { return res, serrors.WrapNoStack("adding TRC to DB", err, "file", f) } if !inserted { - res.Ignored[f] = serrors.JoinNoStack(ErrAlreadyExists, err) + res.Ignored[f] = ErrAlreadyExists continue } res.Loaded = append(res.Loaded, f) } return res, nil } + +// TRCLoader loads TRCs from a directory and stores them in the database. It +// tracks files that it has already loaded and does not load them again. +type TRCLoader struct { + Dir string + DB DB + + seen map[string]struct{} + mtx sync.Mutex +} + +// Load loads all TRCs from the directory into database. Files that have been +// loaded by a previous Load invocation are silently ignored. +func (l *TRCLoader) Load(ctx context.Context) (LoadResult, error) { + l.mtx.Lock() + defer l.mtx.Unlock() + if l.seen == nil { + l.seen = make(map[string]struct{}) + } + + result, err := loadTRCs(ctx, l.Dir, l.DB, l.seen) + for _, f := range result.Loaded { + l.seen[f] = struct{}{} + } + for f, err := range result.Ignored { + if errors.Is(err, ErrAlreadyExists) { + l.seen[f] = struct{}{} + } + } + return result, err +} diff --git a/private/trust/store_test.go b/private/trust/store_test.go index b5c040342a..642b07bde7 100644 --- a/private/trust/store_test.go +++ b/private/trust/store_test.go @@ -317,3 +317,45 @@ func TestLoadTRCs(t *testing.T) { }) } } + +func TestTRCLoaderLoad(t *testing.T) { + dir := genCrypto(t) + + testCases := map[string]struct { + inputDir string + setupDB func(ctrl *gomock.Controller) trust.DB + test func(t *testing.T, loader *trust.TRCLoader) + }{ + "repeated load": { + inputDir: filepath.Join(dir, "ISD1/trcs"), + setupDB: func(ctrl *gomock.Controller) trust.DB { + db := mock_trust.NewMockDB(ctrl) + db.EXPECT().InsertTRC(gomock.Any(), gomock.Any()).Times(2).Return( + true, nil, + ) + return db + }, + test: func(t *testing.T, loader *trust.TRCLoader) { + res, err := loader.Load(context.Background()) + require.NoError(t, err) + assert.Len(t, res.Loaded, 2) + res, err = loader.Load(context.Background()) + require.NoError(t, err) + assert.Len(t, res.Loaded, 0) + }, + }, + } + for name, tc := range testCases { + name, tc := name, tc + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + db := tc.setupDB(ctrl) + loader := &trust.TRCLoader{ + DB: db, + Dir: tc.inputDir, + } + tc.test(t, loader) + }) + } +}