Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nick/neos 1168 add ssh tunneling support to mongodb #2182

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions backend/pkg/mongoconnect/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"errors"
"fmt"
"log/slog"
"net/url"
"sync"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
Expand Down Expand Up @@ -34,16 +35,15 @@
client *mongo.Client
clientMu sync.Mutex

details *connstring.ConnString
// tunnel *sshtunnel.Sshtunnel
details *ConnectionDetails

// logger *slog.Logger
logger *slog.Logger
}

var _ DbContainer = &WrappedMongoClient{}

func newWrappedMongoClient(details *connstring.ConnString) *WrappedMongoClient {
return &WrappedMongoClient{details: details}
func newWrappedMongoClient(details *ConnectionDetails, logger *slog.Logger) *WrappedMongoClient {
return &WrappedMongoClient{details: details, logger: logger}

Check warning on line 46 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L45-L46

Added lines #L45 - L46 were not covered by tests
}

func (w *WrappedMongoClient) Open(ctx context.Context) (*mongo.Client, error) {
Expand All @@ -52,12 +52,21 @@
if w.client != nil {
return w.client, nil
}
// todo: tunneling

if w.details.Tunnel != nil {
ready, err := w.details.Tunnel.Start(w.logger)
if err != nil {
return nil, err

Check warning on line 59 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L56-L59

Added lines #L56 - L59 were not covered by tests
}
<-ready
w.logger.Info("tunnel is now ready", "isopen", w.details.Tunnel.IsOpen())

Check warning on line 62 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L61-L62

Added lines #L61 - L62 were not covered by tests
}
serverAPI := options.ServerAPI(options.ServerAPIVersion1)
w.logger.Info("connecting to mongo instance", "url", w.details.String())

Check warning on line 65 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L65

Added line #L65 was not covered by tests
opts := options.Client().ApplyURI(w.details.String()).SetServerAPIOptions(serverAPI)
client, err := mongo.Connect(ctx, opts)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to connect to mongo instance: %w", err)

Check warning on line 69 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L69

Added line #L69 was not covered by tests
}
w.client = client
return client, nil
Expand All @@ -71,7 +80,12 @@
}
client := w.client
w.client = nil
return client.Disconnect(ctx)
err := client.Disconnect(ctx)
if w.details.Tunnel != nil && w.details.Tunnel.IsOpen() {
w.logger.Debug("closing tunnel...")
w.details.Tunnel.Close()

Check warning on line 86 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L83-L86

Added lines #L83 - L86 were not covered by tests
}
return err

Check warning on line 88 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L88

Added line #L88 was not covered by tests
}

var _ Interface = &Connector{}
Expand All @@ -94,7 +108,7 @@
if err != nil {
return nil, err
}
wrappedclient := newWrappedMongoClient(details.Details)
wrappedclient := newWrappedMongoClient(details, logger)

Check warning on line 111 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L111

Added line #L111 was not covered by tests
return wrappedclient, nil
}

Expand All @@ -107,7 +121,16 @@
return c.Tunnel
}
func (c *ConnectionDetails) String() string {
// todo: add tunnel support
if c.Tunnel != nil && c.Tunnel.IsOpen() {
localhost, port := c.Tunnel.GetLocalHostPort()
parseUrl, err := url.Parse(c.Details.String())
if err != nil {
return "" // todo

Check warning on line 128 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L124-L128

Added lines #L124 - L128 were not covered by tests
}
parseUrl.Host = fmt.Sprintf("%s:%d", localhost, port)
parseUrl.Scheme = "mongodb"
return parseUrl.String()

Check warning on line 132 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L130-L132

Added lines #L130 - L132 were not covered by tests
}
return c.Details.String()
}

Expand All @@ -119,7 +142,6 @@
if cc == nil {
return nil, errors.New("cc was nil, expected *mgmtv1alpha1.ConnectionConfig")
}

mongoConfig := cc.GetMongoConfig()
if mongoConfig == nil {
return nil, fmt.Errorf("mongo config was nil, expected ConnectionConfig to contain valid MongoConfig")
Expand All @@ -142,13 +164,16 @@
}, nil
}

