From ef034231ef36d907b760b9f7affc7e5bd5dfad05 Mon Sep 17 00:00:00 2001 From: "J. Victor Martins" Date: Sun, 28 Apr 2024 21:30:12 -0700 Subject: [PATCH] jsonblob: Support iteration with rangefunc style Signed-off-by: J. Victor Martins --- libvuln/jsonblob/jsonblob.go | 77 ++++++++++++++++++ libvuln/jsonblob/jsonblob_test.go | 126 ++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+) diff --git a/libvuln/jsonblob/jsonblob.go b/libvuln/jsonblob/jsonblob.go index 5d80aff73..629ee13d9 100644 --- a/libvuln/jsonblob/jsonblob.go +++ b/libvuln/jsonblob/jsonblob.go @@ -44,6 +44,64 @@ type Store struct { latest map[driver.UpdateKind]uuid.UUID } +type iter2[X, Y any] func(yield func(X, Y) bool) + +// RecordIter iterates over records of an update operation. +type RecordIter iter2[*claircore.Vulnerability, *driver.EnrichmentRecord] + +// OperationIter iterates over operations, offering a nested iterator for records. +type OperationIter iter2[*driver.UpdateOperation, RecordIter] + +// Iterate iterates over each record serialized in the [io.Reader] grouping by +// update operations. It returns an OperationIter, which is an iterator over each +// update operation with a nested iterator for the associated vulnerability +// entries, and an error function, to check for iteration errors. +func Iterate(r io.Reader) (OperationIter, func() error) { + var err error + var de diskEntry + + d := json.NewDecoder(r) + err = d.Decode(&de) + + it := func(yield func(*driver.UpdateOperation, RecordIter) bool) { + for err == nil { + op := &driver.UpdateOperation{ + Ref: de.Ref, + Updater: de.Updater, + Fingerprint: de.Fingerprint, + Date: de.Date, + Kind: de.Kind, + } + it := func(yield func(*claircore.Vulnerability, *driver.EnrichmentRecord) bool) { + var vuln *claircore.Vulnerability + var en *driver.EnrichmentRecord + for err == nil && op.Ref == de.Ref { + vuln, en, err = de.Unmarshal() + if err != nil || !yield(vuln, en) { + break + } + err = d.Decode(&de) + } + } + if !yield(op, it) { + break + } + for err == nil && op.Ref == de.Ref { + err = d.Decode(&de) + } + } + } + + errF := func() error { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + return it, errF +} + // Load reads in all the records serialized in the provided [io.Reader]. func Load(ctx context.Context, r io.Reader) (*Loader, error) { l := Loader{ @@ -252,6 +310,25 @@ type diskEntry struct { Kind driver.UpdateKind } +// Unmarshal parses the JSON-encoded vulnerability or enrichment record encoded +// in the disk entry, based on the update kind. +func (de *diskEntry) Unmarshal() (v *claircore.Vulnerability, e *driver.EnrichmentRecord, err error) { + switch de.Kind { + case driver.VulnerabilityKind: + v = &claircore.Vulnerability{} + if err = json.Unmarshal(de.Vuln.buf, v); err != nil { + return + } + case driver.EnrichmentKind: + e = &driver.EnrichmentRecord{} + err = json.Unmarshal(de.Enrichment.buf, e) + if err != nil { + return + } + } + return +} + // Entries returns a map containing all the Entries stored by calls to // UpdateVulnerabilities. // diff --git a/libvuln/jsonblob/jsonblob_test.go b/libvuln/jsonblob/jsonblob_test.go index 141c6a7db..fac5eece2 100644 --- a/libvuln/jsonblob/jsonblob_test.go +++ b/libvuln/jsonblob/jsonblob_test.go @@ -118,3 +118,129 @@ func TestEnrichments(t *testing.T) { } t.Logf("wrote:\n%s", buf.String()) } + +func TestIterationWithBreak(t *testing.T) { + ctx := context.Background() + a, err := New() + if err != nil { + t.Fatal(err) + } + + var want, got struct { + V []*claircore.Vulnerability + E []driver.EnrichmentRecord + } + + want.V = test.GenUniqueVulnerabilities(10, "test") + ref, err := a.UpdateVulnerabilities(ctx, "test", "", want.V) + if err != nil { + t.Error(err) + } + t.Logf("ref: %v", ref) + + // We will break after getting vulnerabilities. + test.GenEnrichments(15) + ref, err = a.UpdateEnrichments(ctx, "test", "", want.E) + if err != nil { + t.Error(err) + } + t.Logf("ref: %v", ref) + + var buf bytes.Buffer + defer func() { + t.Logf("wrote:\n%s", buf.String()) + }() + r, w := io.Pipe() + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { defer w.Close(); return a.Store(w) }) + eg.Go(func() error { + i, iErr := Iterate(io.TeeReader(r, &buf)) + i(func(o *driver.UpdateOperation, i RecordIter) bool { + i(func(v *claircore.Vulnerability, e *driver.EnrichmentRecord) bool { + switch o.Kind { + case driver.VulnerabilityKind: + got.V = append(got.V, v) + case driver.EnrichmentKind: + got.E = append(got.E, *e) + default: + t.Errorf("unnexpected kind: %s", o.Kind) + } + return true + }) + // Stop the operation iter, effectively skipping enrichments. + return false + }) + return iErr() + }) + if err := eg.Wait(); err != nil { + t.Error(err) + } + if !cmp.Equal(got, want) { + t.Error(cmp.Diff(got, want)) + } +} + +func TestIterationWithSkip(t *testing.T) { + ctx := context.Background() + a, err := New() + if err != nil { + t.Fatal(err) + } + + var want, got struct { + V []*claircore.Vulnerability + E []driver.EnrichmentRecord + } + + want.V = test.GenUniqueVulnerabilities(10, "test") + ref, err := a.UpdateVulnerabilities(ctx, "test", "", want.V) + if err != nil { + t.Error(err) + } + t.Logf("ref: %v", ref) + + // We will skip the updater "skip this". + test.GenUniqueVulnerabilities(10, "skip this") + + want.E = test.GenEnrichments(15) + ref, err = a.UpdateEnrichments(ctx, "test", "", want.E) + if err != nil { + t.Error(err) + } + t.Logf("ref: %v", ref) + + var buf bytes.Buffer + defer func() { + t.Logf("wrote:\n%s", buf.String()) + }() + r, w := io.Pipe() + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { defer w.Close(); return a.Store(w) }) + eg.Go(func() error { + i, iErr := Iterate(io.TeeReader(r, &buf)) + i(func(o *driver.UpdateOperation, i RecordIter) bool { + if o.Updater == "skip this" { + return true + } + i(func(v *claircore.Vulnerability, e *driver.EnrichmentRecord) bool { + switch o.Kind { + case driver.VulnerabilityKind: + got.V = append(got.V, v) + case driver.EnrichmentKind: + got.E = append(got.E, *e) + default: + t.Errorf("unnexpected kind: %s", o.Kind) + } + return true + }) + return true + }) + return iErr() + }) + if err := eg.Wait(); err != nil { + t.Error(err) + } + if !cmp.Equal(got, want) { + t.Error(cmp.Diff(got, want)) + } +}