Skip to content

Commit

Permalink
Merge pull request #161 from Snowflake-Labs/tmerz-tls-config
Browse files Browse the repository at this point in the history
add another hook and flag to set tls.Config
  • Loading branch information
sfc-gh-tmerz authored Sep 14, 2022
2 parents 8a5b29c + 05768e3 commit 4ce1d37
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 7 deletions.
59 changes: 55 additions & 4 deletions cmd/proxy-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package server

import (
"context"
"crypto/tls"
"fmt"
"net"
"os"

Expand All @@ -32,6 +34,7 @@ import (
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/proxy/server"
"github.com/Snowflake-Labs/sansshell/telemetry"
"google.golang.org/grpc/credentials"
)

// runState encapsulates all of the variable state needed
Expand All @@ -42,6 +45,7 @@ type runState struct {
policy string
clientPolicy string
credSource string
tlsConfig *tls.Config
hostport string
justification bool
justificationFunc func(string) error
Expand Down Expand Up @@ -88,6 +92,14 @@ func WithClientPolicy(policy string) Option {
})
}

// WithTlsConfig applies a supplied tls.Config object to the gRPC server.
func WithTlsConfig(tlsConfig *tls.Config) Option {
return optionFunc(func(r *runState) error {
r.tlsConfig = tlsConfig
return nil
})
}

// WithCredSource applies a registered credential source with the mtls package.
func WithCredSource(credSource string) Option {
return optionFunc(func(r *runState) error {
Expand Down Expand Up @@ -196,14 +208,17 @@ func Run(ctx context.Context, opts ...Option) {
}
}

serverCreds, err := mtls.LoadServerCredentials(ctx, rs.credSource)
serverCreds, err := extractServerTransportCredentialsFromRunState(ctx, rs)

if err != nil {
rs.logger.Error(err, "mtls.LoadServerCredentials", "credsource", rs.credSource)
rs.logger.Error(err, "unable to extract transport credentials from runstate for the server", "credsource", rs.credSource)
os.Exit(1)
}
clientCreds, err := mtls.LoadClientCredentials(ctx, rs.credSource)

clientCreds, err := extractClientTransportCredentialsFromRunState(ctx, rs)

if err != nil {
rs.logger.Error(err, "mtls.LoadClientCredentials", "credsource", rs.credSource)
rs.logger.Error(err, "unable to extract transport credentials from runstate for the client", "credsource", rs.credSource)
os.Exit(1)
}

Expand Down Expand Up @@ -297,3 +312,39 @@ func Run(ctx context.Context, opts ...Option) {
os.Exit(1)
}
}

// extractClientTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified
func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) {
var creds credentials.TransportCredentials
var err error
if rs.credSource != "" && rs.tlsConfig != nil {
return nil, fmt.Errorf("both credSource and tlsConfig are defined for the client")
}
if rs.credSource != "" {
creds, err = mtls.LoadClientCredentials(ctx, rs.credSource)
if err != nil {
return nil, err
}
} else {
creds = credentials.NewTLS(rs.tlsConfig)
}
return creds, nil
}

// extractServerTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified
func extractServerTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) {
var creds credentials.TransportCredentials
var err error
if rs.credSource != "" && rs.tlsConfig != nil {
return nil, fmt.Errorf("both credSource and tlsConfig are defined for the server")
}
if rs.credSource != "" {
creds, err = mtls.LoadServerCredentials(ctx, rs.credSource)
if err != nil {
return nil, err
}
} else {
creds = credentials.NewTLS(rs.tlsConfig)
}
return creds, nil
}
35 changes: 32 additions & 3 deletions cmd/sansshell-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package server

import (
"context"
"crypto/tls"
"fmt"
"os"

"github.com/go-logr/logr"
Expand All @@ -29,6 +31,7 @@ import (
"github.com/Snowflake-Labs/sansshell/auth/mtls"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"github.com/Snowflake-Labs/sansshell/server"
"google.golang.org/grpc/credentials"
)

// runState encapsulates all of the variable state needed
Expand All @@ -37,6 +40,7 @@ import (
type runState struct {
logger logr.Logger
credSource string
tlsConfig *tls.Config
hostport string
policy string
justification bool
Expand Down Expand Up @@ -74,6 +78,14 @@ func WithPolicy(policy string) Option {
})
}

// WithTlsConfig applies a supplied tls.Config object to the gRPC server.
func WithTlsConfig(tlsConfig *tls.Config) Option {
return optionFunc(func(r *runState) error {
r.tlsConfig = tlsConfig
return nil
})
}

// WithCredSource applies a registered credential source with the mtls package.
func WithCredSource(credSource string) Option {
return optionFunc(func(r *runState) error {
Expand Down Expand Up @@ -160,11 +172,10 @@ func Run(ctx context.Context, opts ...Option) {
os.Exit(1)
}
}
creds, err := extractTransportCredentialsFromRunState(ctx, rs)

creds, err := mtls.LoadServerCredentials(ctx, rs.credSource)
if err != nil {
rs.logger.Error(err, "mtls.LoadServerCredentials", "credsource", rs.credSource)
os.Exit(1)
rs.logger.Error(err, "unable to extract transport credentials from runstate", "credsource", rs.credSource)
}

justificationHook := rpcauth.HookIf(rpcauth.JustificationHook(rs.justificationFunc), func(input *rpcauth.RPCAuthInput) bool {
Expand Down Expand Up @@ -193,3 +204,21 @@ func Run(ctx context.Context, opts ...Option) {
os.Exit(1)
}
}

// extractTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified
func extractTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) {
var creds credentials.TransportCredentials
var err error
if rs.credSource != "" && rs.tlsConfig != nil {
return nil, fmt.Errorf("both credSource and tlsConfig are defined")
}
if rs.credSource != "" {
creds, err = mtls.LoadServerCredentials(ctx, rs.credSource)
if err != nil {
return nil, err
}
} else {
creds = credentials.NewTLS(rs.tlsConfig)
}
return creds, nil
}

0 comments on commit 4ce1d37

Please sign in to comment.