-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
57ef498
commit 118df7b
Showing
6 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
package ssh | ||
|
||
type ForwardMessage = forwardMessage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
// Copyright 2025 Canonical. | ||
|
||
package ssh | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
|
||
"github.com/canonical/jimm/v3/internal/openfga" | ||
"github.com/gliderlabs/ssh" | ||
"github.com/juju/zaputil/zapctx" | ||
"go.uber.org/zap" | ||
gossh "golang.org/x/crypto/ssh" | ||
) | ||
|
||
// JUJU_SSH_DEFAULT_PORT is the default port we expect the juju controllers to respond on. | ||
const JUJU_SSH_DEFAULT_PORT = 2223 | ||
|
||
// Resolver is the interface with the methods needed by the ssh jump server to route request. | ||
type Resolver interface { | ||
// GetAddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. | ||
GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error) | ||
} | ||
|
||
// fowardMessage is the struct holding the information about the jump message received by the ssh client. | ||
type forwardMessage struct { | ||
DestAddr string | ||
DestPort uint32 | ||
SrcAddr string | ||
SrcPort uint32 | ||
} | ||
|
||
// Server is the custom struct to embed the gliderlabs.ssh server and a resolver. | ||
type Server struct { | ||
*ssh.Server | ||
|
||
resolver Resolver | ||
} | ||
|
||
// NewJumpSSHServer creates the jump server struct. | ||
func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) { | ||
zapctx.Info(ctx, "NewSSHServer") | ||
server := Server{ | ||
Server: &ssh.Server{ | ||
Addr: fmt.Sprintf(":%d", port), | ||
ChannelHandlers: map[string]ssh.ChannelHandler{ | ||
"direct-tcpip": directTCPIPHandler(resolver), | ||
}, | ||
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { | ||
return true | ||
}, | ||
}, | ||
resolver: resolver, | ||
} | ||
|
||
return server, nil | ||
} | ||
|
||
func directTCPIPHandler(resolver Resolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { | ||
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { | ||
d := forwardMessage{} | ||
|
||
k := newChan.ExtraData() | ||
|
||
if err := gossh.Unmarshal(k, &d); err != nil { | ||
zapctx.Error(ctx, "Failed to parse channel data", zap.Error(err)) | ||
newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data") | ||
return | ||
} | ||
|
||
dest := fmt.Sprintf("%s:%d", d.DestAddr, d.DestPort) | ||
if d.DestPort == 0 { | ||
d.DestPort = JUJU_SSH_DEFAULT_PORT | ||
} | ||
addr, err := resolver.GetAddrFromModelUUID(ctx, openfga.User{}, dest) | ||
|
||
// this is temporary. The way we dial to the controller will heavily change. | ||
client, err := gossh.Dial("tcp", fmt.Sprintf("%s:%d", addr, d.DestPort), &gossh.ClientConfig{ | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PasswordCallback(func() (secret string, err error) { | ||
return "jwt", nil | ||
}), | ||
}, | ||
}) | ||
if err != nil { | ||
zapctx.Error(ctx, fmt.Sprintf("Failed to connect to %s: %v", dest, err), zap.Error(err)) | ||
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("Failed to connect to %s: %v", dest, err)) | ||
return | ||
} | ||
|
||
dChan, reqs, err := client.OpenChannel("direct-tcpip", gossh.Marshal(d)) | ||
if err != nil { | ||
zapctx.Error(ctx, "Failed to open destination channel", zap.Error(err)) | ||
newChan.Reject(gossh.ConnectionFailed, "Failed to open destination channel") | ||
return | ||
} | ||
|
||
go gossh.DiscardRequests(reqs) | ||
|
||
ch, reqs, err := newChan.Accept() | ||
if err != nil { | ||
dChan.Close() | ||
return | ||
} | ||
|
||
go gossh.DiscardRequests(reqs) | ||
|
||
go func() { | ||
defer ch.Close() | ||
defer dChan.Close() | ||
io.Copy(ch, dChan) | ||
}() | ||
go func() { | ||
defer ch.Close() | ||
defer dChan.Close() | ||
io.Copy(dChan, ch) | ||
}() | ||
zapctx.Info(ctx, fmt.Sprintf("Proxying connection from %s:%d to %s:%d \n", d.SrcAddr, d.SrcPort, d.DestAddr, d.DestPort)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
package ssh_test | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"encoding/pem" | ||
"fmt" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/canonical/jimm/v3/internal/openfga" | ||
"github.com/canonical/jimm/v3/internal/ssh" | ||
"github.com/canonical/jimm/v3/internal/utils" | ||
qt "github.com/frankban/quicktest" | ||
"github.com/frankban/quicktest/qtsuite" | ||
|
||
gliderssh "github.com/gliderlabs/ssh" | ||
gossh "golang.org/x/crypto/ssh" | ||
|
||
"testing" | ||
) | ||
|
||
type resolver struct{} | ||
|
||
func (r resolver) GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelName string) (string, error) { | ||
return "", nil | ||
} | ||
|
||
type sshSuite struct { | ||
destinationJujuSSHServer gliderssh.Server | ||
destinationServerPort int | ||
jumpSSHServer ssh.Server | ||
jumpServerPort int | ||
privateKey gossh.Signer | ||
testF func(fm ssh.ForwardMessage) | ||
received chan bool | ||
} | ||
|
||
func (s *sshSuite) Init(c *qt.C) { | ||
s.received = make(chan bool) | ||
port, err := utils.GetFreePort() | ||
c.Assert(err, qt.IsNil) | ||
s.destinationServerPort = port | ||
s.destinationJujuSSHServer = gliderssh.Server{ | ||
Addr: fmt.Sprintf(":%d", port), | ||
ChannelHandlers: map[string]gliderssh.ChannelHandler{ | ||
"direct-tcpip": func(srv *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) { | ||
d := ssh.ForwardMessage{} | ||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { | ||
newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data") | ||
return | ||
} | ||
newChan.Accept() | ||
s.testF(d) | ||
s.received <- true | ||
}, | ||
}, | ||
} | ||
go func() { s.destinationJujuSSHServer.ListenAndServe() }() | ||
s.destinationServerPort, err = strconv.Atoi(strings.Split(s.destinationJujuSSHServer.Addr, ":")[1]) | ||
c.Assert(err, qt.IsNil) | ||
|
||
port, err = utils.GetFreePort() | ||
c.Assert(err, qt.IsNil) | ||
s.jumpServerPort = port | ||
s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{}) | ||
c.Assert(err, qt.IsNil) | ||
go func() { s.jumpSSHServer.ListenAndServe() }() | ||
|
||
k, err := rsa.GenerateKey(rand.Reader, 2048) | ||
c.Assert(err, qt.IsNil) | ||
keyPEM := pem.EncodeToMemory( | ||
&pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(k), | ||
}, | ||
) | ||
|
||
s.privateKey, err = gossh.ParsePrivateKey(keyPEM) | ||
c.Assert(err, qt.IsNil) | ||
} | ||
|
||
// CleanUp doesn't exist in qtsuite, so it needs to be called manually | ||
func (s *sshSuite) CleanUp(c *qt.C) { | ||
err := s.destinationJujuSSHServer.Close() | ||
c.Assert(err, qt.IsNil) | ||
err = s.jumpSSHServer.Close() | ||
c.Assert(err, qt.IsNil) | ||
} | ||
|
||
func (s *sshSuite) TestSSHJump(c *qt.C) { | ||
defer s.CleanUp(c) | ||
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
defer client.Close() | ||
|
||
// send forward message | ||
msg := ssh.ForwardMessage{ | ||
DestAddr: "model1", | ||
DestPort: uint32(s.destinationServerPort), | ||
SrcAddr: "localhost", | ||
SrcPort: 0, | ||
} | ||
s.testF = func(fm ssh.ForwardMessage) { | ||
c.Assert(fm.DestAddr, qt.Equals, "model1") | ||
} | ||
ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) | ||
c.Assert(err, qt.IsNil) | ||
defer ch.Close() | ||
select { | ||
case <-s.received: | ||
case <-time.After(100 * time.Millisecond): | ||
c.Fatalf("ssh jump test timeout") | ||
} | ||
} | ||
|
||
func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { | ||
defer s.CleanUp(c) | ||
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort+1), &gossh.ClientConfig{ | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.ErrorMatches, ".*connect: connection refused.*") | ||
} | ||
|
||
func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { | ||
defer s.CleanUp(c) | ||
|
||
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
|
||
// send forward message | ||
msg := ssh.ForwardMessage{ | ||
DestAddr: "model1", | ||
DestPort: uint32(s.destinationServerPort + 1), | ||
SrcAddr: "localhost", | ||
SrcPort: 0, | ||
} | ||
s.testF = func(fm ssh.ForwardMessage) { | ||
c.Assert(fm.DestAddr, qt.Equals, "model1") | ||
} | ||
_, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) | ||
c.Assert(err, qt.ErrorMatches, ".*connect failed.*") | ||
|
||
} | ||
|
||
func TestIdentityManager(t *testing.T) { | ||
qtsuite.Run(qt.New(t), &sshSuite{}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters