diff --git a/cmd/wireproxy/main.go b/cmd/wireproxy/main.go index 5e89fc3..0f8bfda 100644 --- a/cmd/wireproxy/main.go +++ b/cmd/wireproxy/main.go @@ -45,11 +45,9 @@ func main() { exePath := executablePath() unveilOrPanic("/", "r") unveilOrPanic(exePath, "x") - if err := protect.UnveilBlock(); err != nil { - log.Fatal(err) - } // only allow standard stdio operation, file reading, networking, and exec + // also remove unveil permission to lock unveil pledgeOrPanic("stdio rpath inet dns proc exec") isDaemonProcess := len(os.Args) > 1 && os.Args[1] == daemonProcess diff --git a/routine.go b/routine.go index 8fcbe8c..d7a0f92 100644 --- a/routine.go +++ b/routine.go @@ -37,6 +37,11 @@ type RoutineSpawner interface { SpawnRoutine(vt *VirtualTun) } +type addressPort struct { + address string + port uint16 +} + // LookupAddr lookups a hostname. // DNS traffic may or may not be routed depending on VirtualTun's setting func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, error) { @@ -47,29 +52,7 @@ func (d VirtualTun) LookupAddr(ctx context.Context, name string) ([]string, erro } } -// ResolveAddrPort resolves a hostname and returns an AddrPort. -// DNS traffic may or may not be routed depending on VirtualTun's setting -func (d VirtualTun) ResolveAddrPort(saddr string) (*netip.AddrPort, error) { - name, sport, err := net.SplitHostPort(saddr) - if err != nil { - return nil, err - } - - addr, err := d.ResolveAddrWithContext(context.Background(), name) - if err != nil { - return nil, err - } - - port, err := strconv.Atoi(sport) - if err != nil || port < 0 || port > 65535 { - return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")} - } - - addrPort := netip.AddrPortFrom(*addr, uint16(port)) - return &addrPort, nil -} - -// ResolveAddrPort resolves a hostname and returns an AddrPort. +// ResolveAddrPortWithContext resolves a hostname and returns an AddrPort. // DNS traffic may or may not be routed depending on VirtualTun's setting func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*netip.Addr, error) { addrs, err := d.LookupAddr(ctx, name) @@ -101,7 +84,7 @@ func (d VirtualTun) ResolveAddrWithContext(ctx context.Context, name string) (*n return &addr, nil } -// ResolveAddrPort resolves a hostname and returns an IP. +// Resolve resolves a hostname and returns an IP. // DNS traffic may or may not be routed depending on VirtualTun's setting func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { addr, err := d.ResolveAddrWithContext(ctx, name) @@ -112,6 +95,30 @@ func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, return ctx, addr.AsSlice(), nil } +func parseAddressPort(endpoint string) (*addressPort, error) { + name, sport, err := net.SplitHostPort(endpoint) + if err != nil { + return nil, err + } + + port, err := strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errors.New("port must be numeric")} + } + + return &addressPort{address: name, port: uint16(port)}, nil +} + +func (d VirtualTun) resolveToAddrPort(endpoint *addressPort) (*netip.AddrPort, error) { + addr, err := d.ResolveAddrWithContext(context.Background(), endpoint.address) + if err != nil { + return nil, err + } + + addrPort := netip.AddrPortFrom(*addr, endpoint.port) + return &addrPort, nil +} + // Spawns a socks5 server. func (config *Socks5Config) SpawnRoutine(vt *VirtualTun) { conf := &socks5.Config{Dial: vt.tnet.DialContext, Resolver: vt} @@ -150,8 +157,16 @@ func connForward(bufSize int, from io.ReadWriteCloser, to io.ReadWriteCloser) { } // tcpClientForward starts a new connection via wireguard and forward traffic from `conn` -func tcpClientForward(tnet *netstack.Net, target *net.TCPAddr, conn net.Conn) { - sconn, err := tnet.DialTCP(target) +func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { + target, err := vt.resolveToAddrPort(raddr) + if err != nil { + errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + return + } + + tcpAddr := TCPAddrFromAddrPort(*target) + + sconn, err := vt.tnet.DialTCP(tcpAddr) if err != nil { errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error()) return @@ -163,11 +178,10 @@ func tcpClientForward(tnet *netstack.Net, target *net.TCPAddr, conn net.Conn) { // Spawns a local TCP server which acts as a proxy to the specified target func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) { - raddr, err := vt.ResolveAddrPort(conf.Target) + raddr, err := parseAddressPort(conf.Target) if err != nil { log.Fatal(err) } - tcpAddr := TCPAddrFromAddrPort(*raddr) server, err := net.ListenTCP("tcp", conf.BindAddress) if err != nil { @@ -179,13 +193,21 @@ func (conf *TCPClientTunnelConfig) SpawnRoutine(vt *VirtualTun) { if err != nil { log.Fatal(err) } - go tcpClientForward(vt.tnet, tcpAddr, conn) + go tcpClientForward(vt, raddr, conn) } } // tcpServerForward starts a new connection locally and forward traffic from `conn` -func tcpServerForward(target *net.TCPAddr, conn net.Conn) { - sconn, err := net.DialTCP("tcp", nil, target) +func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { + target, err := vt.resolveToAddrPort(raddr) + if err != nil { + errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + return + } + + tcpAddr := TCPAddrFromAddrPort(*target) + + sconn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil { errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) return @@ -197,11 +219,10 @@ func tcpServerForward(target *net.TCPAddr, conn net.Conn) { // Spawns a TCP server on wireguard which acts as a proxy to the specified target func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { - raddr, err := vt.ResolveAddrPort(conf.Target) + raddr, err := parseAddressPort(conf.Target) if err != nil { log.Fatal(err) } - tcpAddr := TCPAddrFromAddrPort(*raddr) addr := &net.TCPAddr{Port: conf.ListenPort} server, err := vt.tnet.ListenTCP(addr) @@ -214,6 +235,6 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { if err != nil { log.Fatal(err) } - go tcpServerForward(tcpAddr, conn) + go tcpServerForward(vt, raddr, conn) } }