diff --git a/enricher/epss/epss.go b/enricher/epss/epss.go index f8551b05e..87fd37a60 100644 --- a/enricher/epss/epss.go +++ b/enricher/epss/epss.go @@ -11,7 +11,7 @@ import ( "net/http" "net/url" "path" - "sort" + "slices" "strconv" "strings" "time" @@ -20,6 +20,7 @@ import ( "github.com/quay/claircore" "github.com/quay/claircore/enricher" + "github.com/quay/claircore/internal/httputil" "github.com/quay/claircore/libvuln/driver" "github.com/quay/claircore/pkg/tmp" ) @@ -29,6 +30,8 @@ var ( _ driver.EnrichmentUpdater = (*Enricher)(nil) ) +// EPSSItem represents a single entry in the EPSS feed, containing information +// about a CVE's Exploit Prediction Scoring System (EPSS) score and percentile. type EPSSItem struct { ModelVersion string `json:"modelVersion"` Date string `json:"date"` @@ -39,11 +42,11 @@ type EPSSItem struct { const ( // Type is the type of data returned from the Enricher's Enrich method. - Type = `message/vnd.clair.map.vulnerability; enricher=clair.epss schema=https://csrc.nist.gov/schema/nvd/feed/1.1/cvss-v3.x.json` + Type = `message/vnd.clair.map.vulnerability; enricher=clair.epss schema=https://csrc.nist.gov/schema/nvd/baseURL/1.1/cvss-v3.x.json` - // DefaultFeed is the default place to look for EPSS feeds. + // DefaultBaseURL is the default place to look for EPSS feeds. // epss_scores-YYYY-MM-DD.csv.gz needs to be specified to get all data - DefaultFeed = `https://epss.cyentia.com/` + DefaultBaseURL = `https://epss.cyentia.com/` // epssName is the name of the enricher epssName = `clair.epss` @@ -55,13 +58,13 @@ const ( type Enricher struct { driver.NoopUpdater c *http.Client - feed *url.URL + baseURL *url.URL feedPath string } // Config is the configuration for Enricher. type Config struct { - FeedRoot *string `json:"feed_root" yaml:"feed_root"` + BaseURL *string `json:"url" yaml:"url"` } func (e *Enricher) Configure(ctx context.Context, f driver.ConfigUnmarshaler, c *http.Client) error { @@ -76,18 +79,18 @@ func (e *Enricher) Configure(ctx context.Context, f driver.ConfigUnmarshaler, c if err := f(&cfg); err != nil { return err } - if cfg.FeedRoot != nil { + if cfg.BaseURL != nil { // validate the URL format - if _, err := url.Parse(*cfg.FeedRoot); err != nil { - return fmt.Errorf("invalid URL format for FeedRoot: %w", err) + if _, err := url.Parse(*cfg.BaseURL); err != nil { + return fmt.Errorf("invalid URL format for BaseURL: %w", err) } // only .gz file is supported - if strings.HasSuffix(*cfg.FeedRoot, ".gz") { - //overwrite feedPath is cfg provides another feed path - e.feedPath = *cfg.FeedRoot + if strings.HasSuffix(*cfg.BaseURL, ".gz") { + //overwrite feedPath is cfg provides another baseURL path + e.feedPath = *cfg.BaseURL } else { - return fmt.Errorf("invalid feed root: expected a '.gz' file, but got '%q'", *cfg.FeedRoot) + return fmt.Errorf("invalid baseURL root: expected a '.gz' file, but got '%q'", *cfg.BaseURL) } } @@ -98,10 +101,6 @@ func (e *Enricher) Configure(ctx context.Context, f driver.ConfigUnmarshaler, c func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.Fingerprint) (io.ReadCloser, driver.Fingerprint, error) { ctx = zlog.ContextWithValues(ctx, "component", "enricher/epss/Enricher/FetchEnrichment") - if e.feedPath == "" || !strings.HasSuffix(e.feedPath, ".gz") { - return nil, "", fmt.Errorf("invalid feed path: %q must be non-empty and end with '.gz'", e.feedPath) - } - out, err := tmp.NewFile("", "epss.") if err != nil { return nil, "", err @@ -126,8 +125,8 @@ func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.F } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, "", fmt.Errorf("unable to fetch file: received status %d", resp.StatusCode) + if err = httputil.CheckResponse(resp, http.StatusOK); err != nil { + return nil, "", fmt.Errorf("unable to fetch file: %w", err) } etag := resp.Header.Get("etag") @@ -149,7 +148,7 @@ func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.F defer gzipReader.Close() csvReader := csv.NewReader(gzipReader) - csvReader.FieldsPerRecord = -1 // Allow variable-length fields + csvReader.FieldsPerRecord = 2 // assume metadata is always in the first line record, err := csvReader.Read() @@ -160,40 +159,45 @@ func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.F var modelVersion, date string for _, field := range record { field = strings.TrimSpace(field) - if strings.HasPrefix(field, "#") { - field = strings.TrimPrefix(field, "#") + field = strings.TrimPrefix(strings.TrimSpace(field), "#") + key, value, found := strings.Cut(field, ":") + if !found { + return nil, "", fmt.Errorf("unexpected metadata field format: %q", field) } - kv := strings.SplitN(field, ":", 2) - if len(kv) == 2 { - switch strings.TrimSpace(kv[0]) { - case "model_version": - modelVersion = strings.TrimSpace(kv[1]) - case "score_date": - date = strings.TrimSpace(kv[1]) - } + switch key { + case "model_version": + modelVersion = value + case "score_date": + date = value } } if modelVersion == "" || date == "" { return nil, "", fmt.Errorf("missing metadata fields in record: %v", record) } + csvReader.Comment = '#' + csvReader.FieldsPerRecord = 3 // Expect exactly 3 fields per record - csvReader.Comment = '#' // Ignore subsequent comment lines + if modelVersion == "" || date == "" { + return nil, "", fmt.Errorf("missing metadata fields in record: %v", record) + } + // Read and validate header line record, err = csvReader.Read() if err != nil { return nil, "", fmt.Errorf("unable to read header line: %w", err) } - if len(record) < 3 || record[0] != "cve" || record[1] != "epss" || record[2] != "percentile" { + + expectedHeaders := []string{"cve", "epss", "percentile"} + if !slices.Equal(record, expectedHeaders) { return nil, "", fmt.Errorf("unexpected CSV headers: %v", record) } - headers := record enc := json.NewEncoder(out) totalCVEs := 0 for { - record, err = csvReader.Read() + record, err := csvReader.Read() if errors.Is(err, io.EOF) { break } @@ -201,18 +205,13 @@ func (e *Enricher) FetchEnrichment(ctx context.Context, prevFingerprint driver.F return nil, "", fmt.Errorf("unable to read line in CSV: %w", err) } - if len(record) != len(headers) { - zlog.Warn(ctx).Str("record", fmt.Sprintf("%v", record)).Msg("skipping record with mismatched fields") - continue - } - - r, err := newItemFeed(record, headers, modelVersion, date) + r, err := newItemFeed(record, modelVersion, date) if err != nil { - zlog.Warn(ctx).Str("record", fmt.Sprintf("%v", record)).Msg("skipping invalid record") + zlog.Warn(ctx).Err(err).Msg("skipping invalid record") continue } - if err = enc.Encode(&r); err != nil { + if err := enc.Encode(&r); err != nil { return nil, "", fmt.Errorf("unable to write JSON line to file: %w", err) } totalCVEs++ @@ -267,9 +266,9 @@ func currentFeedURL() string { formattedDate := currentDate.Format("2006-01-02") filePath := fmt.Sprintf("epss_scores-%s.csv.gz", formattedDate) - feedURL, err := url.Parse(DefaultFeed) + feedURL, err := url.Parse(DefaultBaseURL) if err != nil { - panic(fmt.Errorf("invalid default feed URL: %w", err)) + panic(fmt.Errorf("invalid default baseURL URL: %w", err)) } feedURL.Path = path.Join(feedURL.Path, filePath) @@ -316,7 +315,7 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla for m := range t { ts = append(ts, m) } - sort.Strings(ts) + slices.Sort(ts) cveKey := strings.Join(ts, "_") @@ -339,9 +338,6 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla } for _, r := range rec { - if _, exists := m[id]; !exists { - m[id] = []json.RawMessage{} - } m[id] = append(m[id], r.Enrichment) } } @@ -357,29 +353,25 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla return Type, []json.RawMessage{b}, nil } -func newItemFeed(record []string, headers []string, modelVersion string, scoreDate string) (driver.EnrichmentRecord, error) { - if len(record) != len(headers) { - return driver.EnrichmentRecord{}, fmt.Errorf("record and headers length mismatch") +func newItemFeed(record []string, modelVersion string, scoreDate string) (driver.EnrichmentRecord, error) { + // Assuming record has already been validated to have 3 fields + if len(record) != 3 { + return driver.EnrichmentRecord{}, fmt.Errorf("unexpected record length: %d", len(record)) } var item EPSSItem - for i, value := range record { - switch headers[i] { - case "cve": - item.CVE = value - case "epss": - if f, err := strconv.ParseFloat(value, 64); err == nil { - item.EPSS = f - } else { - return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for epss: %w", err) - } - case "percentile": - if f, err := strconv.ParseFloat(value, 64); err == nil { - item.Percentile = f - } else { - return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for percentile: %w", err) - } - } + item.CVE = record[0] + + if f, err := strconv.ParseFloat(record[1], 64); err == nil { + item.EPSS = f + } else { + return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for epss: %w", err) + } + + if f, err := strconv.ParseFloat(record[2], 64); err == nil { + item.Percentile = f + } else { + return driver.EnrichmentRecord{}, fmt.Errorf("invalid float for percentile: %w", err) } item.ModelVersion = modelVersion @@ -391,7 +383,7 @@ func newItemFeed(record []string, headers []string, modelVersion string, scoreDa } r := driver.EnrichmentRecord{ - Tags: []string{item.CVE}, // CVE field should be set + Tags: []string{item.CVE}, Enrichment: enrichment, } diff --git a/enricher/epss/epss_test.go b/enricher/epss/epss_test.go index 8e7182fdc..042a6711c 100644 --- a/enricher/epss/epss_test.go +++ b/enricher/epss/epss_test.go @@ -15,10 +15,10 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/quay/zlog" "github.com/quay/claircore" "github.com/quay/claircore/libvuln/driver" - "github.com/quay/zlog" ) func TestConfigure(t *testing.T) { @@ -38,7 +38,7 @@ func TestConfigure(t *testing.T) { Config: func(i interface{}) error { cfg := i.(*Config) s := "http://example.com/" - cfg.FeedRoot = &s + cfg.BaseURL = &s return nil }, Check: func(t *testing.T, err error) { @@ -58,11 +58,11 @@ func TestConfigure(t *testing.T) { }, }, { - Name: "BadURL", // Malformed URL in FeedRoot + Name: "BadURL", // Malformed URL in BaseURL Config: func(i interface{}) error { cfg := i.(*Config) s := "http://[notaurl:/" - cfg.FeedRoot = &s + cfg.BaseURL = &s return nil }, Check: func(t *testing.T, err error) { @@ -72,11 +72,11 @@ func TestConfigure(t *testing.T) { }, }, { - Name: "ValidGZURL", // Proper .gz URL in FeedRoot + Name: "ValidGZURL", // Proper .gz URL in BaseURL Config: func(i interface{}) error { cfg := i.(*Config) s := "http://example.com/epss_scores-2024-10-25.csv.gz" - cfg.FeedRoot = &s + cfg.BaseURL = &s return nil }, Check: func(t *testing.T, err error) { @@ -169,7 +169,7 @@ func mockServer(t *testing.T) *httptest.Server { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch path.Ext(r.URL.Path) { - case ".gz": // only gz feed is supported + case ".gz": // only gz baseURL is supported w.Header().Set("etag", etagValue) f, err := os.Open(filepath.Join(root, "data.csv")) @@ -207,7 +207,7 @@ func (tc fetchTestcase) Run(ctx context.Context, srv *httptest.Server) func(*tes t.Fatal("expected Config type for i, but got a different type") } u := srv.URL + "/data.csv.gz" - cfg.FeedRoot = &u + cfg.BaseURL = &u return nil } @@ -259,7 +259,7 @@ func (tc parseTestcase) Run(ctx context.Context, srv *httptest.Server) func(*tes t.Fatal("assertion failed") } u := srv.URL + "/data.csv.gz" - cfg.FeedRoot = &u + cfg.BaseURL = &u return nil } if err := e.Configure(ctx, f, srv.Client()); err != nil { @@ -313,7 +313,7 @@ func TestEnrich(t *testing.T) { t.Fatal("assertion failed") } u := srv.URL + "/data.csv.gz" - cfg.FeedRoot = &u + cfg.BaseURL = &u return nil } if err := e.Configure(ctx, f, srv.Client()); err != nil {