Skip to content

Commit

Permalink
Merge pull request #200 from nspcc-dev/170-tls-fix
Browse files Browse the repository at this point in the history
Use TLS protocol
  • Loading branch information
roman-khimov authored Apr 27, 2024
2 parents f48958c + c4f9507 commit f024e79
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 216 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ This document outlines major changes between releases.

## [Unreleased]

### Updating from 0.8.3

Notice that configuration parameters in the `server` section were reorganized.
For example e.g.`server.schema` and `tls-listen-limit` were removed, and some
others were moved inside the array `endpoints`. Check your configuration with
the help of the [gate-configuration.md](./docs/gate-configuration.md) and
[config](./config/config.yaml). Also, flags in the command arguments were
changed.

## [0.8.3] - 2024-03-25

### Fixed
Expand Down
228 changes: 124 additions & 104 deletions cmd/neofs-rest-gw/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,30 @@ const (
cfgWalletAddress = "wallet.address"
cfgWalletPassphrase = "wallet.passphrase"

// Config section for autogenerated flags.
cfgServerSection = "server."
// Server endpoints.
cfgServerSection = "server."
cfgServerEndpoints = cfgServerSection + "endpoints"

cfgTLSEnabled = "tls.enabled"
cfgTLSKeyFile = "tls.key"
cfgTLSCertFile = "tls.certificate"
cfgTLSCertCAFile = "tls.ca-certificate"

cfgEndpointAddress = "address"
cfgEndpointExternalAddress = "external-address"
cfgEndpointKeepAlive = "keep-alive"
cfgEndpointReadTimeout = "read-timeout"
cfgEndpointWriteTimeout = "write-timeout"

// Command line args.
cmdHelp = "help"
cmdVersion = "version"
cmdPprof = "pprof"
cmdMetrics = "metrics"
cmdWallet = "wallet"
cmdAddress = "address"
cmdConfig = "config"
cmdHelp = "help"
cmdVersion = "version"
cmdPprof = "pprof"
cmdMetrics = "metrics"
cmdWallet = "wallet"
cmdAddress = "address"
cmdConfig = "config"
cmdListenAddress = "listen-address"

baseURL = "/v1"
)
Expand Down Expand Up @@ -123,9 +136,20 @@ func config() *viper.Viper {

peers := flagSet.StringArrayP(cmdPeers, "p", nil, "NeoFS nodes")

flagSet.String(cmdListenAddress, "localhost:8080", "set the main address to listen")
flagSet.String(cfgTLSCertFile, "", "TLS certificate file to use; note that if you want to start HTTPS server, you should also set up --"+cmdListenAddress+" and --"+cfgTLSKeyFile)
flagSet.String(cfgTLSKeyFile, "", "TLS key file to use; note that if you want to start HTTPS server, you should also set up --"+cmdListenAddress+" and --"+cfgTLSCertFile)
flagSet.Duration(cfgEndpointKeepAlive, 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
flagSet.Duration(cfgEndpointReadTimeout, 30*time.Second, "maximum duration before timing out read of the request")
flagSet.Duration(cfgEndpointWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response")
flagSet.String(cfgEndpointExternalAddress, "localhost:8090", "the IP and port to be shown in the API documentation")

// init server flags
BindDefaultFlags(flagSet)

if err := bindServerFlags(v, flagSet); err != nil {
panic(fmt.Errorf("bind server flags: %w", err))
}
// set defaults:
// pool
v.SetDefault(cfgPoolErrorThreshold, defaultPoolErrorThreshold)
Expand Down Expand Up @@ -191,6 +215,34 @@ func config() *viper.Viper {
return v
}

func bindServerFlags(v *viper.Viper, flags *pflag.FlagSet) error {
// This key is used only to check if the address comes from the command arguments.
if err := v.BindPFlag(cmdListenAddress, flags.Lookup(cmdListenAddress)); err != nil {
return err
}

if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgEndpointAddress, flags.Lookup(cmdListenAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgEndpointExternalAddress, flags.Lookup(cfgEndpointExternalAddress)); err != nil {
return err
}
if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgEndpointKeepAlive, flags.Lookup(cfgEndpointKeepAlive)); err != nil {
return err
}
if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgEndpointReadTimeout, flags.Lookup(cfgEndpointReadTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgEndpointWriteTimeout, flags.Lookup(cfgEndpointWriteTimeout)); err != nil {
return err
}
if err := v.BindPFlag(cfgServerEndpoints+".0."+cfgTLSKeyFile, flags.Lookup(cfgTLSKeyFile)); err != nil {
return err
}

return v.BindPFlag(cfgServerEndpoints+".0."+cfgTLSCertFile, flags.Lookup(cfgTLSCertFile))
}

func init() {
for _, flagName := range serverFlags {
cfgName := cfgServerSection + flagName
Expand All @@ -200,24 +252,10 @@ func init() {
}

var serverFlags = []string{
FlagScheme,
FlagCleanupTimeout,
FlagGracefulTimeout,
FlagMaxHeaderSize,
FlagListenAddress,
FlagListenLimit,
FlagKeepAlive,
FlagReadTimeout,
FlagWriteTimeout,
FlagTLSListenAddress,
FlagTLSCertificate,
FlagTLSKey,
FlagTLSCa,
FlagTLSListenLimit,
FlagTLSKeepAlive,
FlagTLSReadTimeout,
FlagTLSWriteTimeout,
FlagExternalAddress,
}

var bindings = map[string]string{
Expand Down Expand Up @@ -250,6 +288,10 @@ func validateConfig(cfg *viper.Viper, logger *zap.Logger) {

for _, providedKey := range cfg.AllKeys() {
if !strings.HasPrefix(providedKey, cfgPeers) {
if strings.HasPrefix(providedKey, cfgServerEndpoints) {
// Do not validate `Endpoints` section.
continue
}
if _, ok := knownConfigParams[providedKey]; !ok {
logger.Warn("unknown config parameter", zap.String("key", providedKey))
}
Expand Down Expand Up @@ -404,104 +446,82 @@ func newLogger(v *viper.Viper) *zap.Logger {

// ServerConfig contains parsed config for the Echo server.
type ServerConfig struct {
EnabledListeners []string
CleanupTimeout time.Duration
GracefulTimeout time.Duration
MaxHeaderSize int

ListenAddress string
ListenLimit int
KeepAlive time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration

TLSListenAddress string
TLSListenLimit int
TLSKeepAlive time.Duration
TLSReadTimeout time.Duration
TLSWriteTimeout time.Duration
TLSCertificate string
TLSCertificateKey string
TLSCACertificate string

ExternalAddress string
CleanupTimeout time.Duration
GracefulTimeout time.Duration
MaxHeaderSize int
ListenLimit int
Endpoints []EndpointInfo
}

const (
FlagScheme = "scheme"
FlagCleanupTimeout = "cleanup-timeout"
FlagGracefulTimeout = "graceful-timeout"
FlagMaxHeaderSize = "max-header-size"
FlagListenAddress = "listen-address"
FlagListenLimit = "listen-limit"
FlagKeepAlive = "keep-alive"
FlagReadTimeout = "read-timeout"
FlagWriteTimeout = "write-timeout"
FlagTLSListenAddress = "tls-listen-address"
FlagTLSCertificate = "tls-certificate"
FlagTLSKey = "tls-key"
FlagTLSCa = "tls-ca"
FlagTLSListenLimit = "tls-listen-limit"
FlagTLSKeepAlive = "tls-keep-alive"
FlagTLSReadTimeout = "tls-read-timeout"
FlagTLSWriteTimeout = "tls-write-timeout"
FlagExternalAddress = "external-address"
FlagCleanupTimeout = "cleanup-timeout"
FlagGracefulTimeout = "graceful-timeout"
FlagMaxHeaderSize = "max-header-size"
FlagListenLimit = "listen-limit"
)

var defaultSchemes []string

func init() {
defaultSchemes = []string{
schemeHTTP,
}
}

func BindDefaultFlags(flagSet *pflag.FlagSet) {
flagSet.StringSlice(FlagScheme, defaultSchemes, "the listeners to enable, this can be repeated and defaults to the schemes in the swagger spec")

flagSet.Duration(FlagCleanupTimeout, 10*time.Second, "grace period for which to wait before killing idle connections")
flagSet.Duration(FlagGracefulTimeout, 15*time.Second, "grace period for which to wait before shutting down the server")
flagSet.Int(FlagMaxHeaderSize, 1000000, "controls the maximum number of bytes the server will read parsing the request header's keys and values, including the request line. It does not limit the size of the request body")

flagSet.String(FlagListenAddress, "localhost:8080", "the IP and port to listen on")
flagSet.Int(FlagListenLimit, 0, "limit the number of outstanding requests")
flagSet.Duration(FlagKeepAlive, 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
flagSet.Duration(FlagReadTimeout, 30*time.Second, "maximum duration before timing out read of the request")
flagSet.Duration(FlagWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response")

flagSet.String(FlagTLSListenAddress, "localhost:8081", "the IP and port to listen on")
flagSet.String(FlagTLSCertificate, "", "the certificate file to use for secure connections")
flagSet.String(FlagTLSKey, "", "the private key file to use for secure connections (without passphrase)")
flagSet.String(FlagTLSCa, "", "the certificate authority certificate file to be used with mutual tls auth")
flagSet.Int(FlagTLSListenLimit, 0, "limit the number of outstanding requests")
flagSet.Duration(FlagTLSKeepAlive, 3*time.Minute, "sets the TCP keep-alive timeouts on accepted connections. It prunes dead TCP connections ( e.g. closing laptop mid-download)")
flagSet.Duration(FlagTLSReadTimeout, 30*time.Second, "maximum duration before timing out read of the request")
flagSet.Duration(FlagTLSWriteTimeout, 30*time.Second, "maximum duration before timing out write of the response")

flagSet.String(FlagExternalAddress, "localhost:8090", "the IP and port to be shown in the API documentation")
}

func serverConfig(v *viper.Viper) *ServerConfig {
return &ServerConfig{
EnabledListeners: v.GetStringSlice(cfgServerSection + FlagScheme),
CleanupTimeout: v.GetDuration(cfgServerSection + FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(cfgServerSection + FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(cfgServerSection + FlagMaxHeaderSize),

ListenAddress: v.GetString(cfgServerSection + FlagListenAddress),
ListenLimit: v.GetInt(cfgServerSection + FlagListenLimit),
KeepAlive: v.GetDuration(cfgServerSection + FlagKeepAlive),
ReadTimeout: v.GetDuration(cfgServerSection + FlagReadTimeout),
WriteTimeout: v.GetDuration(cfgServerSection + FlagWriteTimeout),
CleanupTimeout: v.GetDuration(cfgServerSection + FlagCleanupTimeout),
GracefulTimeout: v.GetDuration(cfgServerSection + FlagGracefulTimeout),
MaxHeaderSize: v.GetInt(cfgServerSection + FlagMaxHeaderSize),
ListenLimit: v.GetInt(cfgServerSection + FlagListenLimit),
Endpoints: fetchEndpoints(v),
}
}

TLSListenAddress: v.GetString(cfgServerSection + FlagTLSListenAddress),
TLSListenLimit: v.GetInt(cfgServerSection + FlagTLSListenLimit),
TLSKeepAlive: v.GetDuration(cfgServerSection + FlagTLSKeepAlive),
TLSReadTimeout: v.GetDuration(cfgServerSection + FlagTLSReadTimeout),
TLSWriteTimeout: v.GetDuration(cfgServerSection + FlagTLSWriteTimeout),
func fetchEndpoints(v *viper.Viper) []EndpointInfo {
var servers []EndpointInfo

if v.IsSet(cmdListenAddress) {
key := cfgServerEndpoints + ".0."
// If this address is set, we don't use config file to set other parameters.
serverInfo := EndpointInfo{
Address: v.GetString(key + cfgEndpointAddress),
ExternalAddress: v.GetString(key + cfgEndpointExternalAddress),
KeepAlive: v.GetDuration(key + cfgEndpointKeepAlive),
ReadTimeout: v.GetDuration(key + cfgEndpointReadTimeout),
WriteTimeout: v.GetDuration(key + cfgEndpointWriteTimeout),
}
keyFile := v.GetString(key + cfgTLSKeyFile)
certFile := v.GetString(key + cfgTLSCertFile)
if keyFile != "" && certFile != "" {
// If TLS key and certificate are set in the command arguments, we enable TLS.
serverInfo.TLS.Enabled = true
serverInfo.TLS.KeyFile = keyFile
serverInfo.TLS.CertFile = certFile
}
servers = append(servers, serverInfo)
} else {
for i := 0; ; i++ {
key := cfgServerEndpoints + "." + strconv.Itoa(i) + "."

ExternalAddress: v.GetString(cfgServerSection + FlagExternalAddress),
var serverInfo EndpointInfo
serverInfo.Address = v.GetString(key + cfgEndpointAddress)
if serverInfo.Address == "" {
break
}
serverInfo.ExternalAddress = v.GetString(key + cfgEndpointExternalAddress)
serverInfo.KeepAlive = v.GetDuration(key + cfgEndpointKeepAlive)
serverInfo.ReadTimeout = v.GetDuration(key + cfgEndpointReadTimeout)
serverInfo.WriteTimeout = v.GetDuration(key + cfgEndpointWriteTimeout)
serverInfo.TLS.Enabled = v.GetBool(key + cfgTLSEnabled)
serverInfo.TLS.KeyFile = v.GetString(key + cfgTLSKeyFile)
serverInfo.TLS.CertFile = v.GetString(key + cfgTLSCertFile)
serverInfo.TLS.CertCAFile = v.GetString(key + cfgTLSCertCAFile)

servers = append(servers, serverInfo)
}
}

return servers
}

func newNeofsAPI(ctx context.Context, logger *zap.Logger, v *viper.Viper) (*handlers.RestAPI, error) {
Expand Down
4 changes: 2 additions & 2 deletions cmd/neofs-rest-gw/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ func getDefaultConfig(node string) *viper.Viper {
v.SetDefault(cfgPeers+".0.address", node)
v.SetDefault(cfgPeers+".0.weight", 1)
v.SetDefault(cfgPeers+".0.priority", 1)
v.SetDefault(cfgServerSection+FlagListenAddress, testListenAddress)
v.SetDefault(cfgServerSection+FlagWriteTimeout, 60*time.Second)
v.SetDefault(cfgServerSection+cmdListenAddress, testListenAddress)
v.SetDefault(cfgServerSection+cfgEndpointWriteTimeout, 60*time.Second)

return v
}
Expand Down
Loading

0 comments on commit f024e79

Please sign in to comment.