Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: avoid deprecated NATS API #470

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions cmd/ssh-portal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
// get nats server connection
nconn, err := nats.Connect(cmd.NATSServer,
nc, err := nats.Connect(cmd.NATSServer,
nats.Name("ssh-portal"),
// exit on connection close
nats.ClosedHandler(func(_ *nats.Conn) {
Expand All @@ -52,10 +52,6 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't connect to NATS server: %v", err)
}
nc, err := nats.NewEncodedConn(nconn, "json")
if err != nil {
return fmt.Errorf("couldn't get encoded conn: %v", err)
}
defer nc.Close()
// start listening on TCP port
l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort))
Expand Down Expand Up @@ -83,7 +79,15 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
eg.Go(func() error {
// start serving SSH connection requests
return sshserver.Serve(
ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled, cmd.Banner)
ctx,
log,
nc,
l,
c,
hostkeys,
cmd.LogAccessEnabled,
cmd.Banner,
)
})
return eg.Wait()
}
15 changes: 7 additions & 8 deletions internal/sshportalapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ServeNATS(
wg := sync.WaitGroup{}
wg.Add(1)
// connect to NATS server
nconn, err := nats.Connect(natsURL,
nc, err := nats.Connect(natsURL,
nats.Name("ssh-portal-api"),
// synchronise exiting ServeNATS()
nats.ClosedHandler(func(_ *nats.Conn) {
Expand All @@ -67,14 +67,13 @@ func ServeNATS(
if err != nil {
return fmt.Errorf("couldn't connect to NATS server: %v", err)
}
nc, err := nats.NewEncodedConn(nconn, "json")
if err != nil {
return fmt.Errorf("couldn't get encoded conn: %v", err)
}
defer nc.Close()
// set up request/response callback for sshportal
_, err = nc.QueueSubscribe(bus.SubjectSSHAccessQuery, queue,
sshportal(ctx, log, nc, p, l, k))
// configure callback
_, err = nc.QueueSubscribe(
bus.SubjectSSHAccessQuery,
queue,
sshportal(ctx, log, nc, p, l, k),
)
if err != nil {
return fmt.Errorf("couldn't subscribe to queue: %v", err)
}
Expand Down
28 changes: 21 additions & 7 deletions internal/sshportalapi/sshportal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sshportalapi

import (
"context"
"encoding/json"
"errors"
"log/slog"
"time"
Expand All @@ -23,20 +24,30 @@ var (
})
)

var (
falseResponse = []byte(`false`)
trueResponse = []byte(`true`)
)

func sshportal(
ctx context.Context,
log *slog.Logger,
c *nats.EncodedConn,
c *nats.Conn,
p *rbac.Permission,
l LagoonDBService,
k KeycloakService,
) nats.Handler {
return func(_, replySubject string, query *bus.SSHAccessQuery) {
) nats.MsgHandler {
return func(msg *nats.Msg) {
var realmRoles, userGroups []string
// set up tracing and update metrics
ctx, span := otel.Tracer(pkgName).Start(ctx, bus.SubjectSSHAccessQuery)
defer span.End()
requestsCounter.Inc()
var query bus.SSHAccessQuery
if err := json.Unmarshal(msg.Data, &query); err != nil {
log.Warn("couldn't unmarshal query", slog.Any("query", msg.Data))
return
}
log := log.With(slog.Any("query", query))
// sanity check the query
if query.SSHFingerprint == "" || query.NamespaceName == "" {
Expand All @@ -48,7 +59,7 @@ func sshportal(
if err != nil {
if errors.Is(err, lagoondb.ErrNoResult) {
log.Warn("unknown namespace name", slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand All @@ -65,7 +76,7 @@ func sshportal(
log.Warn("ID mismatch in environment identification",
slog.Any("env", env),
slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand All @@ -75,7 +86,7 @@ func sshportal(
if err != nil {
if errors.Is(err, lagoondb.ErrNoResult) {
log.Debug("unknown SSH Fingerprint", slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand Down Expand Up @@ -115,10 +126,13 @@ func sshportal(
ok := p.UserCanSSHToEnvironment(
ctx, env, realmRoles, userGroups, groupNameProjectIDsMap)
var logMsg string
var response []byte
if ok {
logMsg = "SSH access authorized"
response = trueResponse
} else {
logMsg = "SSH access not authorized"
response = falseResponse
}
log.Info(logMsg,
slog.Int("environmentID", env.ID),
Expand All @@ -127,7 +141,7 @@ func sshportal(
slog.String("projectName", env.ProjectName),
slog.String("userUUID", user.UUID.String()),
)
if err = c.Publish(replySubject, ok); err != nil {
if err = c.Publish(msg.Reply, response); err != nil {
log.Error("couldn't publish reply",
slog.String("userUUID", user.UUID.String()),
slog.Any("error", err))
Expand Down
27 changes: 27 additions & 0 deletions internal/sshportalapi/sshportal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package sshportalapi

import (
"encoding/json"
"testing"
)

func TestResponseMarshal(t *testing.T) {
var testCases = map[string]struct {
input []byte
expect bool
}{
"true": {input: trueResponse, expect: true},
"false": {input: falseResponse, expect: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
var value bool
if err := json.Unmarshal(tc.input, &value); err != nil {
tt.Fatalf("error unmarshaling data %v to bool", tc.input)
}
if value != tc.expect {
tt.Fatalf("expected %v, got %v", tc.expect, value)
}
})
}
}
22 changes: 17 additions & 5 deletions internal/sshserver/authhandler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sshserver

import (
"encoding/json"
"log/slog"
"time"

Expand Down Expand Up @@ -40,8 +41,11 @@ var (

// pubKeyAuth returns a ssh.PublicKeyHandler which queries the remote
// ssh-portal-api for Lagoon SSH authorization.
func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
c *k8s.Client) ssh.PublicKeyHandler {
func pubKeyAuth(
log *slog.Logger,
nc *nats.Conn,
c *k8s.Client,
) ssh.PublicKeyHandler {
return func(ctx ssh.Context, key ssh.PublicKey) bool {
authAttemptsTotal.Inc()
log := log.With(slog.String("sessionID", ctx.SessionID()))
Expand All @@ -60,21 +64,29 @@ func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
}
// construct ssh access query
fingerprint := gossh.FingerprintSHA256(pubKey)
q := bus.SSHAccessQuery{
queryData, err := json.Marshal(bus.SSHAccessQuery{
SSHFingerprint: fingerprint,
NamespaceName: ctx.User(),
ProjectID: pid,
EnvironmentID: eid,
SessionID: ctx.SessionID(),
})
if err != nil {
log.Warn("couldn't marshal NATS request", slog.Any("error", err))
return false
}
// send query
var ok bool
err = nc.Request(bus.SubjectSSHAccessQuery, q, &ok, natsTimeout)
msg, err := nc.Request(bus.SubjectSSHAccessQuery, queryData, natsTimeout)
if err != nil {
log.Warn("couldn't make NATS request", slog.Any("error", err))
return false
}
// handle response
var ok bool
if err := json.Unmarshal(msg.Data, &ok); err != nil {
log.Warn("couldn't unmarshal response", slog.Any("response", msg.Data))
return false
}
if !ok {
log.Debug("SSH access not authorized",
slog.String("fingerprint", fingerprint),
Expand Down
2 changes: 1 addition & 1 deletion internal/sshserver/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig {
func Serve(
ctx context.Context,
log *slog.Logger,
nc *nats.EncodedConn,
nc *nats.Conn,
l net.Listener,
c *k8s.Client,
hostKeys [][]byte,
Expand Down
Loading