diff --git a/agent/agent.go b/agent/agent.go index 24c0b896..67a02aca 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -8,6 +8,7 @@ package agent import ( "bufio" + "bytes" "encoding/binary" "fmt" "io" @@ -54,6 +55,9 @@ type Options struct { // can call Close before shutting down. // Optional. ShutdownCleanup bool + + // Key is remote access credentials. Maxlength 64 characters + Key string } // Listen starts the gops agent on a host process. Once agent started, users @@ -107,12 +111,18 @@ func Listen(opts Options) error { return err } - go listen() + key := opts.Key + if len(key) > 64 { + return fmt.Errorf("Key maxlength is 64 characters") + } + + go listen(key) return nil } -func listen() { +func listen(key string) { buf := make([]byte, 1) + keyBuf := make([]byte, 64) for { fd, err := listener.Accept() if err != nil { @@ -122,10 +132,28 @@ func listen() { } continue } + if key != "" { + if _, err := fd.Read(keyBuf); err != nil { + fmt.Fprintf(os.Stderr, "gops: %v", err) + continue + } + if !verify(keyBuf, key) { + fmt.Fprintf(os.Stderr, "gops: access denied. client: %s\n", fd.RemoteAddr()) + fd.Write([]byte{0}) // login failed + fd.Write([]byte("access denied. Please set right GOPS_KEY\n")) + fd.Close() + continue + + } else { + fd.Write([]byte{1}) // login success + } + } + if _, err := fd.Read(buf); err != nil { fmt.Fprintf(os.Stderr, "gops: %v", err) continue } + if err := handle(fd, buf); err != nil { fmt.Fprintf(os.Stderr, "gops: %v", err) continue @@ -175,6 +203,11 @@ func formatBytes(val uint64) string { return fmt.Sprintf("%d bytes", val) } +func verify(input []byte, key string) bool { + input = bytes.TrimRight(input, string([]byte{0})) + return bytes.Compare(input, []byte(key)) == 0 +} + func handle(conn io.ReadWriter, msg []byte) error { switch msg[0] { case signal.StackTrace: diff --git a/agent/agent_test.go b/agent/agent_test.go index fab74e6c..a4ed948f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -6,6 +6,7 @@ package agent import ( "os" + "strings" "testing" ) @@ -76,3 +77,27 @@ func TestFormatBytes(t *testing.T) { } } } + +func TestVerify(t *testing.T) { + inputFunc := func(text string) []byte { + input := make([]byte, 64) + copy(input, []byte(text)) + return input + } + + tests := []struct { + input []byte + key string + want bool + }{ + {inputFunc("abc"), "abc", true}, + {inputFunc("ab"), "abc", false}, + {[]byte{0x01}, "abc", false}, + {inputFunc(strings.Repeat("1", 64)), strings.Repeat("1", 64), true}, + } + for _, tt := range tests { + if got := verify(tt.input, tt.key); got != tt.want { + t.Errorf("verify(%v, %v) = %v; want %v", tt.input, tt.key, got, tt.want) + } + } +} diff --git a/cmd.go b/cmd.go index 63e27f94..a6819a6a 100644 --- a/cmd.go +++ b/cmd.go @@ -206,6 +206,22 @@ func cmdLazy(addr net.TCPAddr, c byte, params ...byte) (io.Reader, error) { if err != nil { return nil, err } + key := os.Getenv("GOPS_KEY") + if key != "" { + keyBuf := make([]byte, 64) + restBuf := make([]byte, 1) + copy(keyBuf, []byte(key)) + if _, err := conn.Write(keyBuf); err != nil { + return nil, err + } + if _, err := conn.Read(restBuf); err != nil { + return nil, err + } + if restBuf[0] == 0 { // login failed. + return conn, nil + } + } + buf := []byte{c} buf = append(buf, params...) if _, err := conn.Write(buf); err != nil { diff --git a/gops b/gops new file mode 100755 index 00000000..0d8d5b9e Binary files /dev/null and b/gops differ