Skip to content

Commit

Permalink
all: replace errors.Join w core.JoinErr
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Nov 22, 2024
1 parent b62b675 commit edf7e21
Show file tree
Hide file tree
Showing 27 changed files with 100 additions and 72 deletions.
3 changes: 2 additions & 1 deletion intra/backend/core_iptree.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"sync"

"github.com/celzero/firestack/intra/core"
"github.com/celzero/firestack/intra/log"
"github.com/k-sone/critbitgo"
)
Expand Down Expand Up @@ -450,7 +451,7 @@ func ip2cidr(ippOrCidr string) (*net.IPNet, error) {
ipaddr = ip
} else {
log.W("iptree: ip2cidr: cidr %v / ipp %v / ip %v", err, err1, err2)
return nil, errors.Join(err, err1, err2)
return nil, core.JoinErr(err, err1, err2)
}
ip := ipaddr.AsSlice()
mask := net.CIDRMask(ipaddr.BitLen(), ipaddr.BitLen())
Expand Down
4 changes: 2 additions & 2 deletions intra/backend/ipn_pipkeygen.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"strings"

"github.com/celzero/firestack/intra/core"
brsa "github.com/celzero/firestack/intra/core/brsa"
"github.com/celzero/firestack/intra/log"
// "github.com/cloudflare/circl/blindsign/blindrsa"
Expand Down Expand Up @@ -112,7 +112,7 @@ func NewPipKey(pubjwk string, msgOrExistingState string) (PipKey, error) {
c, err1 := brsa.NewClient(brsa.SHA384PSSDeterministic, pub)
v, err2 := brsa.NewVerifier(brsa.SHA384PSSDeterministic, pub)
if err1 != nil || err2 != nil {
err := errors.Join(err1, err2)
err := core.JoinErr(err1, err2)
log.E("pipkey: new: sha384-pss-det verifier err %v", err)
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions intra/core/async.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ outer:
select {
case r := <-ch:
if r.err != nil {
errs = errors.Join(errs, r.err)
errs = JoinErr(errs, r.err)
} else {
return r.t, r.i, r.err
}
case <-time.After(timeout):
errs = errors.Join(errs, errTimeout)
errs = JoinErr(errs, errTimeout)
break outer
}
}
Expand Down
41 changes: 36 additions & 5 deletions intra/core/closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
package core

import (
"fmt"
"io"
"net"
"os"
"reflect"
"strings"
"syscall"

"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
Expand Down Expand Up @@ -230,13 +230,21 @@ func OneErr(errs ...error) error {
}

func JoinErr(errs ...error) error {
if len(errs) <= 0 {
var all []error
for _, err := range errs {
if err == nil {
continue
}
all = append(all, err)
}
if len(all) <= 0 {
return nil
}
if len(errs) == 1 {
return errs[0]
if len(all) == 1 {
return all[0]
}
return fmt.Errorf("%v", errs)

return &errMult{errs: all}
}

func JoinErrIf(y bool, errs ...error) error {
Expand All @@ -245,3 +253,26 @@ func JoinErrIf(y bool, errs ...error) error {
}
return nil
}

type errMult struct {
errs []error
}

func (e *errMult) Error() string {
if len(e.errs) <= 0 {
return "<nil>"
} else if len(e.errs) == 1 {
return e.errs[0].Error()
}

b := strings.Builder{}
for _, err := range e.errs {
_, _ = b.WriteString(err.Error())
_, _ = b.WriteString(" | ")
}
return b.String()
}

func (e *errMult) Unwrap() []error {
return e.errs
}
2 changes: 1 addition & 1 deletion intra/core/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func (a agingconn) canread() error {
}
})
}
return errors.Join(ctlErr, checkErr) // may return nil
return JoinErr(ctlErr, checkErr) // may return nil
}

