diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 91d8fac56fe..1102b313607 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -327,7 +327,7 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV strategyTimeout := getter.DefaultStrategyTimeout.String() ctx := r.Context() - ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout, logger) if err != nil { logger.Error(err, err.Error()) jsonhttp.BadRequest(w, "could not parse headers") @@ -521,7 +521,7 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h strategyTimeout := getter.DefaultStrategyTimeout.String() ctx := r.Context() - ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout, logger) if err != nil { logger.Error(err, err.Error()) jsonhttp.BadRequest(w, "could not parse headers") diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index 634d4eb0166..fdcb3f1209f 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -151,9 +151,7 @@ func TestBzzUploadDownloadWithRedundancy(t *testing.T) { if rLevel == 0 { t.Skip("NA") } - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", fileDownloadResource(refResponse.Reference.String()), nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", fileDownloadResource(refResponse.Reference.String()), nil) if err != nil { t.Fatal(err) } @@ -161,9 +159,25 @@ func TestBzzUploadDownloadWithRedundancy(t *testing.T) { req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "false") req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) - _, err = client.Do(req) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatalf("expected error %v; got %v", io.ErrUnexpectedEOF, err) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d; got %d", http.StatusOK, resp.StatusCode) + } + _, err = dataReader.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Equal(resp.Body) + if err != nil { + t.Fatal(err) + } + if ok { + t.Fatal("there should be missing data") } }) diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index dc9c850084c..2b91708b031 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -129,7 +129,7 @@ func New(ctx context.Context, g storage.Getter, putter storage.Putter, address s maxBranching = rLevel.GetMaxShards() } } else { - // if root chunk has no redundancy, strategy is ignored and set to NONE and strict is set to true + // if root chunk has no redundancy, strategy is ignored and set to DATA and strict is set to true conf.Strategy = getter.DATA conf.Strict = true } diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 125f2d727d5..bfb6e4d2cd6 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -24,6 +24,7 @@ import ( "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" + "github.com/ethersphere/bee/pkg/log" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" @@ -31,7 +32,6 @@ import ( "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" - "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "gitlab.com/nolash/go-mockbytes" "golang.org/x/sync/errgroup" ) @@ -1112,15 +1112,14 @@ func TestJoinerRedundancy(t *testing.T) { strategyTimeout := 100 * time.Millisecond // all data can be read back readCheck := func(t *testing.T, expErr error) { - ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) - defer cancel() + ctx := context.Background() strategyTimeoutStr := strategyTimeout.String() decodeTimeoutStr := (10 * strategyTimeout).String() fallback := true s := getter.RACE - ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodeTimeoutStr, &strategyTimeoutStr) + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodeTimeoutStr, &strategyTimeoutStr, log.Noop) if err != nil { t.Fatal(err) } @@ -1169,14 +1168,14 @@ func TestJoinerRedundancy(t *testing.T) { } } t.Run("no recovery possible with no chunk stored", func(t *testing.T) { - readCheck(t, context.DeadlineExceeded) + readCheck(t, storage.ErrNotFound) }) if err := putter.store(shardCnt - 1); err != nil { t.Fatal(err) } t.Run("no recovery possible with 1 short of shardCnt chunks stored", func(t *testing.T) { - readCheck(t, context.DeadlineExceeded) + readCheck(t, storage.ErrNotFound) }) if err := putter.store(1); err != nil { @@ -1253,21 +1252,15 @@ func TestJoinerRedundancyMultilevel(t *testing.T) { canReadRange := func(t *testing.T, s getter.Strategy, fallback bool, levels int, canRead bool) { ctx := context.Background() strategyTimeout := 100 * time.Millisecond - decodingTimeout := 600 * time.Millisecond - if racedetection.IsOn() { - decodingTimeout *= 2 - } strategyTimeoutStr := strategyTimeout.String() decodingTimeoutStr := (2 * strategyTimeout).String() - ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodingTimeoutStr, &strategyTimeoutStr) + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodingTimeoutStr, &strategyTimeoutStr, log.Noop) if err != nil { t.Fatal(err) } - ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) - defer cancel() j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index 4e8da1b6390..ce1181b045d 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -6,37 +6,47 @@ package getter import ( "context" + "errors" "io" "sync" "sync/atomic" + "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/klauspost/reedsolomon" ) +var ( + errStrategyNotAllowed = errors.New("strategy not allowed") + errStrategyFailed = errors.New("strategy failed") +) + // decoder is a private implementation of storage.Getter // if retrieves children of an intermediate chunk potentially using erasure decoding // it caches sibling chunks if erasure decoding started already type decoder struct { - fetcher storage.Getter // network retrieval interface to fetch chunks - putter storage.Putter // interface to local storage to save reconstructed chunks - addrs []swarm.Address // all addresses of the intermediate chunk - inflight []atomic.Bool // locks to protect wait channels and RS buffer - cache map[string]int // map from chunk address shard position index - waits []chan struct{} // wait channels for each chunk - rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding - ready chan struct{} // signal channel for successful retrieval of shardCnt chunks - lastLen int // length of the last data chunk in the RS buffer - shardCnt int // number of data shards - parityCnt int // number of parity shards - wg sync.WaitGroup // wait group to wait for all goroutines to finish - mu sync.Mutex // mutex to protect buffer - err error // error of the last erasure decoding - fetchedCnt atomic.Int32 // count successful retrievals - cancel func() // cancel function for RS decoding - remove func() // callback to remove decoder from decoders cache - config Config // configuration + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + addrs []swarm.Address // all addresses of the intermediate chunk + inflight []atomic.Bool // locks to protect wait channels and RS buffer + cache map[string]int // map from chunk address shard position index + waits []chan error // wait channels for each chunk + rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding + goodRecovery chan struct{} // signal channel for successful retrieval of shardCnt chunks + badRecovery chan struct{} // signals that either the recovery has failed or not allowed to run + lastLen int // length of the last data chunk in the RS buffer + shardCnt int // number of data shards + parityCnt int // number of parity shards + wg sync.WaitGroup // wait group to wait for all goroutines to finish + mu sync.Mutex // mutex to protect buffer + err error // error of the last erasure decoding + fetchedCnt atomic.Int32 // count successful retrievals + failedCnt atomic.Int32 // count successful retrievals + cancel func() // cancel function for RS decoding + remove func() // callback to remove decoder from decoders cache + config Config // configuration + logger log.Logger } type Getter interface { @@ -49,37 +59,46 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter ctx, cancel := context.WithCancel(context.Background()) size := len(addrs) - rsg := &decoder{ - fetcher: g, - putter: p, - addrs: addrs, - inflight: make([]atomic.Bool, size), - cache: make(map[string]int, size), - waits: make([]chan struct{}, shardCnt), - rsbuf: make([][]byte, size), - ready: make(chan struct{}, 1), - cancel: cancel, - remove: remove, - shardCnt: shardCnt, - parityCnt: size - shardCnt, - config: conf, + d := &decoder{ + fetcher: g, + putter: p, + addrs: addrs, + inflight: make([]atomic.Bool, size), + cache: make(map[string]int, size), + waits: make([]chan error, size), + rsbuf: make([][]byte, size), + goodRecovery: make(chan struct{}), + badRecovery: make(chan struct{}), + cancel: cancel, + remove: remove, + shardCnt: shardCnt, + parityCnt: size - shardCnt, + config: conf, + logger: conf.Logger.WithName("redundancy").Build(), } // after init, cache and wait channels are immutable, need no locking for i := 0; i < shardCnt; i++ { - rsg.cache[addrs[i].ByteString()] = i - rsg.waits[i] = make(chan struct{}) + d.cache[addrs[i].ByteString()] = i + } + + // after init, cache and wait channels are immutable, need no locking + for i := 0; i < size; i++ { + d.waits[i] = make(chan error) } // prefetch chunks according to strategy if !conf.Strict || conf.Strategy != NONE { - rsg.wg.Add(1) + d.wg.Add(1) go func() { - rsg.err = rsg.prefetch(ctx) - rsg.wg.Done() + defer d.wg.Done() + d.err = d.prefetch(ctx) }() + } else { // recovery not allowed + close(d.badRecovery) } - return rsg + + return d } // Get will call parities and other sibling chunks if the chunk address cannot be retrieved @@ -89,110 +108,194 @@ func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, err if !ok { return nil, storage.ErrNotFound } - if g.fly(i, true) { + err := g.fetch(ctx, i, true) + if err != nil { + return nil, err + } + return swarm.NewChunk(addr, g.getData(i)), nil +} + +// fetch retrieves a chunk from the netstore if it is the first time the chunk is fetched. +// If the fetch fails and waiting for the recovery is allowed, the function will wait +// for either a good or bad recovery signal. +func (g *decoder) fetch(ctx context.Context, i int, waitForRecovery bool) (err error) { + + waitRecovery := func(err error) error { + if !waitForRecovery { + return err + } + + select { + case <-g.badRecovery: + return storage.ErrNotFound + case <-g.goodRecovery: + g.logger.Debug("recovered chunk", "address", g.addrs[i]) + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + // first time + if g.fly(i) { + + fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) + defer cancel() + g.wg.Add(1) - go func() { - g.fetch(ctx, i) - g.wg.Done() - }() + defer g.wg.Done() + + // retrieval + ch, err := g.fetcher.Get(fctx, g.addrs[i]) + if err != nil { + g.failedCnt.Add(1) + close(g.waits[i]) + return waitRecovery(err) + } + + g.fetchedCnt.Add(1) + g.setData(i, ch.Data()) + close(g.waits[i]) + return nil } + select { case <-g.waits[i]: case <-ctx.Done(): - return nil, ctx.Err() + return ctx.Err() } - return swarm.NewChunk(addr, g.getData(i)), nil -} -// setData sets the data shard in the RS buffer -func (g *decoder) setData(i int, chdata []byte) { - data := chdata - // pad the chunk with zeros if it is smaller than swarm.ChunkSize - if len(data) < swarm.ChunkWithSpanSize { - g.lastLen = len(data) - data = make([]byte, swarm.ChunkWithSpanSize) - copy(data, chdata) + if g.getData(i) != nil { + return nil } - g.rsbuf[i] = data + + return waitRecovery(storage.ErrNotFound) } -// getData returns the data shard from the RS buffer -func (g *decoder) getData(i int) []byte { - if i == g.shardCnt-1 && g.lastLen > 0 { - return g.rsbuf[i][:g.lastLen] // cut padding +func (g *decoder) prefetch(ctx context.Context) error { + defer g.remove() + + run := func(s Strategy) error { + if err := g.runStrategy(ctx, s); err != nil { + return err + } + + return g.recover(ctx) + } + + var err error + for s := g.config.Strategy; s < strategyCnt; s++ { + + err = run(s) + if err != nil { + if s == DATA || s == RACE { + g.logger.Debug("failed recovery", "strategy", s) + } + } + if err == nil { + if s > DATA { + g.logger.Debug("successful recovery", "strategy", s) + } + close(g.goodRecovery) + break + } + if g.config.Strict { // only run one strategy + break + } } - return g.rsbuf[i] -} -// fly commits to retrieve the chunk (fly and land) -// it marks a chunk as inflight and returns true unless it is already inflight -// the atomic bool implements a singleflight pattern -func (g *decoder) fly(i int, up bool) (success bool) { - return g.inflight[i].CompareAndSwap(!up, up) -} - -// fetch retrieves a chunk from the underlying storage -// it must be called asynchonously and only once for each chunk (singleflight pattern) -// it races with erasure recovery which takes precedence even if it started later -// due to the fact that erasure recovery could only implement global locking on all shards -func (g *decoder) fetch(ctx context.Context, i int) { - fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) - defer cancel() - ch, err := g.fetcher.Get(fctx, g.addrs[i]) if err != nil { - _ = g.fly(i, false) // unset inflight - return + close(g.badRecovery) + return err } - g.mu.Lock() - defer g.mu.Unlock() - if i < len(g.waits) { - select { - case <-g.waits[i]: // if chunk is retrieved, ignore - return - default: + return err +} + +func (g *decoder) runStrategy(ctx context.Context, s Strategy) error { + + // across the different strategies, the common goal is to fetch at least as many chunks + // as the number of data shards. + // DATA strategy has a max error tolerance of zero. + // RACE strategy has a max error tolerance of number of parity chunks. + var allowedErrs int + var m []int + + switch s { + case NONE: + return errStrategyNotAllowed + case DATA: + // only retrieve data shards + m = g.unattemptedDataShards() + allowedErrs = 0 + case PROX: + // proximity driven selective fetching + // NOT IMPLEMENTED + return errStrategyNotAllowed + case RACE: + allowedErrs = g.parityCnt + // retrieve all chunks at once enabling race among chunks + m = g.unattemptedDataShards() + for i := g.shardCnt; i < len(g.addrs); i++ { + m = append(m, i) } } - select { - case <-ctx.Done(): // if context is cancelled, ignore - _ = g.fly(i, false) // unset inflight - return - default: + if len(m) == 0 { + return nil } - // write chunk to rsbuf and signal waiters - g.setData(i, ch.Data()) // save the chunk in the RS buffer - if i < len(g.waits) { // if the chunk is a data shard - close(g.waits[i]) // signal that the chunk is retrieved - } + c := make(chan error, len(m)) - // if all chunks are retrieved, signal ready - n := g.fetchedCnt.Add(1) - if n == int32(g.shardCnt) { - close(g.ready) // signal that just enough chunks are retrieved for decoding + for _, i := range m { + g.wg.Add(1) + go func(i int) { + defer g.wg.Done() + c <- g.fetch(ctx, i, false) + }(i) } -} -// missing gathers missing data shards not yet retrieved -// it sets the chunk as inflight and returns the index of the missing data shards -func (g *decoder) missing() (m []int) { - for i := 0; i < g.shardCnt; i++ { + for { select { - case <-g.waits[i]: // if chunk is retrieved, ignore - continue - default: + case <-ctx.Done(): + return ctx.Err() + case <-c: + if g.fetchedCnt.Load() >= int32(g.shardCnt) { + return nil + } + if g.failedCnt.Load() > int32(allowedErrs) { + return errStrategyFailed + } } - _ = g.fly(i, true) // commit (RS) or will commit to retrieve the chunk - m = append(m, i) // remember the missing chunk } - return m +} + +// recover wraps the stages of data shard recovery: +// 1. gather missing data shards +// 2. decode using Reed-Solomon decoder +// 3. save reconstructed chunks +func (g *decoder) recover(ctx context.Context) error { + // gather missing shards + m := g.missingDataShards() + if len(m) == 0 { + return nil // recovery is not needed as there are no missing data chunks + } + + // decode using Reed-Solomon decoder + if err := g.decode(ctx); err != nil { + return err + } + + // save chunks + return g.save(ctx, m) } // decode uses Reed-Solomon erasure coding decoder to recover data shards // it must be called after shqrdcnt shards are retrieved -// it must be called under g.mu mutex protection func (g *decoder) decode(ctx context.Context) error { + g.mu.Lock() + defer g.mu.Unlock() + enc, err := reedsolomon.New(g.shardCnt, g.parityCnt) if err != nil { return err @@ -202,37 +305,64 @@ func (g *decoder) decode(ctx context.Context) error { return enc.ReconstructData(g.rsbuf) } -// recover wraps the stages of data shard recovery: -// 1. gather missing data shards -// 2. decode using Reed-Solomon decoder -// 3. save reconstructed chunks -func (g *decoder) recover(ctx context.Context) error { - // buffer lock acquired - g.mu.Lock() - defer g.mu.Unlock() +func (g *decoder) unattemptedDataShards() (m []int) { + for i := 0; i < g.shardCnt; i++ { + select { + case <-g.waits[i]: // attempted + continue + default: + m = append(m, i) // remember the missing chunk + } + } + return m +} - // gather missing shards - m := g.missing() - if len(m) == 0 { - return nil +// it must be called under mutex protection +func (g *decoder) missingDataShards() (m []int) { + for i := 0; i < g.shardCnt; i++ { + if g.getData(i) == nil { + m = append(m, i) + } } + return m +} - // decode using Reed-Solomon decoder - if err := g.decode(ctx); err != nil { - return err +// setData sets the data shard in the RS buffer +func (g *decoder) setData(i int, chdata []byte) { + g.mu.Lock() + defer g.mu.Unlock() + + data := chdata + // pad the chunk with zeros if it is smaller than swarm.ChunkSize + if len(data) < swarm.ChunkWithSpanSize { + g.lastLen = len(data) + data = make([]byte, swarm.ChunkWithSpanSize) + copy(data, chdata) } + g.rsbuf[i] = data +} - // close wait channels for missing chunks - for _, i := range m { - close(g.waits[i]) +// getData returns the data shard from the RS buffer +func (g *decoder) getData(i int) []byte { + g.mu.Lock() + defer g.mu.Unlock() + if i == g.shardCnt-1 && g.lastLen > 0 { + return g.rsbuf[i][:g.lastLen] // cut padding } + return g.rsbuf[i] +} - // save chunks - return g.save(ctx, m) +// fly commits to retrieve the chunk (fly and land) +// it marks a chunk as inflight and returns true unless it is already inflight +// the atomic bool implements a singleflight pattern +func (g *decoder) fly(i int) (success bool) { + return g.inflight[i].CompareAndSwap(false, true) } // save iterate over reconstructed shards and puts the corresponding chunks to local storage func (g *decoder) save(ctx context.Context, missing []int) error { + g.mu.Lock() + defer g.mu.Unlock() for _, i := range missing { if err := g.putter.Put(ctx, swarm.NewChunk(g.addrs[i], g.rsbuf[i])); err != nil { return err diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index b18caa55c12..95609fc4c1e 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -19,6 +19,7 @@ import ( "github.com/ethersphere/bee/pkg/cac" "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/storage" inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" @@ -72,6 +73,7 @@ func TestGetterRACE(t *testing.T) { // TestGetterFallback tests the retrieval of chunks with missing data shards // using the strict or fallback mode starting with NONE and DATA strategies func TestGetterFallback(t *testing.T) { + t.Skip("removed strategy timeout") t.Run("GET", func(t *testing.T) { t.Run("NONE", func(t *testing.T) { t.Run("strict", func(t *testing.T) { @@ -119,6 +121,7 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { Strategy: getter.RACE, FetchTimeout: 2 * strategyTimeout, StrategyTimeout: strategyTimeout, + Logger: log.Noop, } g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index 9269636e3e9..410a0003962 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -10,6 +10,7 @@ import ( "fmt" "time" + "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/retrieval" ) @@ -25,6 +26,7 @@ type ( modeKey struct{} fetchTimeoutKey struct{} strategyTimeoutKey struct{} + loggerKey struct{} Strategy = int ) @@ -34,6 +36,7 @@ type Config struct { Strict bool FetchTimeout time.Duration StrategyTimeout time.Duration + Logger log.Logger } const ( @@ -50,6 +53,7 @@ var DefaultConfig = Config{ Strict: DefaultStrict, FetchTimeout: DefaultFetchTimeout, StrategyTimeout: DefaultStrategyTimeout, + Logger: log.Noop, } // NewConfigFromContext returns a new Config based on the context @@ -86,6 +90,12 @@ func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err err return conf, e("strategy timeout") } } + if val := ctx.Value(loggerKey{}); val != nil { + conf.Logger, ok = val.(log.Logger) + if !ok { + return conf, e("strategy timeout") + } + } return conf, nil } @@ -110,8 +120,13 @@ func SetStrategyTimeout(ctx context.Context, timeout time.Duration) context.Cont return context.WithValue(ctx, strategyTimeoutKey{}, timeout) } +// SetStrategyTimeout sets the timeout for each strategy +func SetLogger(ctx context.Context, l log.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, l) +} + // SetConfigInContext sets the config params in the context -func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fetchTimeout, strategyTimeout *string) (context.Context, error) { +func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fetchTimeout, strategyTimeout *string, logger log.Logger) (context.Context, error) { if s != nil { ctx = SetStrategy(ctx, *s) } @@ -136,85 +151,9 @@ func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fe ctx = SetStrategyTimeout(ctx, dur) } - return ctx, nil -} - -func (g *decoder) prefetch(ctx context.Context) error { - if g.config.Strict && g.config.Strategy == NONE { - return nil + if logger != nil { + ctx = SetLogger(ctx, logger) } - defer g.remove() - var cancels []func() - cancelAll := func() { - for _, cancel := range cancels { - cancel() - } - } - defer cancelAll() - run := func(s Strategy) error { - if s == PROX { // NOT IMPLEMENTED - return errors.New("strategy not implemented") - } - var stop <-chan time.Time - if s < RACE { - timer := time.NewTimer(g.config.StrategyTimeout) - defer timer.Stop() - stop = timer.C - } - lctx, cancel := context.WithCancel(ctx) - cancels = append(cancels, cancel) - prefetch(lctx, g, s) - - select { - // successfully retrieved shardCnt number of chunks - case <-g.ready: - cancelAll() - case <-stop: - return fmt.Errorf("prefetching with strategy %d timed out", s) - case <-ctx.Done(): - return nil - } - // call the erasure decoder - // if decoding is successful terminate the prefetch loop - return g.recover(ctx) // context to cancel when shardCnt chunks are retrieved - } - var err error - for s := g.config.Strategy; s < strategyCnt; s++ { - err = run(s) - if g.config.Strict || err == nil { - break - } - } - - return err -} - -// prefetch launches the retrieval of chunks based on the strategy -func prefetch(ctx context.Context, g *decoder, s Strategy) { - var m []int - switch s { - case NONE: - return - case DATA: - // only retrieve data shards - m = g.missing() - case PROX: - // proximity driven selective fetching - // NOT IMPLEMENTED - case RACE: - // retrieve all chunks at once enabling race among chunks - m = g.missing() - for i := g.shardCnt; i < len(g.addrs); i++ { - m = append(m, i) - } - } - for _, i := range m { - i := i - g.wg.Add(1) - go func() { - g.fetch(ctx, i) - g.wg.Done() - }() - } + return ctx, nil }