Skip to content

Commit

Permalink
remove secure from TLSConfig Load
Browse files Browse the repository at this point in the history
Signed-off-by: Arvindh <[email protected]>
  • Loading branch information
arvindh123 committed Mar 12, 2024
1 parent c7d1db3 commit 8f4e658
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 55 deletions.
11 changes: 5 additions & 6 deletions pkg/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (

"github.com/absmach/mproxy"
"github.com/absmach/mproxy/pkg/session"
mptls "github.com/absmach/mproxy/pkg/tls"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -118,7 +117,7 @@ func NewProxy(config mproxy.Config, handler session.Handler, logger *slog.Logger
}

func (p Proxy) Listen(ctx context.Context) error {
tlsCfg, secure, err := p.config.TLSConfig.Load()
tlsCfg, err := p.config.TLSConfig.Load()
if err != nil {
return err
}
Expand All @@ -128,11 +127,11 @@ func (p Proxy) Listen(ctx context.Context) error {
return err
}

if secure > mptls.WithoutTLS {
if tlsCfg != nil {
l = tls.NewListener(l, tlsCfg)
}

p.logger.Info(fmt.Sprintf("HTTP proxy server started at %s%s %s", p.config.Address, p.config.PrefixPath, secure.String()))
p.logger.Info(fmt.Sprintf("HTTP proxy server started at %s%s %s", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()))

var server http.Server
g, ctx := errgroup.WithContext(ctx)
Expand All @@ -150,9 +149,9 @@ func (p Proxy) Listen(ctx context.Context) error {
return server.Close()
})
if err := g.Wait(); err != nil {
p.logger.Info(fmt.Sprintf("HTTP proxy server at %s%s %s exiting with errors", p.config.Address, p.config.PrefixPath, secure.String()), slog.String("error", err.Error()))
p.logger.Info(fmt.Sprintf("HTTP proxy server at %s%s %s exiting with errors", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()), slog.String("error", err.Error()))
} else {
p.logger.Info(fmt.Sprintf("HTTP proxy server at %s%s %s exiting...", p.config.Address, p.config.PrefixPath, secure.String()))
p.logger.Info(fmt.Sprintf("HTTP proxy server at %s%s %s exiting...", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()))
}
return nil
}
10 changes: 5 additions & 5 deletions pkg/mqtt/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (p Proxy) handle(ctx context.Context, inbound net.Conn) {

// Listen of the server, this will block.
func (p Proxy) Listen(ctx context.Context) error {
tlsCfg, secure, err := p.config.TLSConfig.Load()
tlsCfg, err := p.config.TLSConfig.Load()
if err != nil {
return err
}
Expand All @@ -85,11 +85,11 @@ func (p Proxy) Listen(ctx context.Context) error {
return err
}

if secure > mptls.WithoutTLS {
if tlsCfg != nil {
l = tls.NewListener(l, tlsCfg)
}

p.logger.Info(fmt.Sprintf("MQTT proxy server started at %s %s", p.config.Address, secure.String()))
p.logger.Info(fmt.Sprintf("MQTT proxy server started at %s %s", p.config.Address, p.config.TLSConfig.Security()))
g, ctx := errgroup.WithContext(ctx)

// Acceptor loop
Expand All @@ -103,9 +103,9 @@ func (p Proxy) Listen(ctx context.Context) error {
return l.Close()
})
if err := g.Wait(); err != nil {
p.logger.Info(fmt.Sprintf("MQTT proxy server at %s %s exiting with errors", p.config.Address, secure.String()), slog.String("error", err.Error()))
p.logger.Info(fmt.Sprintf("MQTT proxy server at %s %s exiting with errors", p.config.Address, p.config.TLSConfig.Security()), slog.String("error", err.Error()))
} else {
p.logger.Info(fmt.Sprintf("MQTT proxy server at %s %s exiting...", p.config.Address, secure.String()))
p.logger.Info(fmt.Sprintf("MQTT proxy server at %s %s exiting...", p.config.Address, p.config.TLSConfig.Security()))
}
return nil
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/mqtt/websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (p Proxy) pass(ctx context.Context, in *websocket.Conn) {
}

func (p Proxy) Listen(ctx context.Context) error {
tlsCfg, secure, err := p.config.TLSConfig.Load()
tlsCfg, err := p.config.TLSConfig.Load()
if err != nil {
return err
}
Expand All @@ -112,7 +112,7 @@ func (p Proxy) Listen(ctx context.Context) error {
return err
}

if secure > mptls.WithoutTLS {
if tlsCfg != nil {
l = tls.NewListener(l, tlsCfg)
}

Expand All @@ -126,16 +126,16 @@ func (p Proxy) Listen(ctx context.Context) error {
g.Go(func() error {
return server.Serve(l)
})
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server started at %s%s %s", p.config.Address, p.config.PrefixPath, secure.String()))
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server started at %s%s %s", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()))

g.Go(func() error {
<-ctx.Done()
return server.Close()
})
if err := g.Wait(); err != nil {
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server at %s%s %s exiting with errors", p.config.Address, p.config.PrefixPath, secure.String()), slog.String("error", err.Error()))
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server at %s%s %s exiting with errors", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()), slog.String("error", err.Error()))
} else {
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server at %s%s %s exiting...", p.config.Address, p.config.PrefixPath, secure.String()))
p.logger.Info(fmt.Sprintf("MQTT websocket proxy server at %s%s %s exiting...", p.config.Address, p.config.PrefixPath, p.config.TLSConfig.Security()))
}
return nil
}
92 changes: 53 additions & 39 deletions pkg/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,68 @@ func (c *Config) EnvParse(opts env.Options) error {
}

// Load return a TLS configuration that can be used in TLS servers.
func (c *Config) Load() (*tls.Config, Security, error) {
func (c *Config) Load() (*tls.Config, error) {
if c.CertFile == "" || c.KeyFile == "" {
return nil, nil
}

tlsConfig := &tls.Config{}
secure := WithoutTLS
if c.CertFile != "" || c.KeyFile != "" {
certificate, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, secure, errors.Join(errLoadCerts, err)

certificate, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, errors.Join(errLoadCerts, err)
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{certificate},
}

// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return nil, errors.Join(errLoadServerCA, err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
tlsConfig = &tls.Config{
Certificates: []tls.Certificate{certificate},
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, errAppendCA
}
secure = WithTLS
}

// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return nil, secure, errors.Join(errLoadServerCA, err)
// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return nil, errors.Join(errLoadClientCA, err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return nil, secure, errAppendCA
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, errAppendCA
}

// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return nil, secure, errors.Join(errLoadClientCA, err)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
if len(c.ClientValidation.ValidationMethods) > 0 {
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return nil, secure, errAppendCA
}
secure = WithmTLS
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
if len(c.ClientValidation.ValidationMethods) > 0 {
tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate
secure = WithmTLSVerify
}
}
return tlsConfig, nil
}

func (c *Config) Security() Security {
if c.CertFile != "" && c.KeyFile != "" {
return WithoutTLS
}

if c.ClientCAFile != "" {
if len(c.ClientValidation.ValidationMethods) > 0 {
return WithmTLSVerify
}
return WithmTLS
}
return tlsConfig, secure, nil

return WithTLS
}

// ClientCert returns client certificate.
Expand Down

0 comments on commit 8f4e658

Please sign in to comment.