func logev(err error) log.LogFn {
Expand Down
4 changes: 2 additions & 2 deletions intra/core/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ func setttl(c MinConn, v4 bool) (err error) {
if raw4 != nil {
err1 := raw4.SetControlMessage(ipv4.FlagTTL, true)
err2 := raw4.SetTTL(ttl)
err = errors.Join(err1, err2)
err = JoinErr(err1, err2)
} else if raw6 != nil {
err1 := raw6.SetControlMessage(ipv6.FlagHopLimit, true)
err2 := raw6.SetHopLimit(ttl)
err = errors.Join(err1, err2)
err = JoinErr(err1, err2)
}
return
}
Expand Down
8 changes: 4 additions & 4 deletions intra/dialers/cdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect
log.V("commondial: ip %s works for %s", confirmed, remote)
return conn, nil
}
errs = errors.Join(errs, err)
errs = core.JoinErr(errs, err)
ips.Disconfirm(confirmed)
logwd(err)("rdial: commondial: confirmed %s for %s failed; err %v",
confirmed, remote, err)
Expand All @@ -114,7 +114,7 @@ func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect
if dontretry {
if !confirmedIPOK {
log.E("commondial: ip %s not ok for %s", confirmed, raddr)
errs = errors.Join(errs, errNoIps)
errs = core.JoinErr(errs, errNoIps)
}
return nil, errs
}
Expand All @@ -134,7 +134,7 @@ func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect
for _, ip := range allips {
end := time.Since(start)
if end > dialRetryTimeout {
errs = errors.Join(errs, errRetryTimeout)
errs = core.JoinErr(errs, errRetryTimeout)
log.D("commondial: timeout %s for %s", end, raddr)
break
}
Expand All @@ -150,7 +150,7 @@ func commondial2[D rdials, C rconns](d D, network, laddr, raddr string, connect
log.I("commondial: ip %s works for %s", ip, remote)
return conn, nil
}
errs = errors.Join(errs, err)
errs = core.JoinErr(errs, err)
logwd(err)("rdial: commondial: ip %s for %s failed; err %v", ip, remote, err)
} else {
log.W("commondial: ip %s not ok for %s", ip, raddr)
Expand Down
24 changes: 12 additions & 12 deletions intra/dialers/pdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package dialers

import (
"errors"
"net"
"net/netip"
"time"
Expand Down Expand Up @@ -40,29 +39,30 @@ func ProxyDial(d proxy.Dialer, network, addr string) (net.Conn, error) {
}

// ProxyDials tries to connect to addr using each dialer in dd
func ProxyDials(dd []proxy.Dialer, network, addr string) (c net.Conn, err error) {
func ProxyDials(dd []proxy.Dialer, network, addr string) (c net.Conn, errs error) {
start := time.Now()
tot := len(dd)
for i, d := range dd {
if time.Since(start) > dialRetryTimeout {
err = errors.Join(err, errRetryTimeout)
errs = core.JoinErr(errs, errRetryTimeout)
break
}
c, err = ProxyDial(d, network, addr)
if c == nil && err == nil {
err = errors.Join(err, errNoConn)
conn, err := ProxyDial(d, network, addr)
c = conn
if conn == nil && err == nil {
errs = core.JoinErr(errs, errNoConn)
} else if err != nil {
clos(c)
clos(conn)
log.W("pdial: trying %s dialer of %d / %d to %s", network, i, tot, addr)
err = errors.Join(err)
} else if c != nil {
err = nil
errs = core.JoinErr(errs, err)
} else if conn != nil {
errs = nil
return
}
}
if c == nil && err == nil {
if c == nil {
log.W("pdial: no dialer (sz: %d) succeeded for %s", tot, addr)
return nil, errNoDialer
return nil, core.OneErr(errs, errNoDialer)
}
return
}
5 changes: 2 additions & 3 deletions intra/dialers/retrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
package dialers

import (
"errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -475,7 +474,7 @@ func (r *retrier) CloseWrite() error {
// Close closes the connection and the read and write flags.
func (r *retrier) Close() error {
// also close the read and write flags
return errors.Join(r.CloseRead(), r.CloseWrite())
return core.JoinErr(r.CloseRead(), r.CloseWrite())
}

// LocalAddr behaves slightly strangely: its value may change as a
Expand Down Expand Up @@ -529,5 +528,5 @@ func (r *retrier) SetWriteDeadline(t time.Time) error {
func (r *retrier) SetDeadline(t time.Time) error {
e1 := r.SetReadDeadline(t)
e2 := r.SetWriteDeadline(t)
return errors.Join(e1, e2)
return core.JoinErr(e1, e2)
}
14 changes: 7 additions & 7 deletions intra/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ package intra

import (
"context"
"errors"
"strings"
"sync"

x "github.com/celzero/firestack/intra/backend"
"github.com/celzero/firestack/intra/core"
"github.com/celzero/firestack/intra/dns53"
"github.com/celzero/firestack/intra/dnscrypt"
"github.com/celzero/firestack/intra/dnsx"
Expand Down Expand Up @@ -40,7 +40,7 @@ func AddDNSProxy(t Tunnel, id, ip, port string) error {
p, perr := t.internalProxies()
r, rerr := t.internalResolver()
if rerr != nil || perr != nil {
return errors.Join(rerr, perr)
return core.JoinErr(rerr, perr)
}
ctx := t.internalCtx()
if dns, err := dns53.NewTransport(ctx, id, ip, port, p); err != nil {
Expand All @@ -63,7 +63,7 @@ func SetSystemDNS(t Tunnel, ipcsv string) error {
n := len(ipcsv)
if r == nil || p == nil || n <= 0 {
log.W("dns: cannot set system dns; n: %d, errs: %v %v", n, rerr, perr)
return errors.Join(dnsx.ErrAddFailed, rerr, perr)
return core.JoinErr(dnsx.ErrAddFailed, rerr, perr)
}

// if the ipcsv is localhost, use loopback addresses.
Expand Down Expand Up @@ -130,7 +130,7 @@ func AddProxyDNS(t Tunnel, p x.Proxy) error {
pxr, perr := t.internalProxies()
r, rerr := t.internalResolver()
if rerr != nil || perr != nil {
return errors.Join(rerr, perr)
return core.JoinErr(rerr, perr)
}
ctx := t.internalCtx()
ipOrHostCsv := p.DNS() // may return csv(host:port), csv(ip:port), csv(ips), csv(host)
Expand Down Expand Up @@ -166,7 +166,7 @@ func AddDoHTransport(t Tunnel, id, url, ips string) error {
pxr, perr := t.internalProxies()
r, rerr := t.internalResolver()
if rerr != nil || perr != nil {
return errors.Join(rerr, perr)
return core.JoinErr(rerr, perr)
}
ctx := t.internalCtx()
split := []string{}
Expand All @@ -186,7 +186,7 @@ func AddODoHTransport(t Tunnel, id, endpoint, resolver, epips string) error {
pxr, perr := t.internalProxies()
r, rerr := t.internalResolver()
if rerr != nil || perr != nil {
return errors.Join(rerr, perr)
return core.JoinErr(rerr, perr)
}
ctx := t.internalCtx()
split := []string{}
Expand All @@ -205,7 +205,7 @@ func AddDoTTransport(t Tunnel, id, url, ips string) error {
pxr, perr := t.internalProxies()
r, rerr := t.internalResolver()
if rerr != nil || perr != nil {
return errors.Join(rerr, perr)
return core.JoinErr(rerr, perr)
}
ctx := t.internalCtx()
split := []string{}
Expand Down
3 changes: 1 addition & 2 deletions intra/dns53/goos.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package dns53

import (
"context"
"errors"
"net"
"net/netip"
"time"
Expand Down Expand Up @@ -119,7 +118,7 @@ func (t *goosr) send(msg *dns.Msg) (ans *dns.Msg, elapsed time.Duration, qerr *d
log.D("dns53: goosr: go resolver (why? %v) for %s => %s", errl, host, ips)
ans, err = xdns.AQuadAForQuery(msg, ips...)
} else {
err = errors.Join(errl, errc)
err = core.JoinErr(errl, errc)
}
// TODO: if len(ips) <= 0 synthesize a NXDOMAIN?
}
Expand Down
6 changes: 3 additions & 3 deletions intra/dns53/ipmapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (m *ipmapper) queryIP2(_ context.Context, network, host, uid string) ([]net
}

if err4 != nil || err6 != nil {
errs := errors.Join(err4, err6)
errs := core.JoinErr(err4, err6)
log.E("ipmapper: lookup: query %s err %v", host, errs)
return nil, errs
}
Expand Down Expand Up @@ -164,14 +164,14 @@ func (m *ipmapper) queryIP2(_ context.Context, network, host, uid string) ([]net
}

if lerr4 != nil && lerr6 != nil { // all errors
errs := errors.Join(lerr4, lerr6)
errs := core.JoinErr(lerr4, lerr6)
log.E("ipmapper: lookup: %s: err %v", host, errs)
return nil, errs
} else if noval4 && noval6 { // typecast failed or no answer
log.E("ipmapper: lookup: no answers for %s; len(4)? %d len(6)? %d", host, len(r4), len(r6))
return nil, errNoAns
} else if len(r4) <= 0 && len(r6) <= 0 { // empty answer
errs := errors.Join(errNoAns, lerr4, lerr6)
errs := core.JoinErr(errNoAns, lerr4, lerr6)
log.E("ipmapper: lookup: no answers for %s, err %v", host, errs)
return nil, errs
}
Expand Down
4 changes: 2 additions & 2 deletions intra/dnsx/cacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (t *ctransport) fetch(network string, q *dns.Msg, summary *x.DNSSummary, cb
// return cached/barriered response, instead return an error
inhangover := t.hangover.Exceeds(httl)
if inhangover {
err = errors.Join(err, errHangover)
err = core.JoinErr(err, errHangover)
log.W("cache: barrier: hangover(k: %s); discard ans (has? %t)", key, hasans)
if cachehit {
fillSummary(cachedres.s, fsmm)
Expand All @@ -385,7 +385,7 @@ func (t *ctransport) fetch(network string, q *dns.Msg, summary *x.DNSSummary, cb
fillSummary(cachedsmm, fsmm) // cachedsmm may itself be fsmm
}

return fres, errors.Join(err, ferr)
return fres, core.JoinErr(err, ferr)
}

// check if underlying transport can connect fine, if not treat cache
Expand Down
4 changes: 2 additions & 2 deletions intra/doh/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ func newTransport(ctx context.Context, typ, id, rawurl, otargeturl string, addrs
proxy := rawurl // may be empty
configurl, err := url.Parse(odohconfigdns)
if err != nil || configurl == nil || configurl.Hostname() == "" {
return nil, errors.Join(errNoOdohConfigUrl, err)
return nil, core.JoinErr(errNoOdohConfigUrl, err)
}
targeturl, err := url.Parse(otargeturl)
if err != nil || targeturl == nil || targeturl.Hostname() == "" {
return nil, errors.Join(errNoOdohTarget, err)
return nil, core.JoinErr(errNoOdohTarget, err)
}
proxyurl, _ := url.Parse(proxy) // ignore err as proxy may be empty

Expand Down
Loading

0 comments on commit edf7e21

Please sign in to comment.