Skip to content

Commit

Permalink
crypto: do not load already loaded TRC files (scionproto#4588)
Browse files Browse the repository at this point in the history
This commit introduces a TRC loader that keeps track of what TRC files
have already been read. By doing so we can omit duplicate work of
parsing TRC files and storing them in the DB.

This applies to the control service and the daemon.
  • Loading branch information
lukedirtwalker authored Aug 28, 2024
1 parent 08852b4 commit fbd9eb0
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 13 deletions.
2 changes: 1 addition & 1 deletion control/trust.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func newCachingSignerGen(

gen := trust.SignerGen{
IA: ia,
DB: cstrust.CryptoLoader{
DB: &cstrust.CryptoLoader{
Dir: filepath.Join(cfgDir, "crypto/as"),
TRCDirs: []string{filepath.Join(cfgDir, "certs")},
DB: db,
Expand Down
32 changes: 27 additions & 5 deletions control/trust/crypto_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/x509"
"errors"
"sync"

"github.com/scionproto/scion/pkg/log"
"github.com/scionproto/scion/pkg/private/serrors"
Expand All @@ -33,9 +34,13 @@ type CryptoLoader struct {
Dir string
// TRCDirs are optional directories from which TRCs are loaded.
TRCDirs []string

trcLoaders map[string]*trust.TRCLoader
trcLoadersInitMtx sync.Mutex
}

func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) {
func (l *CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.SignedTRC, error) {
l.initTRCLoaders()
if err := l.loadTRCs(ctx); err != nil {
log.FromCtx(ctx).Info("Failed to load TRCs from disk, continuing", "err", err)
}
Expand All @@ -44,7 +49,7 @@ func (l CryptoLoader) SignedTRC(ctx context.Context, id cppki.TRCID) (cppki.Sign

// Chains loads chains from disk, stores them to DB, and returns the result from
// DB. The fallback mode is always the result of the DB.
func (l CryptoLoader) Chains(ctx context.Context,
func (l *CryptoLoader) Chains(ctx context.Context,
query trust.ChainQuery) ([][]*x509.Certificate, error) {

r, err := trust.LoadChains(ctx, l.Dir, l.DB)
Expand All @@ -70,7 +75,24 @@ func (l CryptoLoader) Chains(ctx context.Context,
return l.DB.Chains(ctx, query)
}

func (l CryptoLoader) loadTRCs(ctx context.Context) error {
func (l *CryptoLoader) initTRCLoaders() {
l.trcLoadersInitMtx.Lock()
defer l.trcLoadersInitMtx.Unlock()
if l.trcLoaders != nil {
return
}
l.trcLoaders = make(map[string]*trust.TRCLoader, len(l.TRCDirs)+1)
for _, dir := range append([]string{l.Dir}, l.TRCDirs...) {
if _, ok := l.trcLoaders[dir]; !ok {
l.trcLoaders[dir] = &trust.TRCLoader{
DB: l.DB,
Dir: dir,
}
}
}
}

func (l *CryptoLoader) loadTRCs(ctx context.Context) error {
var errs serrors.List
for _, dir := range append([]string{l.Dir}, l.TRCDirs...) {
if err := l.loadTRCsFromDir(ctx, dir); err != nil {
Expand All @@ -80,8 +102,8 @@ func (l CryptoLoader) loadTRCs(ctx context.Context) error {
return errs.ToError()
}

func (l CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error {
r, err := trust.LoadTRCs(ctx, dir, l.DB)
func (l *CryptoLoader) loadTRCsFromDir(ctx context.Context, dir string) error {
r, err := l.trcLoaders[dir].Load(ctx)
if err != nil {
return err
}
Expand Down
11 changes: 7 additions & 4 deletions daemon/cmd/daemon/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,13 @@ func realMain(ctx context.Context) error {
CacheHits: metrics.NewPromCounter(trustmetrics.CacheHitsTotal),
MaxCacheExpiration: globalCfg.TrustEngine.Cache.Expiration.Duration,
}
trcLoader := periodic.Start(periodic.Func{
trcLoader := trust.TRCLoader{
Dir: filepath.Join(globalCfg.General.ConfigDir, "certs"),
DB: trustDB,
}
trcLoaderTask := periodic.Start(periodic.Func{
Task: func(ctx context.Context) {
trcDirs := filepath.Join(globalCfg.General.ConfigDir, "certs")
res, err := trust.LoadTRCs(ctx, trcDirs, trustDB)
res, err := trcLoader.Load(ctx)
if err != nil {
log.SafeInfo(log.FromCtx(ctx), "TRC loading failed", "err", err)
}
Expand All @@ -172,7 +175,7 @@ func realMain(ctx context.Context) error {
},
TaskName: "daemon_trc_loader",
}, 10*time.Second, 10*time.Second)
defer trcLoader.Stop()
defer trcLoaderTask.Stop()

var drkeyClientEngine *sd_drkey.ClientEngine
if globalCfg.DRKeyLevel2DB.Connection != "" {
Expand Down
55 changes: 52 additions & 3 deletions private/trust/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"
"path/filepath"
"sync"
"time"

"github.com/scionproto/scion/pkg/private/serrors"
Expand Down Expand Up @@ -112,9 +113,22 @@ func LoadChains(ctx context.Context, dir string, db DB) (LoadResult, error) {
}

// LoadTRCs loads all *.trc located in a directory in the database. This
// function exits on the first encountered error. TRCs with a not before time
// in the future are ignored.
// function exits on the first encountered error. TRCs with a not before time in
// the future are ignored.
//
// This function is not recommended for repeated use as it will read all TRC
// files in a directory on every invocation. Consider using a TRCLoader if you
// want to monitor a directory for new TRCs.
func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) {
return loadTRCs(ctx, dir, db, nil)
}

func loadTRCs(
ctx context.Context,
dir string,
db DB,
ignoreFiles map[string]struct{},
) (LoadResult, error) {
if _, err := os.Stat(dir); err != nil {
return LoadResult{}, serrors.WrapNoStack("stating directory", err, "dir", dir)
}
Expand All @@ -127,6 +141,10 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) {
res := LoadResult{Ignored: map[string]error{}}
// TODO(roosd): should probably be a transaction.
for _, f := range files {
// ignore as per request of the caller
if _, ok := ignoreFiles[f]; ok {
continue
}
raw, err := os.ReadFile(f)
if err != nil {
return res, serrors.WrapNoStack("reading TRC", err, "file", f)
Expand All @@ -148,10 +166,41 @@ func LoadTRCs(ctx context.Context, dir string, db DB) (LoadResult, error) {
return res, serrors.WrapNoStack("adding TRC to DB", err, "file", f)
}
if !inserted {
res.Ignored[f] = serrors.JoinNoStack(ErrAlreadyExists, err)
res.Ignored[f] = ErrAlreadyExists
continue
}
res.Loaded = append(res.Loaded, f)
}
return res, nil
}

// TRCLoader loads TRCs from a directory and stores them in the database. It
// tracks files that it has already loaded and does not load them again.
type TRCLoader struct {
Dir string
DB DB

seen map[string]struct{}
mtx sync.Mutex
}

// Load loads all TRCs from the directory into database. Files that have been
// loaded by a previous Load invocation are silently ignored.
func (l *TRCLoader) Load(ctx context.Context) (LoadResult, error) {
l.mtx.Lock()
defer l.mtx.Unlock()
if l.seen == nil {
l.seen = make(map[string]struct{})
}

result, err := loadTRCs(ctx, l.Dir, l.DB, l.seen)
for _, f := range result.Loaded {
l.seen[f] = struct{}{}
}
for f, err := range result.Ignored {
if errors.Is(err, ErrAlreadyExists) {
l.seen[f] = struct{}{}
}
}
return result, err
}
42 changes: 42 additions & 0 deletions private/trust/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,45 @@ func TestLoadTRCs(t *testing.T) {
})
}
}

func TestTRCLoaderLoad(t *testing.T) {
dir := genCrypto(t)

testCases := map[string]struct {
inputDir string
setupDB func(ctrl *gomock.Controller) trust.DB
test func(t *testing.T, loader *trust.TRCLoader)
}{
"repeated load": {
inputDir: filepath.Join(dir, "ISD1/trcs"),
setupDB: func(ctrl *gomock.Controller) trust.DB {
db := mock_trust.NewMockDB(ctrl)
db.EXPECT().InsertTRC(gomock.Any(), gomock.Any()).Times(2).Return(
true, nil,
)
return db
},
test: func(t *testing.T, loader *trust.TRCLoader) {
res, err := loader.Load(context.Background())
require.NoError(t, err)
assert.Len(t, res.Loaded, 2)
res, err = loader.Load(context.Background())
require.NoError(t, err)
assert.Len(t, res.Loaded, 0)
},
},
}
for name, tc := range testCases {
name, tc := name, tc
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := tc.setupDB(ctrl)
loader := &trust.TRCLoader{
DB: db,
Dir: tc.inputDir,
}
tc.test(t, loader)
})
}
}

0 comments on commit fbd9eb0

Please sign in to comment.