diff --git a/intra/core/barrier.go b/intra/core/barrier.go index bbd1d11e..b901aad6 100644 --- a/intra/core/barrier.go +++ b/intra/core/barrier.go @@ -19,6 +19,11 @@ import ( "time" ) +const ( + Anew = iota + Shared +) + // V is an in-flight or completed Barrier.Do V type V struct { wg sync.WaitGroup @@ -66,7 +71,7 @@ func (ba *Barrier) addLocked(k string) *V { // sure that only one execution is in-flight for a given key at a // time. If a duplicate comes in, the duplicate caller waits for the // original to complete and receives the same results. -func (ba *Barrier) Do(k string, me func() (any, error)) *V { +func (ba *Barrier) Do(k string, me func() (any, error)) (*V, int) { ba.mu.Lock() c, ok := ba.getLocked(k) if ok { @@ -74,7 +79,7 @@ func (ba *Barrier) Do(k string, me func() (any, error)) *V { c.N.Add(1) c.wg.Wait() // wait for the in-flight req to complete - return c + return c, Shared } c = ba.addLocked(k) ba.mu.Unlock() @@ -82,5 +87,5 @@ func (ba *Barrier) Do(k string, me func() (any, error)) *V { c.Val, c.Err = me() c.wg.Done() // unblock all waiters - return c + return c, Anew } diff --git a/intra/dnsx/cacher.go b/intra/dnsx/cacher.go index 3b9cd839..b4073618 100644 --- a/intra/dnsx/cacher.go +++ b/intra/dnsx/cacher.go @@ -300,18 +300,11 @@ func (t *ctransport) Type() string { } func (t *ctransport) fetch(network string, q []byte, msg *dns.Msg, summary *Summary, cb *cache, key string) (r []byte, err error) { - sendRequest := func(async bool) ([]byte, error) { - var finalsumm *Summary - if async { - finalsumm = new(Summary) - } else { - finalsumm = summary - } - + sendRequest := func(finalsumm *Summary) ([]byte, error) { finalsumm.ID = t.Transport.ID() finalsumm.Type = t.Transport.Type() - rv := t.reqbarrier.Do(key, func() (any, error) { + rv, st := t.reqbarrier.Do(key, func() (any, error) { ans, err := t.Transport.Query(network, q, finalsumm) cb.put(ans, finalsumm) return &cres{ans: xdns.AsMsg(ans), s: finalsumm}, err @@ -322,9 +315,14 @@ func (t *ctransport) fetch(network string, q []byte, msg *dns.Msg, summary *Summ return nil, errCacheResponseMismatch } + // if rv is "shared", then use this req's summary over the "shared" one + if st == core.Shared { + cachedres.s = finalsumm + } + finalres, origsumm, finalerr := asResponse(msg, cachedres, true) // fill summary regardless of errors - origsumm.FillInto(finalsumm) // origsumm may be equal to s + origsumm.FillInto(finalsumm) // origsumm may be equal to finalsumm return finalres, errors.Join(rv.Err, finalerr) } @@ -353,7 +351,7 @@ func (t *ctransport) fetch(network string, q []byte, msg *dns.Msg, summary *Summ // fallthrough to sendRequest } else if cachedsummary != nil { if !isfresh { // not fresh, fetch in the background - go sendRequest(true) + go sendRequest(new(Summary)) } // change summary fields to reflect cached response, except for latency cachedsummary.FillInto(summary) @@ -362,7 +360,8 @@ func (t *ctransport) fetch(network string, q []byte, msg *dns.Msg, summary *Summ return } // else: fallthrough to sendRequest } - return sendRequest(false) // summary is filled by underlying transport + + return sendRequest(summary) // summary is filled by underlying transport } func (t *ctransport) Query(network string, q []byte, summary *Summary) ([]byte, error) { diff --git a/intra/ipn/wgnet.go b/intra/ipn/wgnet.go index 4c6c6414..58c1acee 100644 --- a/intra/ipn/wgnet.go +++ b/intra/ipn/wgnet.go @@ -573,7 +573,7 @@ func (tnet *wgtun) DialContext(ctx context.Context, network, address string) (ne } } - rv := tnet.reqbarrier.Do(host, func() (any, error) { + rv, _ := tnet.reqbarrier.Do(host, func() (any, error) { return tnet.LookupContextHost(ctx, host) }) if rv.Err != nil {