From 45abc068b061c84279ea2beca1ac9ea6bb9d7310 Mon Sep 17 00:00:00 2001 From: Son Roy Almerol Date: Mon, 11 Nov 2024 13:29:34 -0500 Subject: [PATCH] use channels for loop waits instead --- cmd/windows_agent/local_drives.go | 47 ++++++++++ cmd/windows_agent/service.go | 141 +++++++++++++----------------- internal/agent/sftp/config.go | 10 +-- internal/agent/sftp/sftp.go | 63 ++++++++----- internal/utils/local_drives.go | 16 ---- internal/utils/wait.go | 14 +++ internal/utils/waitgroup.go | 12 --- 7 files changed, 164 insertions(+), 139 deletions(-) create mode 100644 cmd/windows_agent/local_drives.go delete mode 100644 internal/utils/local_drives.go create mode 100644 internal/utils/wait.go delete mode 100644 internal/utils/waitgroup.go diff --git a/cmd/windows_agent/local_drives.go b/cmd/windows_agent/local_drives.go new file mode 100644 index 0000000..bf179da --- /dev/null +++ b/cmd/windows_agent/local_drives.go @@ -0,0 +1,47 @@ +//go:build windows + +package main + +import ( + "fmt" + "os" + + "github.com/sonroyaalmerol/pbs-plus/internal/agent/sftp" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +type Drive struct { + Letter string + ErrorChan chan string +} + +func getLocalDrives() (r []Drive) { + for _, drive := range "ABCDEFGHIJKLMNOPQRSTUVWXYZ" { + f, err := os.Open(string(drive) + ":\\") + if err == nil { + r = append(r, Drive{Letter: string(drive)}) + f.Close() + } + } + return +} + +func (drive *Drive) serveSFTP(p *agentService) error { + rune := []rune(drive.Letter)[0] + sftpConfig, err := sftp.InitializeSFTPConfig(p.svc, drive.Letter) + if err != nil { + return fmt.Errorf("Unable to initialize SFTP config: %s", err) + } + if err := sftpConfig.PopulateKeys(); err != nil { + return fmt.Errorf("Unable to populate SFTP keys: %s", err) + } + + port, err := utils.DriveLetterPort(rune) + if err != nil { + return fmt.Errorf("Unable to map letter to port: %s", err) + } + + go sftp.Serve(p.ctx, drive.ErrorChan, sftpConfig, "0.0.0.0", port, drive.Letter) + + return nil +} diff --git a/cmd/windows_agent/service.go b/cmd/windows_agent/service.go index b804bd6..aa497c1 100644 --- a/cmd/windows_agent/service.go +++ b/cmd/windows_agent/service.go @@ -8,12 +8,10 @@ import ( _ "embed" "fmt" "net/http" - "sync" "time" "github.com/kardianos/service" "github.com/sonroyaalmerol/pbs-plus/internal/agent" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/sftp" "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" "github.com/sonroyaalmerol/pbs-plus/internal/syslog" "github.com/sonroyaalmerol/pbs-plus/internal/utils" @@ -30,7 +28,6 @@ type PingResp struct { type agentService struct { svc service.Service - wg sync.WaitGroup ctx context.Context cancel context.CancelFunc } @@ -38,60 +35,35 @@ type agentService struct { func (p *agentService) Start(s service.Service) error { p.ctx, p.cancel = context.WithCancel(context.Background()) - go p.runLoop() + go p.startPing() + go p.run() return nil } func (p *agentService) startPing() { - firstPing := true - lastCheck := time.Now() - for { - select { - case <-p.ctx.Done(): - utils.SetEnvironment("PBS_AGENT_STATUS", "Agent service is not running") - return - default: - if time.Since(lastCheck) > time.Second*5 || firstPing { - firstPing = false - - var pingResp PingResp - pingErr := agent.ProxmoxHTTPRequest(http.MethodGet, "/api2/json/ping", nil, &pingResp) - if pingErr != nil { - utils.SetEnvironment("PBS_AGENT_STATUS", fmt.Sprintf("Error - (%s)", pingErr.Error())) - } else if !pingResp.Data.Pong { - utils.SetEnvironment("PBS_AGENT_STATUS", "Error - server did not return expected data") - } else { - utils.SetEnvironment("PBS_AGENT_STATUS", "Connected") - } - lastCheck = time.Now() - } + ping := func() { + var pingResp PingResp + pingErr := agent.ProxmoxHTTPRequest(http.MethodGet, "/api2/json/ping", nil, &pingResp) + if pingErr != nil { + utils.SetEnvironment("PBS_AGENT_STATUS", fmt.Sprintf("Error - (%s)", pingErr.Error())) + } else if !pingResp.Data.Pong { + utils.SetEnvironment("PBS_AGENT_STATUS", "Error - server did not return expected data") + } else { + utils.SetEnvironment("PBS_AGENT_STATUS", "Connected") } } -} - -func (p *agentService) runLoop() { - logger, err := syslog.InitializeLogger(p.svc) - if err != nil { - utils.SetEnvironment("PBS_AGENT_STATUS", fmt.Sprintf("Failed to initialize logger -> %s", err.Error())) - return - } - go p.startPing() + ping() for { - p.run() - wgDone := utils.WaitChan(&p.wg) - + retryWait := utils.WaitChan(time.Second * 5) select { case <-p.ctx.Done(): - snapshots.CloseAllSnapshots() + utils.SetEnvironment("PBS_AGENT_STATUS", "Agent service is not running") return - case <-wgDone: - utils.SetEnvironment("PBS_AGENT_STATUS", "Unexpected shutdown - restarting SSH endpoints") - logger.Error("SSH endpoints stopped unexpectedly. Restarting...") - p.wg = sync.WaitGroup{} - time.Sleep(5 * time.Second) + case <-retryWait: + ping() } } } @@ -104,54 +76,61 @@ func (p *agentService) run() { return } - firstUrlCheck := true - lastCheck := time.Now() -waitUrl: - for { - select { - case <-p.ctx.Done(): - return - default: - if time.Since(lastCheck) > time.Second*5 || firstUrlCheck { - firstUrlCheck = false - key, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\PBSPlus\Config`, registry.QUERY_VALUE) - if err == nil { - defer key.Close() - - if serverUrl, _, err := key.GetStringValue("ServerURL"); err == nil && serverUrl != "" { - break waitUrl - } - } - lastCheck = time.Now() + urlExists := func() bool { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\PBSPlus\Config`, registry.QUERY_VALUE) + if err == nil { + defer key.Close() + + if serverUrl, _, err := key.GetStringValue("ServerURL"); err == nil && serverUrl != "" { + return true } } + + return false } - drives := utils.GetLocalDrives() - for _, driveLetter := range drives { - rune := []rune(driveLetter)[0] - sftpConfig, err := sftp.InitializeSFTPConfig(p.svc, driveLetter) - if err != nil { - logger.Error(fmt.Sprintf("Unable to initialize SFTP config: %s", err)) - continue - } - if err := sftpConfig.PopulateKeys(); err != nil { - logger.Error(fmt.Sprintf("Unable to populate SFTP keys: %s", err)) - continue + if !urlExists() { + for !urlExists() { + retryWait := utils.WaitChan(time.Second * 5) + select { + case <-p.ctx.Done(): + return + case <-retryWait: + } } + } - port, err := utils.DriveLetterPort(rune) - if err != nil { - logger.Error(fmt.Sprintf("Unable to map letter to port: %s", err)) - continue + drives := getLocalDrives() + for _, drive := range drives { + drive.ErrorChan = make(chan string) + err = drive.serveSFTP(p) + for err != nil { + logger.Errorf("Drive SFTP error: %v", err) + retryWait := utils.WaitChan(time.Second * 5) + select { + case <-p.ctx.Done(): + return + case <-retryWait: + err = drive.serveSFTP(p) + } } - p.wg.Add(1) go func() { - sftp.Serve(p.ctx, sftpConfig, "0.0.0.0", port, driveLetter) - p.wg.Done() + defer close(drive.ErrorChan) + + for { + select { + case <-p.ctx.Done(): + return + case err := <-drive.ErrorChan: + logger.Errorf("SFTP %s drive error: %s", drive.Letter, err) + } + } }() } + + <-p.ctx.Done() + snapshots.CloseAllSnapshots() } func (p *agentService) Stop(s service.Service) error { diff --git a/internal/agent/sftp/config.go b/internal/agent/sftp/config.go index 6990e3b..f6dcc74 100644 --- a/internal/agent/sftp/config.go +++ b/internal/agent/sftp/config.go @@ -17,7 +17,6 @@ import ( "time" "github.com/kardianos/service" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" "github.com/sonroyaalmerol/pbs-plus/internal/utils" "golang.org/x/crypto/ssh" "golang.org/x/sys/windows/registry" @@ -36,15 +35,10 @@ func (s *SFTPConfig) GetRegistryKey() string { return fmt.Sprintf("Software\\PBSPlus\\Config\\SFTP-%s", s.BasePath) } -var logger *syslog.Logger - func InitializeSFTPConfig(svc service.Service, driveLetter string) (*SFTPConfig, error) { var err error - if logger == nil { - logger, err = syslog.InitializeLogger(svc) - if err != nil { - return nil, fmt.Errorf("InitializeLogger: failed to initialize logger -> %w", err) - } + if err != nil { + return nil, fmt.Errorf("InitializeLogger: failed to initialize logger -> %w", err) } baseKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, "Software\\PBSPlus\\Config", registry.QUERY_VALUE) diff --git a/internal/agent/sftp/sftp.go b/internal/agent/sftp/sftp.go index 9148bc5..0c688aa 100644 --- a/internal/agent/sftp/sftp.go +++ b/internal/agent/sftp/sftp.go @@ -9,57 +9,78 @@ import ( "net" "net/url" "strings" + "time" "github.com/pkg/sftp" "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" "golang.org/x/crypto/ssh" ) -func Serve(ctx context.Context, sftpConfig *SFTPConfig, address, port string, driveLetter string) { - listenAt := fmt.Sprintf("%s:%s", address, port) - listener, err := net.Listen("tcp", listenAt) - if err != nil { - logger.Error(fmt.Sprintf("Port is already in use! Failed to listen on %s: %v", listenAt, err)) - return +func Serve(ctx context.Context, errChan chan string, sftpConfig *SFTPConfig, address, port string, driveLetter string) { + var listener net.Listener + + listening := false + + listen := func() { + var err error + listenAt := fmt.Sprintf("%s:%s", address, port) + listener, err = net.Listen("tcp", listenAt) + if err != nil { + errChan <- fmt.Sprintf("Port is already in use! Failed to listen on %s: %v", listenAt, err) + return + } + + listening = true } - defer listener.Close() - logger.Infof("Listening on %v\n", listener.Addr()) + listen() + + for !listening { + retryWait := utils.WaitChan(time.Second * 5) + select { + case <-ctx.Done(): + return + case <-retryWait: + listen() + } + } + + defer listener.Close() for { select { case <-ctx.Done(): - logger.Info("Context cancelled. Terminating SFTP listener.") return default: conn, err := listener.Accept() if err != nil { - logger.Error(fmt.Sprintf("failed to accept connection: %v", err)) + errChan <- fmt.Sprintf("failed to accept connection: %v", err) continue } - go handleConnection(ctx, conn, sftpConfig, driveLetter) + go handleConnection(ctx, errChan, conn, sftpConfig, driveLetter) } } } -func handleConnection(ctx context.Context, conn net.Conn, sftpConfig *SFTPConfig, driveLetter string) { +func handleConnection(ctx context.Context, errChan chan string, conn net.Conn, sftpConfig *SFTPConfig, driveLetter string) { defer conn.Close() server, err := url.Parse(sftpConfig.Server) if err != nil { - logger.Error(fmt.Sprintf("failed to parse server IP: %v", err)) + errChan <- fmt.Sprintf("failed to parse server IP: %v", err) return } if !strings.Contains(conn.RemoteAddr().String(), server.Hostname()) { - logger.Error(fmt.Sprintf("WARNING: an unregistered client has attempted to connect: %s", conn.RemoteAddr().String())) + errChan <- fmt.Sprintf("WARNING: an unregistered client has attempted to connect: %s", conn.RemoteAddr().String()) return } sconn, chans, reqs, err := ssh.NewServerConn(conn, sftpConfig.ServerConfig) if err != nil { - logger.Error(fmt.Sprintf("failed to perform SSH handshake: %v", err)) + errChan <- fmt.Sprintf("failed to perform SSH handshake: %v", err) return } defer sconn.Close() @@ -81,7 +102,7 @@ func handleConnection(ctx context.Context, conn net.Conn, sftpConfig *SFTPConfig go handleRequests(ctx, requests, sftpRequest) if requested, ok := <-sftpRequest; ok && requested { - go handleSFTP(ctx, channel, driveLetter) + go handleSFTP(ctx, errChan, channel, driveLetter) } else { channel.Close() } @@ -119,19 +140,19 @@ func handleRequests(ctx context.Context, requests <-chan *ssh.Request, sftpReque } } -func handleSFTP(ctx context.Context, channel ssh.Channel, driveLetter string) { +func handleSFTP(ctx context.Context, errChan chan string, channel ssh.Channel, driveLetter string) { defer channel.Close() snapshot, err := snapshots.Snapshot(driveLetter) if err != nil { - logger.Error(fmt.Sprintf("failed to initialize snapshot: %v", err)) + errChan <- fmt.Sprintf("failed to initialize snapshot: %v", err) return } sftpHandler, err := NewSftpHandler(ctx, driveLetter, snapshot) if err != nil { snapshot.Close() - logger.Error(fmt.Sprintf("failed to initialize handler: %v", err)) + errChan <- fmt.Sprintf("failed to initialize handler: %v", err) return } @@ -142,7 +163,5 @@ func handleSFTP(ctx context.Context, channel ssh.Channel, driveLetter string) { server.Close() }() - if err := server.Serve(); err != nil { - logger.Infof("sftp server completed with error: %s", err) - } + _ = server.Serve() } diff --git a/internal/utils/local_drives.go b/internal/utils/local_drives.go deleted file mode 100644 index c6e708a..0000000 --- a/internal/utils/local_drives.go +++ /dev/null @@ -1,16 +0,0 @@ -package utils - -import ( - "os" -) - -func GetLocalDrives() (r []string) { - for _, drive := range "ABCDEFGHIJKLMNOPQRSTUVWXYZ" { - f, err := os.Open(string(drive) + ":\\") - if err == nil { - r = append(r, string(drive)) - f.Close() - } - } - return -} diff --git a/internal/utils/wait.go b/internal/utils/wait.go new file mode 100644 index 0000000..0286bcc --- /dev/null +++ b/internal/utils/wait.go @@ -0,0 +1,14 @@ +package utils + +import ( + "time" +) + +func WaitChan(duration time.Duration) <-chan struct{} { + done := make(chan struct{}) + go func() { + time.Sleep(duration) + close(done) + }() + return done +} diff --git a/internal/utils/waitgroup.go b/internal/utils/waitgroup.go deleted file mode 100644 index 6fcd6af..0000000 --- a/internal/utils/waitgroup.go +++ /dev/null @@ -1,12 +0,0 @@ -package utils - -import "sync" - -func WaitChan(wg *sync.WaitGroup) <-chan struct{} { - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - return done -}