From 10e5457139ed34bc5cb8fac4634c3b8c31321d3d Mon Sep 17 00:00:00 2001 From: Mark Pashmfouroush Date: Sun, 31 Mar 2024 15:28:01 +0100 Subject: [PATCH] random fixes for preventing memory/conn leak Signed-off-by: Mark Pashmfouroush --- go.mod | 4 +- go.sum | 6 ++ ips.go | 4 +- relay.go | 172 ++++++++++++++++++++++++++------------------------- tcp.go | 8 +++ udp.go | 74 +++++++++++++++------- volunteer.sh | 138 ----------------------------------------- 7 files changed, 160 insertions(+), 246 deletions(-) create mode 100644 tcp.go delete mode 100644 volunteer.sh diff --git a/go.mod b/go.mod index 866f00d..e9d8900 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/uoosef/bepass-relay -go 1.19 +go 1.22 require github.com/gaissmai/cidrtree v0.1.4 + +require github.com/peterbourgon/ff/v4 v4.0.0-alpha.4 diff --git a/go.sum b/go.sum index 2b8fd2a..5a94ac6 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,8 @@ github.com/gaissmai/cidrtree v0.1.4 h1:/aYnv1LIwjtSDHNr1eNN99WJeh6vLrB+Sgr1tRMhHDc= github.com/gaissmai/cidrtree v0.1.4/go.mod h1:nrjEeeMZmvoJpLcSvZ3qIVFxw/+9GHKi7wDHHmHKGRI= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/peterbourgon/ff/v4 v4.0.0-alpha.4 h1:aiqS8aBlF9PsAKeMddMSfbwp3smONCn3UO8QfUg0Z7Y= +github.com/peterbourgon/ff/v4 v4.0.0-alpha.4/go.mod h1:H/13DK46DKXy7EaIxPhk2Y0EC8aubKm35nBjBe8AAGc= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/ips.go b/ips.go index 409545f..69b959f 100644 --- a/ips.go +++ b/ips.go @@ -9,6 +9,8 @@ import ( var ( // List of torrent trackers torrentTrackers = []netip.Prefix{ + netip.MustParsePrefix("127.0.0.0/8"), + netip.MustParsePrefix("::1/128"), netip.MustParsePrefix("93.158.213.92/32"), netip.MustParsePrefix("102.223.180.235/32"), netip.MustParsePrefix("23.134.88.6/32"), @@ -101,7 +103,6 @@ var ( // List of Cloudflare IP ranges cfRanges = []netip.Prefix{ - netip.MustParsePrefix("127.0.0.0/8"), netip.MustParsePrefix("103.21.244.0/22"), netip.MustParsePrefix("103.22.200.0/22"), netip.MustParsePrefix("103.31.4.0/22"), @@ -116,7 +117,6 @@ var ( netip.MustParsePrefix("190.93.240.0/20"), netip.MustParsePrefix("197.234.240.0/22"), netip.MustParsePrefix("198.41.128.0/17"), - netip.MustParsePrefix("::1/128"), netip.MustParsePrefix("2400:cb00::/32"), netip.MustParsePrefix("2405:8100::/32"), netip.MustParsePrefix("2405:b500::/32"), diff --git a/relay.go b/relay.go index d422d4e..5ae2d3d 100644 --- a/relay.go +++ b/relay.go @@ -3,97 +3,83 @@ package main import ( "bufio" - "flag" + "context" + "errors" "fmt" "io" - "log" + "log/slog" "net" "net/netip" + "os" + "os/signal" "strings" -) - -const BUFFER_SIZE = 256 * 1024 - -type Server struct { - host string - port string -} - -type Client struct { - conn net.Conn -} + "syscall" -type Config struct { - Host string - Port string -} + "github.com/peterbourgon/ff/v4" + "github.com/peterbourgon/ff/v4/ffhelp" +) -func New(config *Config) *Server { - return &Server{ - host: config.Host, - port: config.Port, - } -} +const BUFFER_SIZE = 2048 -func (server *Server) Run() { - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", server.host, server.port)) +func run(ctx context.Context, l *slog.Logger, bind string) error { + listener, err := net.Listen("tcp", bind) if err != nil { - log.Fatal(err) + return err } - defer func() { - _ = listener.Close() - }() + defer listener.Close() for { - conn, err := listener.Accept() - if err != nil { - log.Fatal(err) + select { + case <-ctx.Done(): + return nil + default: + conn, err := listener.Accept() + if err != nil { + l.Error("failed to accept connection", "error", err.Error()) + continue + } + + src := netip.MustParseAddrPort(conn.RemoteAddr().String()) + + // Check if srcIP is in the whitelist + if !connFilter.isSourceAllowed(src.Addr()) { + l.Debug("blocked connection", "address", src) + conn.Close() + continue + } + + go handleConnection(l, conn) } - - src, err := netip.ParseAddrPort(conn.RemoteAddr().String()) - if err != nil { - log.Printf("unable to parse host %v", conn.RemoteAddr()) - _ = conn.Close() - continue - } - - // Check if srcIP is in the whitelist - if !connFilter.isSourceAllowed(src.Addr()) { - log.Printf("blocked connection from: %v", src) - conn.Close() - continue - } - - go (&Client{conn: conn}).handleRequest() } } -func (client *Client) handleRequest() { - defer func() { - _ = client.conn.Close() - }() - reader := bufio.NewReader(client.conn) +func handleConnection(l *slog.Logger, lConn net.Conn) { + reader := bufio.NewReader(lConn) + header, _ := reader.ReadBytes(byte(13)) if len(header) < 1 { + lConn.Close() return } + inputHeader := strings.Split(string(header[:len(header)-1]), "@") if len(inputHeader) < 2 { + lConn.Close() return } + network := "tcp" if inputHeader[0] == "udp" { network = "udp" } - address := strings.Replace(inputHeader[1], "$", ":", -1) - if strings.Contains(address, "temp-mail.org") { - return - } + address := strings.Replace(inputHeader[1], "$", ":", -1) dh, _, err := net.SplitHostPort(address) if err != nil { + lConn.Close() return } + // check if ip is not blocked blockFlag := false addr, err := netip.ParseAddr(dh) @@ -101,6 +87,7 @@ func (client *Client) handleRequest() { // the host may not be an IP, try to resolve it ips, err := net.LookupIP(dh) if err != nil { + lConn.Close() return } @@ -112,32 +99,26 @@ func (client *Client) handleRequest() { blockFlag = !addr.IsValid() || !connFilter.isDestinationAllowed(addr) if blockFlag { - log.Printf("destination host is blocked: %s\n", address) + l.Debug("destination host is blocked", "address", address) + lConn.Close() return } - if network == "udp" { - handleUDPOverTCP(client.conn, address) - return - } - - // transmit data - log.Printf("%s Dialing to %s...\n", network, address) + switch network { + case "tcp": + rConn, err := net.Dial(network, address) + if err != nil { + l.Error("failed to dial", "protocol", network, "address", address, "error", err.Error()) + lConn.Close() + return + } - rConn, err := net.Dial(network, address) + go handleTCP(lConn, rConn) - if err != nil { - log.Println(fmt.Errorf("failed to connect to socket: %v", err)) - return + case "udp": + go handleUDPOverTCP(l, lConn, address) } - - defer func() { - _ = rConn.Close() - }() - - // transmit data - go Copy(client.conn, rConn) - Copy(rConn, client.conn) + l.Debug("relaying connection", "protocol", network, "address", address) } // Copy reads from src and writes to dst until either EOF is reached on src or @@ -154,10 +135,33 @@ func Copy(src io.Reader, dst io.Writer) { } func main() { - var config Config - flag.StringVar(&config.Host, "b", "0.0.0.0", "Server IP address") - flag.StringVar(&config.Port, "p", "6666", "Server Port number") - flag.Parse() - server := New(&config) - server.Run() + fs := ff.NewFlagSet("bepass-relay") + var ( + verbose = fs.Bool('v', "verbose", "enable verbose logging") + bind = fs.String('b', "bind", "0.0.0.0:6666", "bind address") + ) + + err := ff.Parse(fs, os.Args[1:]) + switch { + case errors.Is(err, ff.ErrHelp): + fmt.Fprintf(os.Stderr, "%s\n", ffhelp.Flags(fs)) + os.Exit(0) + case err != nil: + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + l := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) + + if *verbose { + l = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + } + + ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + if err := run(ctx, l, *bind); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + <-ctx.Done() } diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..2925106 --- /dev/null +++ b/tcp.go @@ -0,0 +1,8 @@ +package main + +import "net" + +func handleTCP(lConn, rConn net.Conn) { + go Copy(lConn, rConn) + Copy(rConn, lConn) +} diff --git a/udp.go b/udp.go index 68d1ba5..8eab72c 100644 --- a/udp.go +++ b/udp.go @@ -2,8 +2,11 @@ package main import ( - "log" + "errors" + "io" + "log/slog" "net" + "time" ) var ( @@ -12,30 +15,42 @@ var ( ) // readFromConn reads data from a net.Conn and sends it to a channel. -func readFromConn(conn net.Conn, c chan<- []byte) { +func readFromConn(l *slog.Logger, conn net.Conn, c chan<- []byte) { + defer conn.Close() defer close(c) // Close the channel when done. - buf := make([]byte, 32*1024) + buf := make([]byte, 2048) for { + if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil { + return + } + n, err := conn.Read(buf) - if n > 0 { - c <- append([]byte(nil), buf[:n]...) // Send a copy of the slice. + if err != nil && errors.Is(err, io.EOF) { + return } + if err != nil { - log.Printf("Connection closed: %v--->%v\r\n", conn.LocalAddr(), conn.RemoteAddr()) + l.Debug("connection closed", "protocol", "udp", "address", conn.RemoteAddr(), "error", err.Error()) return } + + if n > 0 { + c <- append([]byte(nil), buf[:n]...) // Send a copy of the slice. + } } } // handleUDPOverTCP handles UDP-over-TCP traffic. -func handleUDPOverTCP(conn net.Conn, destination string) { +func handleUDPOverTCP(l *slog.Logger, conn net.Conn, destination string) { + // On return, delete the destination from the map of active tunnels defer delete(activeTunnels, destination) - writeToWebsocketChannel := make(chan []byte) - activeTunnels[destination] = writeToWebsocketChannel + // Store a byte channel in the map of active tunnels. The data read + // from the UDP socket is sent on this channel. + activeTunnels[destination] = make(chan []byte) wsReadDataChan := make(chan []byte) - go readFromConn(conn, wsReadDataChan) + go readFromConn(l, conn, wsReadDataChan) for { select { @@ -43,25 +58,34 @@ func handleUDPOverTCP(conn net.Conn, destination string) { if dataFromWS == nil || len(dataFromWS) < 8 { return } - if udpWriteChan, err := getOrCreateUDPChan(destination, string(dataFromWS[:8])); err == nil { - udpWriteChan <- dataFromWS - } else { - log.Printf("Unable to create connection to destination network: %v\r\n", err) + + udpWriteChan, err := getOrCreateUDPChan(l, destination, string(dataFromWS[:8])) + if err != nil { + l.Debug("unable to connect to destination", "protocol", "udp", "address", destination, "error", err.Error()) + continue } + + udpWriteChan <- dataFromWS + case dataFromUDP := <-activeTunnels[destination]: - if dataFromUDP != nil { - _, err := conn.Write(dataFromUDP) - if err != nil { - log.Printf("Unable to write on destination network: %v\r\n", err) + if dataFromUDP == nil { + continue + } + + if err := conn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { return } + + if _, err := conn.Write(dataFromUDP); err != nil { + l.Debug("can't write to socket", "protocol", "udp", "address", destination, "error", err.Error()) + return } } } } // getOrCreateUDPChan returns an existing UDP channel or creates a new one. -func getOrCreateUDPChan(destination, header string) (chan []byte, error) { +func getOrCreateUDPChan(l *slog.Logger, destination, header string) (chan []byte, error) { channelID := destination + header if udpWriteChan, ok := udpToTCPChannels[channelID]; ok { return udpWriteChan, nil @@ -74,27 +98,35 @@ func getOrCreateUDPChan(destination, header string) (chan []byte, error) { udpToTCPChannels[channelID] = make(chan []byte) udpReadChanFromConn := make(chan []byte) - go readFromConn(udpConn, udpReadChanFromConn) + go readFromConn(l, udpConn, udpReadChanFromConn) go func() { defer func() { delete(udpToTCPChannels, channelID) - _ = udpConn.Close() + udpConn.Close() }() + for { select { case dataFromWS := <-udpToTCPChannels[channelID]: if len(dataFromWS) < 8 { return } + + if err := udpConn.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { + return + } + _, err := udpConn.Write(dataFromWS[8:]) if err != nil { return } + case dataFromUDP := <-udpReadChanFromConn: if dataFromUDP == nil { return } + if c, ok := activeTunnels[destination]; ok { c <- append([]byte(header[6:]), dataFromUDP...) } diff --git a/volunteer.sh b/volunteer.sh deleted file mode 100644 index d1ae001..0000000 --- a/volunteer.sh +++ /dev/null @@ -1,138 +0,0 @@ -#!/bin/bash - -# Function to print characters with delay -print_with_delay() { - local text="$1" - local delay="$2" - for ((i = 0; i < ${#text}; i++)); do - echo -n "${text:i:1}" - sleep "$delay" - done -} - -# Function to uninstall the service -uninstall_service() { - sudo systemctl stop cfb.service - sudo systemctl disable cfb.service - sudo rm /etc/systemd/system/cfb.service - sudo rm -rf /opt/cf-bepass - echo "Service has been uninstalled." -} - -# Introduction animation -echo -print_with_delay "**** Thanks for Becoming a Volunteer Maintainer ****" 0.03 -echo - -# Display options -echo -echo "Select an option:" -echo "------------------------------" -echo "1) Install service" -echo "2) Uninstall service" -echo "3) Set IP On Worker" -echo "------------------------------" -read -p "Please select: " option - -if [[ $option == "1" ]]; then - # Check the operating system - if [[ $(uname -s) != "Linux" ]]; then - echo "Not supported OS: $(uname -s)" - exit 1 - fi - - # Step 1: Install Golang and clone the repository - sudo apt-get update - sudo apt-get install -y golang git - sudo mkdir -p /opt - cd /opt || exit 1 - sudo git clone https://github.com/uoosef/cf-bepass.git - cd cf-bepass || exit 1 - CGO_ENABLED=0 go build -ldflags '-s -w' -trimpath *.go - - # Step 2: Create a systemd service for Bepass - cat > /etc/systemd/system/cfb.service <