Skip to content

Commit

Permalink
refactor: 重写ssh
Browse files Browse the repository at this point in the history
  • Loading branch information
devhaozi committed Oct 19, 2024
1 parent 1a7f679 commit 47b92a8
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 178 deletions.
29 changes: 9 additions & 20 deletions internal/service/ssh.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package service

import (
"bytes"
"context"
"net/http"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/spf13/cast"
"go.uber.org/zap"

"github.com/TheTNB/panel/internal/app"
"github.com/TheTNB/panel/internal/biz"
"github.com/TheTNB/panel/internal/data"
"github.com/TheTNB/panel/internal/http/request"
Expand Down Expand Up @@ -74,11 +75,9 @@ func (s *SSHService) Session(w http.ResponseWriter, r *http.Request) {
cast.ToString(info["password"]),
)
client, err := ssh.NewSSHClient(config)

if err != nil {
_ = ws.WriteControl(websocket.CloseMessage,
[]byte(err.Error()), time.Now().Add(time.Second))
ErrorSystem(w)
return
}
defer client.Close()
Expand All @@ -87,38 +86,28 @@ func (s *SSHService) Session(w http.ResponseWriter, r *http.Request) {
if err != nil {
_ = ws.WriteControl(websocket.CloseMessage,
[]byte(err.Error()), time.Now().Add(time.Second))
ErrorSystem(w)
return
}
defer turn.Close()

var bufPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
var logBuff = bufPool.Get().(*bytes.Buffer)
logBuff.Reset()
defer bufPool.Put(logBuff)

sshCtx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background())
wg := sync.WaitGroup{}
wg.Add(2)

go func() {
defer wg.Done()
if err = turn.LoopRead(logBuff, sshCtx); err != nil {
ErrorSystem(w)
if err = turn.Handle(ctx); err != nil {
app.Logger.Error("读取 ssh 数据失败", zap.Error(err))
return
}
}()
go func() {
defer wg.Done()
if err = turn.SessionWait(); err != nil {
ErrorSystem(w)
return
if err = turn.Wait(); err != nil {
app.Logger.Error("保持 ssh 会话失败", zap.Error(err))
}
cancel()
}()
wg.Wait()

wg.Wait()
}
21 changes: 11 additions & 10 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package ssh

import (
"os"
"time"

"golang.org/x/crypto/ssh"

"github.com/TheTNB/panel/pkg/io"
)

type AuthMethod int8
Expand Down Expand Up @@ -45,11 +44,12 @@ func ClientConfigPublicKey(hostAddr, user, keyPath string) *ClientConfig {
}

func NewSSHClient(conf *ClientConfig) (*ssh.Client, error) {
config := &ssh.ClientConfig{
Timeout: conf.Timeout,
User: conf.User,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
config := &ssh.ClientConfig{}
config.SetDefaults()
config.Timeout = conf.Timeout
config.User = conf.User
config.HostKeyCallback = ssh.InsecureIgnoreHostKey()

switch conf.AuthMethod {
case PASSWORD:
config.Auth = []ssh.AuthMethod{ssh.Password(conf.Password)}
Expand All @@ -60,18 +60,19 @@ func NewSSHClient(conf *ClientConfig) (*ssh.Client, error) {
}
config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}
c, err := ssh.Dial("tcp", conf.HostAddr, config)
c, err := ssh.Dial("tcp", conf.HostAddr, config) // TODO support ipv6
if err != nil {
return nil, err
}

return c, nil
}

func getKey(keyPath string) (ssh.Signer, error) {
key, err := io.Read(keyPath)
key, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}

return ssh.ParsePrivateKey([]byte(key))
return ssh.ParsePrivateKey(key)
}
107 changes: 36 additions & 71 deletions pkg/ssh/turn.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package ssh

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -13,29 +11,30 @@ import (
"golang.org/x/crypto/ssh"
)

const (
MsgData = '1'
MsgResize = '2'
)
type MessageResize struct {
Resize bool `json:"resize"`
Columns int `json:"columns"`
Rows int `json:"rows"`
}

type Turn struct {
StdinPipe io.WriteCloser
Session *ssh.Session
WsConn *websocket.Conn
stdin io.WriteCloser
session *ssh.Session
ws *websocket.Conn
}

func NewTurn(wsConn *websocket.Conn, sshClient *ssh.Client) (*Turn, error) {
sess, err := sshClient.NewSession()
func NewTurn(ws *websocket.Conn, client *ssh.Client) (*Turn, error) {
sess, err := client.NewSession()
if err != nil {
return nil, err
}

stdinPipe, err := sess.StdinPipe()
stdin, err := sess.StdinPipe()
if err != nil {
return nil, err
}

turn := &Turn{StdinPipe: stdinPipe, Session: sess, WsConn: wsConn}
turn := &Turn{stdin: stdin, session: sess, ws: ws}
sess.Stdout = turn
sess.Stderr = turn

Expand All @@ -44,18 +43,18 @@ func NewTurn(wsConn *websocket.Conn, sshClient *ssh.Client) (*Turn, error) {
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
if err := sess.RequestPty("xterm", 150, 30, modes); err != nil {
if err = sess.RequestPty("xterm", 150, 80, modes); err != nil {
return nil, err
}
if err := sess.Shell(); err != nil {
if err = sess.Shell(); err != nil {
return nil, err
}

return turn, nil
}

func (t *Turn) Write(p []byte) (n int, err error) {
writer, err := t.WsConn.NextWriter(websocket.BinaryMessage)
writer, err := t.ws.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err
}
Expand All @@ -65,76 +64,42 @@ func (t *Turn) Write(p []byte) (n int, err error) {
}

func (t *Turn) Close() error {
if t.Session != nil {
t.Session.Close()
if t.session != nil {
_ = t.session.Close()
}

return t.WsConn.Close()
return t.ws.Close()
}

func (t *Turn) Read(p []byte) (n int, err error) {
for {
msgType, reader, err := t.WsConn.NextReader()
if err != nil {
return 0, err
}
if msgType != websocket.BinaryMessage {
continue
}

return reader.Read(p)
}
}

func (t *Turn) LoopRead(logBuff *bytes.Buffer, context context.Context) error {
func (t *Turn) Handle(context context.Context) error {
var resize MessageResize
for {
select {
case <-context.Done():
return errors.New("LoopRead exit")
return errors.New("ssh context done exit")
default:
_, wsData, err := t.WsConn.ReadMessage()
_, data, err := t.ws.ReadMessage()
if err != nil {
return fmt.Errorf("reading webSocket message err:%s", err)
return fmt.Errorf("reading ws message err: %v", err)
}
body := decode(wsData[1:])
switch wsData[0] {
case MsgResize:
var args Resize
err := json.Unmarshal(body, &args)
if err != nil {
return fmt.Errorf("ssh pty resize windows err:%s", err)
}
if args.Columns > 0 && args.Rows > 0 {
if err := t.Session.WindowChange(args.Rows, args.Columns); err != nil {
return fmt.Errorf("ssh pty resize windows err:%s", err)

// 判断是否是 resize 消息
if err = json.Unmarshal(data, &resize); err == nil {
if resize.Resize && resize.Columns > 0 && resize.Rows > 0 {
if err = t.session.WindowChange(resize.Rows, resize.Columns); err != nil {
return fmt.Errorf("change window size err: %v", err)
}
}
case MsgData:
if _, err := t.StdinPipe.Write(body); err != nil {
return fmt.Errorf("StdinPipe write err:%s", err)
}
if _, err := logBuff.Write(body); err != nil {
return fmt.Errorf("logBuff write err:%s", err)
}
continue
}
}
}
}

func (t *Turn) SessionWait() error {
if err := t.Session.Wait(); err != nil {
return err
if _, err = t.stdin.Write(data); err != nil {
return fmt.Errorf("writing ws message to stdin err: %v", err)
}
}
}

return nil
}

func decode(p []byte) []byte {
decodeString, _ := base64.StdEncoding.DecodeString(string(p))
return decodeString
}

type Resize struct {
Columns int
Rows int
func (t *Turn) Wait() error {
return t.session.Wait()
}
6 changes: 4 additions & 2 deletions web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
"dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4",
"@vueuse/core": "^11.1.0",
"@xterm/addon-attach": "^0.11.0",
"@xterm/addon-clipboard": "^0.1.0",
"@xterm/addon-fit": "^0.10.0",
"@xterm/addon-web-links": "^0.11.0",
"@xterm/addon-webgl": "^0.18.0",
"@xterm/xterm": "^5.5.0",
"axios": "^1.7.7",
"crypto-js": "^4.2.0",
"echarts": "^5.5.1",
"install": "^0.13.0",
"lodash-es": "^4.17.21",
Expand All @@ -44,7 +47,6 @@
"@iconify/vue": "^4.1.2",
"@rushstack/eslint-patch": "^1.10.4",
"@tsconfig/node20": "^20.1.4",
"@types/crypto-js": "^4.2.2",
"@types/lodash-es": "^4.17.12",
"@types/luxon": "^3.4.2",
"@types/node": "^20.16.11",
Expand Down
Loading

0 comments on commit 47b92a8

Please sign in to comment.