Skip to content

Commit

Permalink
feat: VEC-370 add --tls-hostname-override flag (#14)
Browse files Browse the repository at this point in the history
* feat: VEC-370 add --tls-hostname-override flag
  • Loading branch information
jdogmcsteezy authored Oct 1, 2024
1 parent 0b1336a commit 34ddb64
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 84 deletions.
13 changes: 7 additions & 6 deletions cmd/flags/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (cf *ClientFlags) NewClientFlagSet() *pflag.FlagSet {
flagSet.VarP(&cf.AuthCredentials.Password, AuthPassword, "P", "The AVS password for the specified user. If a password is not provided you will be prompted. Additionally can be set using the environment variable ASVEC_PASSWORD.") //nolint:lll // For readability
flagSet.VarP(&cf.AuthCredentials, AuthCredentials, "C", "The AVS user and password used to authenticate. Additionally can be set using the environment variable ASVEC_CREDENTIALS. If a password is not provided you will be prompted. This flag is provided in addition to --user and --password") //nolint:lll // For readability
flagSet.DurationVar(&cf.Timeout, Timeout, time.Second*5, "The timeout to use for each request to AVS") //nolint:lll // For readability
flagSet.AddFlagSet(cf.newTLSFlagSet(func(s string) string { return s }))
flagSet.AddFlagSet(cf.newTLSFlagSet())

return flagSet
}
Expand All @@ -52,11 +52,12 @@ func (cf *ClientFlags) NewSLogAttr() []any {
slog.String(ListenerName, cf.ListenerName.String()),
slog.String(AuthUser, cf.AuthCredentials.User.String()),
slog.String(AuthPassword, logPass),
slog.Bool(TLSCaFile, cf.TLSRootCAFile != nil),
slog.Bool(TLSCaPath, cf.TLSRootCAPath != nil),
slog.Bool(TLSCertFile, cf.TLSCertFile != nil),
slog.Bool(TLSKeyFile, cf.TLSKeyFile != nil),
slog.Bool(TLSKeyFilePass, cf.TLSKeyFilePass != nil),
slog.Bool(TLSCaFile, cf.RootCAFile != nil),
slog.Bool(TLSCaPath, cf.RootCAPath != nil),
slog.Bool(TLSCertFile, cf.CertFile != nil),
slog.Bool(TLSKeyFile, cf.KeyFile != nil),
slog.Bool(TLSKeyFilePass, cf.KeyFilePass != nil),
slog.String(TLSHostnameOverride, cf.HostnameOverride),
slog.Duration(Timeout, cf.Timeout),
}
}
5 changes: 5 additions & 0 deletions cmd/flags/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,20 @@ const (
TLSCertFile = "tls-certfile"
TLSKeyFile = "tls-keyfile"
TLSKeyFilePass = "tls-keyfile-password" //nolint:gosec // Not a credential
TLSHostnameOverride = "tls-hostname-override"

// TODO Replace short flag constants with variables
DimensionShort = "d"
VectorFieldShort = "f"
DistanceMetricShort = "m"
NamespaceShort = "n"
SetShort = "s"
IndexNameShort = "i"
VectorShort = "v"
KeyStrShort = "k"
KeyIntShort = "t"
MaxDataColWidthShort = "w"
YesShort = "y"

DefaultIPv4 = "127.0.0.1"
DefaultPort = 5000
Expand Down
61 changes: 39 additions & 22 deletions cmd/flags/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,53 +10,70 @@ import (

//nolint:govet // Padding not a concern for a CLI
type TLSFlags struct {
TLSProtocols commonFlags.TLSProtocolsFlag
TLSRootCAFile commonFlags.CertFlag
TLSRootCAPath commonFlags.CertPathFlag
TLSCertFile commonFlags.CertFlag
TLSKeyFile commonFlags.CertFlag
TLSKeyFilePass commonFlags.PasswordFlag
Protocols commonFlags.TLSProtocolsFlag
RootCAFile commonFlags.CertFlag
RootCAPath commonFlags.CertPathFlag
CertFile commonFlags.CertFlag
KeyFile commonFlags.CertFlag
KeyFilePass commonFlags.PasswordFlag
HostnameOverride string
}

func NewTLSFlags() *TLSFlags {
return &TLSFlags{
TLSProtocols: commonFlags.NewDefaultTLSProtocolsFlag(),
Protocols: commonFlags.NewDefaultTLSProtocolsFlag(),
}
}

// newTLSFlagSet returns a new pflag.FlagSet with TLS flags defined. Values
// are stored in the TLSFlags struct.
func (tf *TLSFlags) newTLSFlagSet(fmtUsage commonFlags.UsageFormatter) *pflag.FlagSet {
func (tf *TLSFlags) newTLSFlagSet() *pflag.FlagSet {
f := &pflag.FlagSet{}

f.Var(&tf.TLSRootCAFile, "tls-cafile", fmtUsage("The CA used when connecting to AVS."))
f.Var(&tf.TLSRootCAPath, "tls-capath", fmtUsage("A path containing CAs for connecting to AVS."))
f.Var(&tf.TLSCertFile, "tls-certfile", fmtUsage("The certificate file for mutual TLS authentication with AVS."))
f.Var(&tf.TLSKeyFile, "tls-keyfile", fmtUsage("The key file used for mutual TLS authentication with AVS."))
f.Var(&tf.TLSKeyFilePass, "tls-keyfile-password", fmtUsage("The password used to decrypt the key-file if encrypted."))
f.Var(&tf.TLSProtocols, "tls-protocols", fmtUsage(
f.Var(&tf.RootCAFile, TLSCaFile, "The CA used when connecting to AVS.")
f.Var(&tf.RootCAPath, TLSCaPath, "A path containing CAs for connecting to AVS.")
f.Var(&tf.CertFile, TLSCertFile, "The certificate file for mutual TLS authentication with AVS.")
f.Var(&tf.KeyFile, TLSKeyFile, "The key file used for mutual TLS authentication with AVS.")
f.Var(&tf.KeyFilePass, TLSKeyFilePass, "The password used to decrypt the key-file if encrypted.")
f.Var(&tf.Protocols, TLSProtocols,
"Set the TLS protocol selection criteria. This format is the same as"+
" Apache's SSLProtocol documented at https://httpd.apache.org/docs/current/mod/mod_ssl.html#ssl protocol.",
))
)
f.StringVar(
&tf.HostnameOverride,
TLSHostnameOverride,
"",
"The hostname to use when validating the server certificate.",
)

return f
}

func (tf *TLSFlags) NewTLSConfig() (*tls.Config, error) {
rootCA := [][]byte{}

if len(tf.TLSRootCAFile) != 0 {
rootCA = append(rootCA, tf.TLSRootCAFile)
if len(tf.RootCAFile) != 0 {
rootCA = append(rootCA, tf.RootCAFile)
}

rootCA = append(rootCA, tf.TLSRootCAPath...)
rootCA = append(rootCA, tf.RootCAPath...)

return commonClient.NewTLSConfig(
tlsConfig, err := commonClient.NewTLSConfig(
rootCA,
tf.TLSCertFile,
tf.TLSKeyFile,
tf.TLSKeyFilePass,
tf.CertFile,
tf.KeyFile,
tf.KeyFilePass,
0,
0,
).NewGoTLSConfig()

if err != nil {
return nil, err
}

if tf.HostnameOverride != "" {
tlsConfig.ServerName = tf.HostnameOverride
}

return tlsConfig, nil
}
8 changes: 4 additions & 4 deletions cmd/indexCreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ var indexCreateFlags = &struct {

func newIndexCreateFlagSet() *pflag.FlagSet {
flagSet := &pflag.FlagSet{}
flagSet.BoolVarP(&indexCreateFlags.yes, flags.Yes, "y", false, "When true do not prompt for confirmation.")
flagSet.BoolVarP(&indexCreateFlags.yes, flags.Yes, flags.YesShort, false, "When true do not prompt for confirmation.")
flagSet.StringVar(&indexCreateFlags.inputFile, flags.InputFile, StdIn, "A yaml file containing IndexDefinitions created using \"asvec index list --yaml\"") //nolint:lll // For readability
flagSet.StringVarP(&indexCreateFlags.namespace, flags.Namespace, flags.NamespaceShort, "", "The namespace for the index.") //nolint:lll // For readability
flagSet.VarP(&indexCreateFlags.set, flags.Set, flags.SetShort, "The sets for the index.") //nolint:lll // For readability //nolint:lll // For readability
flagSet.StringVarP(&indexCreateFlags.indexName, flags.IndexName, flags.IndexNameShort, "", "The name of the index.") //nolint:lll // For readability
flagSet.StringVarP(&indexCreateFlags.vectorField, flags.VectorField, "f", "", "The name of the vector field.") //nolint:lll // For readability
flagSet.Uint32VarP(&indexCreateFlags.dimensions, flags.Dimension, "d", 0, "The dimension of the vector field.") //nolint:lll // For readability
flagSet.VarP(&indexCreateFlags.distanceMetric, flags.DistanceMetric, "m", fmt.Sprintf("The distance metric for the index. Valid values: %s", strings.Join(flags.DistanceMetricEnum(), ", "))) //nolint:lll // For readability
flagSet.StringVarP(&indexCreateFlags.vectorField, flags.VectorField, flags.VectorFieldShort, "", "The name of the vector field.") //nolint:lll // For readability
flagSet.Uint32VarP(&indexCreateFlags.dimensions, flags.Dimension, flags.DimensionShort, 0, "The dimension of the vector field.") //nolint:lll // For readability
flagSet.VarP(&indexCreateFlags.distanceMetric, flags.DistanceMetric, flags.DistanceMetricShort, fmt.Sprintf("The distance metric for the index. Valid values: %s", strings.Join(flags.DistanceMetricEnum(), ", "))) //nolint:lll // For readability
flagSet.StringToStringVar(&indexCreateFlags.indexLabels, flags.IndexLabels, nil, "Optional labels to assign to the index. Example: \"model=all-MiniLM-L6-v2,foo=bar\"") //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.storageNamespace, flags.StorageNamespace, "Optional storage namespace where the index is stored. Defaults to the index namespace.") //nolint:lll // For readability
flagSet.Var(&indexCreateFlags.storageSet, flags.StorageSet, "Optional storage set where the index is stored. Defaults to the index name.") //nolint:lll // For readability
Expand Down
6 changes: 6 additions & 0 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ func createClientFromFlags(clientFlags *flags.ClientFlags) (*avs.Client, error)
ctx, hosts, clientFlags.ListenerName.Val, isLoadBalancer, creds, tlsConfig, logger,
)
if err != nil {
if strings.Contains(err.Error(), "because it doesn't contain any IP SANs") {
view.Printf("Hint: Failed to verify because of certificate hostname mismatch.")
view.Printf("Hint: Either correctly set your certificate SAN or use --%s", flags.TLSHostnameOverride)
}

logger.Error("failed to create AVS client", slog.Any("error", err))

return nil, err
}

Expand Down
55 changes: 55 additions & 0 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1787,3 +1787,58 @@ func (suite *CmdTestSuite) TestEnvVars() {

suite.NoError(err, "err: %s, stdout: %s, stderr: %s", err, stdout, stderr)
}

func (suite *CmdTestSuite) TestTLSHostnameOverride_Success() {
if suite.AvsTLSConfig == nil {
suite.T().Skip("Not a TLS suite")
}

newSuiteFlags := []string{}
for _, flag := range suite.SuiteFlags {
if strings.Contains(flag, tests.CreateFlagStr(flags.Host, "")) || strings.Contains(flag, tests.CreateFlagStr(flags.Seeds, "")) {
flagSplit := strings.Split(flag, " ")

flagSplit[1] = "127.0.0.1:10000" // For tls the certs only work with localhost not 127.0.0.1
flag = strings.Join(flagSplit, " ")
}

newSuiteFlags = append(newSuiteFlags, flag)
}

newSuiteFlags = append(newSuiteFlags, tests.CreateFlagStr(flags.TLSHostnameOverride, "localhost"))

suite.Logger.Debug("suite flags", slog.Any("flags", newSuiteFlags))
asvecCmd := strings.Split("index ls --log-level debug --timeout 10s", " ")
asvecCmd = append(asvecCmd, strings.Split(strings.Join(newSuiteFlags, " "), " ")...)

stdout, stderr, err := suite.RunCmd(asvecCmd...)

suite.NoError(err, "err: %s, stdout: %s, stderr: %s", err, stdout, stderr)
}

func (suite *CmdTestSuite) TestTLSHostnameOverride_Failure() {
if suite.AvsTLSConfig == nil {
suite.T().Skip("Not a TLS suite")
}

newSuiteFlags := []string{}
for _, flag := range suite.SuiteFlags {
if strings.Contains(flag, tests.CreateFlagStr(flags.Host, "")) || strings.Contains(flag, tests.CreateFlagStr(flags.Seeds, "")) {
flagSplit := strings.Split(flag, " ")

flagSplit[1] = "127.0.0.1:10000" // For tls the certs only work with localhost not 127.0.0.1
flag = strings.Join(flagSplit, " ")
}

newSuiteFlags = append(newSuiteFlags, flag)
}

suite.Logger.Debug("suite flags", slog.Any("flags", newSuiteFlags))
asvecCmd := strings.Split("index ls --log-level debug --timeout 10s", " ")
asvecCmd = append(asvecCmd, strings.Split(strings.Join(newSuiteFlags, " "), " ")...)

stdout, stderr, err := suite.RunCmd(asvecCmd...)
suite.Error(err, "err: %s, stdout: %s, stderr: %s", err, stdout, stderr)
suite.Contains(stdout, "Hint: Failed to verify because of certificate hostname mismatch.")
suite.Contains(stdout, "Hint: Either correctly set your certificate SAN or use")
}
27 changes: 13 additions & 14 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,49 @@ module asvec

go 1.22.5

// replace github.com/aerospike/tools-common-go => ../tools-common-go

require (
github.com/aerospike/avs-client-go v0.0.0-20240906211641-97c1df4005ae
github.com/aerospike/avs-client-go v0.0.0-20241001202601-cea4c0a9a32d
github.com/aerospike/tools-common-go v0.0.0-20240927170813-c352c1917359
github.com/jedib0t/go-pretty/v6 v6.5.9
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
golang.org/x/term v0.22.0
golang.org/x/term v0.24.0
google.golang.org/protobuf v1.34.2
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/aerospike/aerospike-client-go/v7 v7.6.0 // indirect
github.com/aerospike/aerospike-client-go/v7 v7.7.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/sagikazarmark/locafero v0.6.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
go.opentelemetry.io/otel/metric v1.27.0 // indirect
go.opentelemetry.io/otel/trace v1.27.0 // indirect
go.uber.org/goleak v1.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.22.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240725223205-93522f1f2a9f // indirect
google.golang.org/grpc v1.65.0 // indirect
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect
golang.org/x/net v0.29.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240930140551-af27646dc61f // indirect
google.golang.org/grpc v1.67.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
Loading

0 comments on commit 34ddb64

Please sign in to comment.