From 6cc4f18bcc724d0c8b060cb985af25d620f8951d Mon Sep 17 00:00:00 2001 From: RTann Date: Mon, 4 Dec 2023 11:24:43 -0800 Subject: [PATCH] chore: batch insert vulns --- pkg/vulndump/batch_loader.go | 83 +++++++++++++++++++++++++++++++ pkg/vulndump/batch_loader_test.go | 62 +++++++++++++++++++++++ pkg/vulndump/loader.go | 46 ++++++++++++++--- pkg/vulndump/loader_test.go | 28 +++++++++++ 4 files changed, 211 insertions(+), 8 deletions(-) create mode 100644 pkg/vulndump/batch_loader.go create mode 100644 pkg/vulndump/batch_loader_test.go create mode 100644 pkg/vulndump/loader_test.go diff --git a/pkg/vulndump/batch_loader.go b/pkg/vulndump/batch_loader.go new file mode 100644 index 000000000..2f8b4313e --- /dev/null +++ b/pkg/vulndump/batch_loader.go @@ -0,0 +1,83 @@ +package vulndump + +import ( + "encoding/json" + "io" + + "github.com/stackrox/scanner/database" +) + +const defaultBatchSize = 10_000 + +// osLoader batch loads OS-level vulnerabilities into a buffer. +type osLoader struct { + rc io.ReadCloser + + dec *json.Decoder + + batchSize int + buf []database.Vulnerability + done bool + err error +} + +func newOSLoader(source io.ReadCloser) (*osLoader, error) { + dec := json.NewDecoder(source) + // Read the initial token, "[". + _, err := dec.Token() + if err != nil { + return nil, err + } + + return &osLoader{ + rc: source, + dec: dec, + batchSize: defaultBatchSize, + buf: make([]database.Vulnerability, 0, defaultBatchSize), + }, nil +} + +// Next loads the next batch of vulnerabilities and returns +// whether it was successful or not. +func (l *osLoader) Next() bool { + if l.done || l.err != nil { + return false + } + + l.buf = l.buf[:0] + for i := 0; i < l.batchSize; i++ { + if !l.dec.More() { + // JSON array has no more values. + l.done = true + return true + } + l.buf = append(l.buf, database.Vulnerability{}) + if err := l.dec.Decode(&l.buf[i]); err != nil { + l.err = err + return false + } + } + + return true +} + +// Vulns returns the "next" bath of vulnerabilities. +// It is expected to be called once for each successful call to Next. +func (l *osLoader) Vulns() []database.Vulnerability { + return l.buf +} + +// Err returns the error associated with loading vulnerabilities. +// It is expected to be non-nil upon an unsuccessful call to Next. +func (l *osLoader) Err() error { + return l.err +} + +// Close closes the loader. +func (l *osLoader) Close() error { + l.buf = nil // hint to GC to clean this. + // Don't bother reading the final token, "]", + // as it is possible there was a failure reading + // the JSON array. Just close the file. + return l.rc.Close() +} diff --git a/pkg/vulndump/batch_loader_test.go b/pkg/vulndump/batch_loader_test.go new file mode 100644 index 000000000..9488d1a2a --- /dev/null +++ b/pkg/vulndump/batch_loader_test.go @@ -0,0 +1,62 @@ +package vulndump + +import ( + "archive/zip" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/stackrox/rox/pkg/utils" + "github.com/stackrox/scanner/pkg/ziputil" + "github.com/stretchr/testify/require" +) + +const defURL = "https://definitions.stackrox.io/93AEC554-29EE-4E24-96D6-744092A98444/diff.zip" + +func mustFetchOSVulns(b *testing.B) *os.File { + f, err := os.Create(filepath.Join(b.TempDir(), "vulns.zip")) + require.NoError(b, err) + + c := &http.Client{Timeout: 30 * time.Second} + resp, err := c.Get(defURL) + require.NoError(b, err) + defer utils.IgnoreError(resp.Body.Close) + + _, err = io.Copy(f, resp.Body) + require.NoError(b, err) + + return f +} + +func BenchmarkOSLoader(b *testing.B) { + f := mustFetchOSVulns(b) + defer utils.IgnoreError(f.Close) + + zipR, err := zip.OpenReader(f.Name()) + require.NoError(b, err) + vulnsF, err := ziputil.OpenFile(&zipR.Reader, OSVulnsFileName) + require.NoError(b, err) + + runtime.GC() + + loader, err := newOSLoader(vulnsF) + require.NoError(b, err) + defer func() { + require.NoError(b, loader.Close()) + }() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var n int + for loader.Next() { + vulns := loader.Vulns() + n += len(vulns) + } + require.NoError(b, loader.Err()) + b.Logf("Loaded %d vulns", n) + } +} diff --git a/pkg/vulndump/loader.go b/pkg/vulndump/loader.go index 3b0d85cc9..7b76266b9 100644 --- a/pkg/vulndump/loader.go +++ b/pkg/vulndump/loader.go @@ -117,6 +117,15 @@ func LoadOSVulnsFromDump(zipR *zip.Reader) ([]database.Vulnerability, error) { return vulns, nil } +func getOSLoader(zipR *zip.Reader) (*osLoader, error) { + osVulnsFile, err := ziputil.OpenFile(zipR, OSVulnsFileName) + if err != nil { + return nil, errors.Wrap(err, "opening OS vulns file") + } + + return newOSLoader(osVulnsFile) +} + func renew(sig *concurrency.Signal, db database.Datastore, interval time.Duration, expiration time.Time, instanceName string) { // Give a buffer for this instance to renew the lock expirationDuration := time.Until(expiration) - 10*time.Second @@ -188,16 +197,27 @@ func startVulnLoad(manifest *Manifest, db database.Datastore, updateInterval tim func loadOSVulns(zipR *zip.Reader, db database.Datastore) error { log.Info("Loading OS vulns...") - osVulns, err := LoadOSVulnsFromDump(zipR) + loader, err := getOSLoader(zipR) if err != nil { return err } - log.Infof("Done loading OS vulns. There are %d vulns to insert into the DB", len(osVulns)) + defer utils.IgnoreError(loader.Close) + + var n int + for loader.Next() { + osVulns := loader.Vulns() + + n += len(osVulns) + log.Infof("Loaded %d OS-level vulns. Total OS-level vulns: %d", len(osVulns), n) - if err := db.InsertVulnerabilities(osVulns); err != nil { - return errors.Wrap(err, "inserting vulns into the DB") + if err := db.InsertVulnerabilities(osVulns); err != nil { + return errors.Wrap(err, "inserting vulns into the DB") + } + } + if loader.Err() != nil { + return loader.Err() } - log.Info("Done inserting OS vulns into the DB") + return nil } @@ -275,12 +295,12 @@ func UpdateFromVulnDump(zipPath string, db database.Datastore, updateInterval ti } defer utils.IgnoreError(zipR.Close) - log.Info("Loading manifest...") + log.Info("Loading vulnerability manifest") manifest, err := LoadManifestFromDump(&zipR.Reader) if err != nil { return err } - log.Info("Loaded manifest") + log.Info("Successfully loaded vulnerability manifest") if db != nil { performUpdate, finishFn, err := startVulnLoad(manifest, db, updateInterval, instanceName) @@ -288,6 +308,8 @@ func UpdateFromVulnDump(zipPath string, db database.Datastore, updateInterval ti return errors.Wrap(err, "error beginning vuln loading") } if performUpdate { + log.Info("Loading OS-level vulnerabilities") + if err := loadRHELv2Vulns(db, &zipR.Reader, repoToCPE); err != nil { _ = finishFn(err) return errors.Wrap(err, "error loading RHEL vulns") @@ -301,15 +323,23 @@ func UpdateFromVulnDump(zipPath string, db database.Datastore, updateInterval ti if err := finishFn(nil); err != nil { return errors.Wrap(err, "error ending vuln loading") } + + log.Info("Loaded OS-level vulnerabilities successfully") } } + log.Info("Loading application-level vulnerabilities") errorList := errorhelpers.NewErrorList("loading application-level caches") for _, appCache := range caches { if err := loadApplicationUpdater(appCache, manifest, &zipR.Reader); err != nil { errorList.AddError(errors.Wrapf(err, "error loading into in-mem cache %q", appCache.Dir())) } } + if err := errorList.ToError(); err != nil { + return err + } - return errorList.ToError() + log.Info("Successfully loaded application-level vulnerabilities") + + return nil } diff --git a/pkg/vulndump/loader_test.go b/pkg/vulndump/loader_test.go new file mode 100644 index 000000000..7f332c74d --- /dev/null +++ b/pkg/vulndump/loader_test.go @@ -0,0 +1,28 @@ +package vulndump + +import ( + "archive/zip" + "runtime" + "testing" + + "github.com/stackrox/rox/pkg/utils" + "github.com/stretchr/testify/require" +) + +func BenchmarkLoadOSVulnsFromDump(b *testing.B) { + f := mustFetchOSVulns(b) + defer utils.IgnoreError(f.Close) + + zipR, err := zip.OpenReader(f.Name()) + require.NoError(b, err) + defer utils.IgnoreError(zipR.Close) + + runtime.GC() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + vulns, err := LoadOSVulnsFromDump(&zipR.Reader) + require.NoError(b, err) + b.Logf("Loaded %d vulns", len(vulns)) + } +}