From 50038541688e12595abc0c361660cbd6e3b7a46d Mon Sep 17 00:00:00 2001 From: Acha Bill Date: Mon, 4 Mar 2024 15:07:44 +0100 Subject: [PATCH] fix: unit test --- cmd/bee/cmd/split.go | 53 ++++++++++++++++----------------------- cmd/bee/cmd/split_test.go | 8 ++++++ 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/cmd/bee/cmd/split.go b/cmd/bee/cmd/split.go index d74650e066a..cee464b0d1a 100644 --- a/cmd/bee/cmd/split.go +++ b/cmd/bee/cmd/split.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/redundancy" @@ -21,16 +22,15 @@ import ( // putter is a putter that stores all the split chunk addresses of a file type putter struct { - c chan swarm.Chunk + cb func(chunk swarm.Chunk) error } -func (s *putter) Put(ctx context.Context, chunk swarm.Chunk) error { - s.c <- chunk - return nil +func (s *putter) Put(_ context.Context, chunk swarm.Chunk) error { + return s.cb(chunk) } -func newPutter() *putter { +func newPutter(cb func(ch swarm.Chunk) error) *putter { return &putter{ - c: make(chan swarm.Chunk), + cb: cb, } } @@ -98,23 +98,19 @@ func splitRefs(cmd *cobra.Command) { logger.Info("splitting", "file", inputFileName, "rLevel", rLevel) logger.Info("writing output", "file", outputFileName) - store := newPutter() + var refs []string + store := newPutter(func(ch swarm.Chunk) error { + refs = append(refs, ch.Address().String()) + return nil + }) writer, err := os.OpenFile(outputFileName, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { return fmt.Errorf("open output file: %w", err) } defer writer.Close() - var refs []string - go func() { - for chunk := range store.c { - refs = append(refs, chunk.Address().String()) - } - }() - p := requestPipelineFn(store, false, redundancy.Level(rLevel)) rootRef, err := p(context.Background(), reader) - close(store.c) if err != nil { return fmt.Errorf("pipeline: %w", err) } @@ -192,28 +188,23 @@ func splitChunks(cmd *cobra.Command) { logger.Info("splitting", "file", inputFileName, "rLevel", rLevel) logger.Info("writing output", "dir", outputDir) - store := newPutter() - ctx, cancel := context.WithCancel(context.Background()) - var chunksCount int64 - go func() { - for chunk := range store.c { - filePath := filepath.Join(outputDir, chunk.Address().String()) - err := os.WriteFile(filePath, chunk.Data(), 0644) - if err != nil { - logger.Error(err, "write chunk") - cancel() - } - chunksCount++ + var chunksCount atomic.Int64 + store := newPutter(func(chunk swarm.Chunk) error { + filePath := filepath.Join(outputDir, chunk.Address().String()) + err := os.WriteFile(filePath, chunk.Data(), 0644) + if err != nil { + return err } - }() + chunksCount.Add(1) + return nil + }) p := requestPipelineFn(store, false, redundancy.Level(rLevel)) - rootRef, err := p(ctx, reader) - close(store.c) + rootRef, err := p(context.Background(), reader) if err != nil { return fmt.Errorf("pipeline: %w", err) } - logger.Info("done", "root", rootRef.String(), "chunks", chunksCount) + logger.Info("done", "root", rootRef.String(), "chunks", chunksCount.Load()) return nil }, } diff --git a/cmd/bee/cmd/split_test.go b/cmd/bee/cmd/split_test.go index e848497b77d..3dde5c386f2 100644 --- a/cmd/bee/cmd/split_test.go +++ b/cmd/bee/cmd/split_test.go @@ -14,6 +14,7 @@ import ( "os" "path" "path/filepath" + "sync" "testing" "github.com/ethersphere/bee/cmd/bee/cmd" @@ -108,6 +109,10 @@ func TestDBSplitChunks(t *testing.T) { if err != nil { t.Fatal(err) } + + if len(entries) != len(putter.chunks) { + t.Fatal("number of chunks does not match") + } for _, entry := range entries { ref := entry.Name() if _, ok := putter.chunks[ref]; !ok { @@ -149,9 +154,12 @@ func compare(path string, chunk swarm.Chunk) (error, bool) { type putter struct { chunks map[string]swarm.Chunk + mu sync.Mutex } func (s *putter) Put(_ context.Context, chunk swarm.Chunk) error { + s.mu.Lock() + defer s.mu.Unlock() s.chunks[chunk.Address().String()] = chunk return nil }