From 34f4e080f0dc72ac01fb6235eff07e2dfa63e1ee Mon Sep 17 00:00:00 2001 From: daynewlee Date: Sun, 17 Nov 2024 14:38:00 -0600 Subject: [PATCH] epss: test Enrich() Signed-off-by: daynewlee --- enricher/epss/epss.go | 62 ++++++++++---- enricher/epss/epss_test.go | 146 ++++++++++++++++++++++---------- enricher/epss/testdata/data.csv | 2 +- 3 files changed, 149 insertions(+), 61 deletions(-) diff --git a/enricher/epss/epss.go b/enricher/epss/epss.go index 1d3cdec10..4151493c4 100644 --- a/enricher/epss/epss.go +++ b/enricher/epss/epss.go @@ -18,6 +18,7 @@ import ( "path" "regexp" "sort" + "strconv" "strings" "time" ) @@ -224,38 +225,49 @@ func (e *Enricher) sourceURL() { func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *claircore.VulnerabilityReport) (string, []json.RawMessage, error) { ctx = zlog.ContextWithValues(ctx, "component", "enricher/epss/Enricher/Enrich") - - // We return any CVSS blobs for CVEs mentioned in the free-form parts of the - // vulnerability. m := make(map[string][]json.RawMessage) - erCache := make(map[string][]driver.EnrichmentRecord) + for id, v := range r.Vulnerabilities { t := make(map[string]struct{}) - ctx := zlog.ContextWithValues(ctx, - "vuln", v.Name) + ctx := zlog.ContextWithValues(ctx, "vuln", v.Name) + for _, elem := range []string{ v.Description, v.Name, v.Links, } { - for _, m := range cveRegexp.FindAllString(elem, -1) { + // Check if the element is non-empty before running the regex + if elem == "" { + zlog.Debug(ctx).Str("element", elem).Msg("skipping empty element") + continue + } + + matches := cveRegexp.FindAllString(elem, -1) + if len(matches) == 0 { + zlog.Debug(ctx).Str("element", elem).Msg("no CVEs found in element") + continue + } + for _, m := range matches { t[m] = struct{}{} } } + + // Skip if no CVEs were found if len(t) == 0 { + zlog.Debug(ctx).Msg("no CVEs found in vulnerability metadata") continue } + ts := make([]string, 0, len(t)) for m := range t { ts = append(ts, m) } - zlog.Debug(ctx). - Strs("cve", ts). - Msg("found CVEs") - sort.Strings(ts) + cveKey := strings.Join(ts, "_") + zlog.Debug(ctx).Str("cve_key", cveKey).Strs("cve", ts).Msg("generated CVE cache key") + rec, ok := erCache[cveKey] if !ok { var err error @@ -265,16 +277,27 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla } erCache[cveKey] = rec } - zlog.Debug(ctx). - Int("count", len(rec)). - Msg("found records") + + zlog.Debug(ctx).Int("count", len(rec)).Msg("found records") + + // Skip if no enrichment records are found + if len(rec) == 0 { + zlog.Debug(ctx).Strs("cve", ts).Msg("no enrichment records found for CVEs") + continue + } + for _, r := range rec { + if _, exists := m[id]; !exists { + m[id] = []json.RawMessage{} + } m[id] = append(m[id], r.Enrichment) } } + if len(m) == 0 { return Type, nil, nil } + b, err := json.Marshal(m) if err != nil { return Type, nil, err @@ -283,9 +306,14 @@ func (e *Enricher) Enrich(ctx context.Context, g driver.EnrichmentGetter, r *cla } func newItemFeed(record []string, headers []string) (driver.EnrichmentRecord, error) { - item := make(map[string]string) + item := make(map[string]interface{}) // Use interface{} to allow mixed types for i, value := range record { - item[headers[i]] = value + // epss details are numeric values + if f, err := strconv.ParseFloat(value, 64); err == nil { + item[headers[i]] = f + } else { + item[headers[i]] = value + } } enrichment, err := json.Marshal(item) @@ -294,7 +322,7 @@ func newItemFeed(record []string, headers []string) (driver.EnrichmentRecord, er } r := driver.EnrichmentRecord{ - Tags: []string{item["cve"]}, + Tags: []string{item["cve"].(string)}, // Ensure the "cve" field is a string Enrichment: enrichment, } diff --git a/enricher/epss/epss_test.go b/enricher/epss/epss_test.go index 78dc790fb..48bbbb40d 100644 --- a/enricher/epss/epss_test.go +++ b/enricher/epss/epss_test.go @@ -1,12 +1,13 @@ package epss import ( + "bufio" "compress/gzip" "context" - "encoding/csv" "encoding/json" "errors" - "fmt" + "github.com/google/go-cmp/cmp" + "github.com/quay/claircore" "github.com/quay/claircore/libvuln/driver" "github.com/quay/zlog" "io" @@ -273,72 +274,131 @@ func (tc parseTestcase) Run(ctx context.Context, srv *httptest.Server) func(*tes } type fakeGetter struct { - items []map[string]string - res []driver.EnrichmentRecord + items []driver.EnrichmentRecord } -func (f *fakeGetter) GetEnrichment(ctx context.Context, tags []string) ([]driver.EnrichmentRecord, error) { - id := tags[0] - for _, item := range f.items { - if value, ok := item["id"]; ok && value == id { - enrichment, err := json.Marshal(item) - if err != nil { - return nil, fmt.Errorf("failed to encode enrichment: %w", err) - } - r := driver.EnrichmentRecord{ - Tags: []string{item["cve"]}, - Enrichment: enrichment, +func (g *fakeGetter) GetEnrichment(ctx context.Context, cves []string) ([]driver.EnrichmentRecord, error) { + var results []driver.EnrichmentRecord + for _, cve := range cves { + for _, item := range g.items { + for _, tag := range item.Tags { + if tag == cve { + results = append(results, item) + break + } } - f.res = []driver.EnrichmentRecord{r} } } - return f.res, nil + return results, nil } -func parseCSV(filePath string) ([]map[string]string, error) { +func TestEnrich(t *testing.T) { + t.Parallel() + ctx := zlog.Test(context.Background(), t) + data, err := parseCSV("testdata/data.csv") + if err != nil { + t.Fatal(err) + } + g := &fakeGetter{items: data} + r := &claircore.VulnerabilityReport{ + Vulnerabilities: map[string]*claircore.Vulnerability{ + "-1": { + Description: "This is a fake vulnerability that doesn't have a CVE.", + }, + "1": { + Description: "This is a fake vulnerability that looks like CVE-2022-34667.", + }, + "6004": { + Description: "CVE-2024-9972 is here", + }, + "6005": { + Description: "CVE-2024-9986 is awesome", + }, + }, + } + e := &Enricher{} + kind, es, err := e.Enrich(ctx, g, r) + if err != nil { + t.Error(err) + } + if got, want := kind, Type; got != want { + t.Errorf("got: %q, want: %q", got, want) + } + want := map[string][]map[string]interface{}{ + "1": { + { + "cve": "CVE-2022-34667", + "epss": float64(0.00073), + "percentile": float64(0.32799), + }, + }, + "6004": { + { + "cve": "CVE-2024-9972", + "epss": float64(0.00091), + "percentile": float64(0.39923), + }, + }, + "6005": { + { + "cve": "CVE-2024-9986", + "epss": float64(0.00165), + "percentile": float64(0.53867), + }, + }, + } + + got := map[string][]map[string]interface{}{} + if err := json.Unmarshal(es[0], &got); err != nil { + t.Error(err) + } else { + log.Printf("Got: %+v\n", got) + + if !cmp.Equal(got, want) { + t.Error(cmp.Diff(got, want)) + } + } +} + +func parseCSV(filePath string) ([]driver.EnrichmentRecord, error) { file, err := os.Open(filePath) if err != nil { return nil, err } defer file.Close() - reader := csv.NewReader(file) - reader.Comma = ',' // Set comma as the delimiter (can be customized) + scanner := bufio.NewScanner(file) + var records []driver.EnrichmentRecord + var headers []string - var items []map[string]string - var headers []string // Declare headers outside the if block + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) - for { - record, err := reader.Read() - if err == io.EOF { - break // Reached end of file - } - if err != nil { - return nil, err // Handle other errors - } - - if len(record) == 0 || strings.HasPrefix(record[0], "#") { - continue // Skip comment or empty lines + // Skip empty lines and metadata lines + if line == "" || strings.HasPrefix(line, "#") { + continue } - if items == nil { - // Store headers on first data line and initialize items slice - items = make([]map[string]string, 0) + record := strings.Split(line, ",") + if headers == nil { headers = record continue } if len(record) != len(headers) { - log.Printf("warning: skipping line with mismatched fields: %s\n", record) - continue // Skip lines with mismatched number of fields + log.Printf("warning: skipping line with mismatched fields: %s\n", line) + continue } - item := make(map[string]string) - for i, value := range record { - item[headers[i]] = value // Use headers as map keys + r, err := newItemFeed(record, headers) + if err != nil { + return nil, err } - items = append(items, item) + records = append(records, r) } - return items, nil + if err := scanner.Err(); err != nil { + return nil, err + } + return records, nil } diff --git a/enricher/epss/testdata/data.csv b/enricher/epss/testdata/data.csv index 55c1eedc8..306044130 100644 --- a/enricher/epss/testdata/data.csv +++ b/enricher/epss/testdata/data.csv @@ -28,4 +28,4 @@ CVE-2024-9983,0.00090,0.39372 CVE-2024-9984,0.00091,0.39923 CVE-2024-9985,0.00091,0.39923 CVE-2024-9986,0.00165,0.53867 -CVE-2024-9987,0.00043,0.09778 +CVE-2024-9987,0.00043,0.09778 \ No newline at end of file