var destination *sshtunnel.Endpoint // todo
destination, err := getEndpointFromMongoConnectionConfig(mongoConfig)
if err != nil {
return nil, err

Check warning on line 169 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L167-L169

Added lines #L167 - L169 were not covered by tests
}
authmethod, err := sshtunnel.GetTunnelAuthMethodFromSshConfig(tunnelCfg.GetAuthentication())
if err != nil {
return nil, err
}
var publickey ssh.PublicKey
if tunnelCfg.GetKnownHostPublicKey() == "" {
if tunnelCfg.GetKnownHostPublicKey() != "" {

Check warning on line 176 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L176

Added line #L176 was not covered by tests
publickey, err = sshtunnel.ParseSshKey(tunnelCfg.GetKnownHostPublicKey())
if err != nil {
return nil, err
Expand All @@ -166,7 +191,6 @@
if err != nil {
return nil, err
}
_ = connDetails

return &ConnectionDetails{
Tunnel: tunnel,
Expand All @@ -181,3 +205,11 @@
}
return connstring.ParseAndValidate(dburl)
}

func getEndpointFromMongoConnectionConfig(config *mgmtv1alpha1.MongoConnectionConfig) (*sshtunnel.Endpoint, error) {
details, err := getGeneralDbConnectConfigFromMongo(config)
if err != nil {
return nil, err

Check warning on line 212 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L209-L212

Added lines #L209 - L212 were not covered by tests
}
return sshtunnel.NewEndpointWithUser(details.Hosts[0], -1, details.Username), nil

Check warning on line 214 in backend/pkg/mongoconnect/connector.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/mongoconnect/connector.go#L214

Added line #L214 was not covered by tests
}
3 changes: 3 additions & 0 deletions backend/pkg/sshtunnel/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@

// Returns the stringified endpoint sans user
func (endpoint *Endpoint) String() string {
if endpoint.Port < 0 {
return endpoint.Host

Check warning on line 32 in backend/pkg/sshtunnel/endpoint.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/sshtunnel/endpoint.go#L32

Added line #L32 was not covered by tests
}
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
7 changes: 6 additions & 1 deletion backend/pkg/sshtunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
}
}

func (t *Sshtunnel) IsOpen() bool {
return t.isOpen

Check warning on line 66 in backend/pkg/sshtunnel/tunnel.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/sshtunnel/tunnel.go#L65-L66

Added lines #L65 - L66 were not covered by tests
}

// After a tunnel has started, this will return the auto-generated port (if 0 was passed in)
func (t *Sshtunnel) GetLocalHostPort() (host string, port int) {
return t.local.Host, t.local.Port
Expand Down Expand Up @@ -108,6 +112,7 @@
t.isOpen = false
go func() {
t.shutdowns.Range(func(key, value any) bool {
logger.Debug("shutting down tunnel session", "key", key)

Check warning on line 115 in backend/pkg/sshtunnel/tunnel.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/sshtunnel/tunnel.go#L115

Added line #L115 was not covered by tests
sd, ok := value.(chan any)
if ok {
sd <- struct{}{}
Expand Down Expand Up @@ -208,7 +213,7 @@
if err != nil {
return nil, err
}
logger.Debug(fmt.Sprintf("conntected to %s", addr))
logger.Debug(fmt.Sprintf("[ssh-client] conntected to %s", addr))

Check warning on line 216 in backend/pkg/sshtunnel/tunnel.go

View check run for this annotation

Codecov / codecov/patch

backend/pkg/sshtunnel/tunnel.go#L216

Added line #L216 was not covered by tests
s.sshclient = client
return client, nil
}
Expand Down
2 changes: 1 addition & 1 deletion backend/pkg/sshtunnel/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func getPlaintextPrivateKeyAuthMethod(keyBytes []byte) (ssh.AuthMethod, error) {

func ParseSshKey(keyString string) (ssh.PublicKey, error) {
// Parse the key
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString)) //nolint
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString)) //nolint:dogsled
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ import Spinner from '@/components/Spinner';
import RequiredLabel from '@/components/labels/RequiredLabel';
import PermissionsDialog from '@/components/permissions/PermissionsDialog';
import { useAccount } from '@/components/providers/account-provider';
import {
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from '@/components/ui/accordion';
import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert';
import { Button } from '@/components/ui/button';
import {
Expand All @@ -16,7 +22,11 @@ import {
FormMessage,
} from '@/components/ui/form';
import { Input } from '@/components/ui/input';
import { MongoDbFormValues } from '@/yup-validations/connections';
import { Textarea } from '@/components/ui/textarea';
import {
EditMongoDbFormContext,
MongoDbFormValues,
} from '@/yup-validations/connections';
import { yupResolver } from '@hookform/resolvers/yup';
import {
CheckConnectionConfigResponse,
Expand All @@ -38,7 +48,7 @@ export default function MongoDbForm(props: Props): ReactElement {
const { connectionId, defaultValues, onSaved, onSaveFailed } = props;
const { account } = useAccount();

const form = useForm<MongoDbFormValues>({
const form = useForm<MongoDbFormValues, EditMongoDbFormContext>({
resolver: yupResolver(MongoDbFormValues),
mode: 'onChange',
values: defaultValues,
Expand Down Expand Up @@ -137,6 +147,133 @@ export default function MongoDbForm(props: Props): ReactElement {
)}
/>

<Accordion type="single" collapsible className="w-full">
<AccordionItem value="bastion">
<AccordionTrigger> Bastion Host Configuration</AccordionTrigger>
<AccordionContent className="flex flex-col gap-4 p-2">
<div className="text-sm">
This section is optional and only necessary if your database is
not publicly accessible to the internet.
</div>
<FormField
control={form.control}
name="tunnel.host"
render={({ field }) => (
<FormItem>
<FormLabel>Host</FormLabel>
<FormDescription>
The hostname of the bastion server that will be used for
SSH tunneling.
</FormDescription>
<FormControl>
<Input placeholder="bastion.example.com" {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="tunnel.port"
render={({ field }) => (
<FormItem>
<FormLabel>Port</FormLabel>
<FormDescription>
The port of the bastion host. Typically this is port 22.
</FormDescription>
<FormControl>
<Input
type="number"
placeholder="22"
{...field}
onChange={(e) => {
field.onChange(e.target.valueAsNumber);
}}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="tunnel.user"
render={({ field }) => (
<FormItem>
<FormLabel>User</FormLabel>
<FormDescription>
The name of the user that will be used to authenticate.
</FormDescription>
<FormControl>
<Input {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="tunnel.privateKey"
render={({ field }) => (
<FormItem>
<FormLabel>Private Key</FormLabel>
<FormDescription>
The private key that will be used to authenticate against
the SSH server. If using passphrase auth, provide that in
the appropriate field below.
</FormDescription>
<FormControl>
<Textarea {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="tunnel.passphrase"
render={({ field }) => (
<FormItem>
<FormLabel>Passphrase / Private Key Password</FormLabel>
<FormDescription>
The passphrase that will be used to authenticate with. If
the SSH Key provided above is encrypted, provide the
password for it here.
</FormDescription>
<FormControl>
<Input type="password" {...field} />
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="tunnel.knownHostPublicKey"
render={({ field }) => (
<FormItem>
<FormLabel>Known Host Public Key</FormLabel>
<FormDescription>
The public key of the host that will be expected when
connecting to the tunnel. This should be in the format
like what is found in the `~/.ssh/known_hosts` file,
excluding the hostname. If this is not provided, any host
public key will be accepted.
</FormDescription>
<FormControl>
<Input
placeholder="ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAlkjd9s7aJkfdLk3jSLkfj2lk3j2lkfj2l3kjf2lkfj2l"
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</AccordionContent>
</AccordionItem>
</Accordion>

<PermissionsDialog
checkResponse={
validationResponse ?? new CheckConnectionConfigResponse({})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,33 @@ export function getConnectionComponentDetails(
connectionName: connection.name,
url: connection.connectionConfig.config.value.connectionConfig
.value,
tunnel: {
host:
connection.connectionConfig.config.value.tunnel?.host ?? '',
port:
connection.connectionConfig.config.value.tunnel?.port ?? 22,
knownHostPublicKey:
connection.connectionConfig.config.value.tunnel
?.knownHostPublicKey ?? '',
user:
connection.connectionConfig.config.value.tunnel?.user ?? '',
passphrase:
connection.connectionConfig.config.value.tunnel &&
connection.connectionConfig.config.value.tunnel.authentication
? getPassphraseFromSshAuthentication(
connection.connectionConfig.config.value.tunnel
.authentication
) ?? ''
: '',
privateKey:
connection.connectionConfig.config.value.tunnel &&
connection.connectionConfig.config.value.tunnel.authentication
? getPrivateKeyFromSshAuthentication(
connection.connectionConfig.config.value.tunnel
.authentication
) ?? ''
: '',
},
}}
onSaved={(resp) => onSaved(resp)}
onSaveFailed={onSaveFailed}
Expand Down
1 change: 1 addition & 0 deletions frontend/apps/web/app/(mgmt)/[account]/connections/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ function buildMongoConnectionConfig(
case: 'url',
value: values.url,
},
tunnel: getTunnelConfig(values.tunnel),
});

return mongoconfig;
Expand Down
Loading