Skip to content

Commit

Permalink
Merge pull request #10 from jhernand/add_listener_flags
Browse files Browse the repository at this point in the history
Add `--api-listener-address`
  • Loading branch information
jhernand authored Nov 8, 2023
2 parents 1eafd77 + a11e775 commit 5364ccc
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 8 deletions.
29 changes: 25 additions & 4 deletions internal/cmd/server/start_deployment_manager_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/openshift-kni/oran-o2ims/internal/authorization"
"github.com/openshift-kni/oran-o2ims/internal/exit"
"github.com/openshift-kni/oran-o2ims/internal/logging"
"github.com/openshift-kni/oran-o2ims/internal/network"
"github.com/openshift-kni/oran-o2ims/internal/service"
)

Expand All @@ -41,6 +42,7 @@ func DeploymentManagerServer() *cobra.Command {
flags := result.Flags()
authentication.AddFlags(flags)
authorization.AddFlags(flags)
network.AddListenerFlags(flags, network.APIListener, network.APIAddress)
_ = flags.String(
cloudIDFlagName,
"",
Expand Down Expand Up @@ -252,12 +254,31 @@ func (c *DeploymentManagerServerCommand) run(cmd *cobra.Command, argv []string)
objectAdapter,
).Methods(http.MethodGet)

// Start the server:
err = http.ListenAndServe(":8080", router)
// Start the API server:
apiListener, err := network.NewListener().
SetLogger(logger).
SetFlags(flags, network.APIListener).
Build()
if err != nil {
logger.Error(
"server finished with error",
"error", err,
"Failed to to create API listener",
slog.String("error", err.Error()),
)
return exit.Error(1)
}
logger.Info(
"API listening",
slog.String("address", apiListener.Addr().String()),
)
apiServer := http.Server{
Addr: apiListener.Addr().String(),
Handler: router,
}
err = apiServer.Serve(apiListener)
if err != nil {
logger.Error(
"API server finished with error",
slog.String("error", err.Error()),
)
return exit.Error(1)
}
Expand Down
29 changes: 25 additions & 4 deletions internal/cmd/server/start_metadata_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/openshift-kni/oran-o2ims/internal"
"github.com/openshift-kni/oran-o2ims/internal/exit"
"github.com/openshift-kni/oran-o2ims/internal/network"
"github.com/openshift-kni/oran-o2ims/internal/service"
)

Expand All @@ -36,6 +37,7 @@ func MetadataServer() *cobra.Command {
RunE: c.run,
}
flags := result.Flags()
network.AddListenerFlags(flags, network.APIListener, network.APIAddress)
_ = flags.String(
cloudIDFlagName,
"",
Expand Down Expand Up @@ -157,12 +159,31 @@ func (c *MetadataServerCommand) run(cmd *cobra.Command, argv []string) error {
cloudInfoAdapter,
).Methods(http.MethodGet)

// Start the server:
err = http.ListenAndServe(":8080", router)
// Start the API server:
apiListener, err := network.NewListener().
SetLogger(logger).
SetFlags(flags, network.APIListener).
Build()
if err != nil {
logger.Error(
"Failed to to create API listener",
slog.String("error", err.Error()),
)
return exit.Error(1)
}
logger.Info(
"API listening",
slog.String("address", apiListener.Addr().String()),
)
apiServer := http.Server{
Addr: apiListener.Addr().String(),
Handler: router,
}
err = apiServer.Serve(apiListener)
if err != nil {
logger.Error(
"server finished with error",
"error", err,
"API server finished with error",
slog.String("error", err.Error()),
)
return exit.Error(1)
}
Expand Down
52 changes: 52 additions & 0 deletions internal/network/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
Copyright 2023 Red Hat Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing permissions and limitations under the
License.
*/

package network

import (
"fmt"
"strings"

"github.com/spf13/pflag"
)

// AddListenerFlags adds to the given flag set the flags needed to configure a network listener. It
// receives the name of the listerner and the default address. For example, to configure an API
// listener:
//
// network.AddListenerFlags("API", "localhost:8000")
//
// The name will be converted to lower case to generate a prefix for the flags, and will be used
// unchanged as a prefix for the help text. The above example will result in the following flags:
//
// --api-listener-address string API listen address. (default "localhost:8000")
func AddListenerFlags(set *pflag.FlagSet, name, addr string) {
_ = set.String(
listenerFlagName(name, listenerAddrFlagSuffix),
addr,
fmt.Sprintf("%s listen address.", name),
)
}

// Names of the flags:
const (
listenerAddrFlagSuffix = "listener-address"
)

// listenerFlagName calculates a complete flag name from a listener name and a flag name suffix.
// For example, if the listener name is 'API' and the flag name suffix is 'listener-address' it
// returns 'api-listener-address'.
func listenerFlagName(name, suffix string) string {
return fmt.Sprintf("%s-%s", strings.ToLower(name), suffix)
}
105 changes: 105 additions & 0 deletions internal/network/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
Copyright 2023 Red Hat Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing permissions and limitations under the
License.
*/

package network

import (
"fmt"
"log/slog"
"net"

"github.com/spf13/pflag"
)

// ListenerBuilder contains the data and logic needed to create a network listener. Don't create
// instances of this object directly, use the NewListener function instead.
type ListenerBuilder struct {
logger *slog.Logger
network string
address string
}

// NewListener creates a builder that can then used to configure and create a network listener.
func NewListener() *ListenerBuilder {
return &ListenerBuilder{
network: "tcp",
}
}

// SetLogger sets the logger that the listener will use to send messages to the log. This is
// mandatory.
func (b *ListenerBuilder) SetLogger(value *slog.Logger) *ListenerBuilder {
b.logger = value
return b
}

// SetFlags sets the command line flags that should be used to configure the listener.
//
// The name is used to select the options when there are multiple listeners. For example, if it
// is 'API' then it will only take into accounts the flags starting with '--api'.
//
// This is optional.
func (b *ListenerBuilder) SetFlags(flags *pflag.FlagSet, name string) *ListenerBuilder {
if flags != nil {
listenerAddrFlagName := listenerFlagName(name, listenerAddrFlagSuffix)
value, err := flags.GetString(listenerAddrFlagName)
if err == nil {
b.SetAddress(value)
}
}
return b
}

// SetNetwork sets the network. This is optional and the default is TCP.
func (b *ListenerBuilder) SetNetwork(value string) *ListenerBuilder {
b.network = value
return b
}

// SetAddress sets the listen address. This is mandatory.
func (b *ListenerBuilder) SetAddress(value string) *ListenerBuilder {
b.address = value
return b
}

// Build uses the data stored in the builder to create a new network listener.
func (b *ListenerBuilder) Build() (result net.Listener, err error) {
// Check parameters:
if b.logger == nil {
err = fmt.Errorf("logger is mandatory")
return
}
if b.network == "" {
err = fmt.Errorf("network is mandatory")
return
}
if b.address == "" {
err = fmt.Errorf("address is mandatory")
return
}

// Create and populate the object:
result, err = net.Listen(b.network, b.address)
return
}

// Common listener names:
const (
APIListener = "API"
)

// Common listener addresses:
const (
APIAddress = "localhost:8000"
)
134 changes: 134 additions & 0 deletions internal/network/listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
Copyright (c) 2023 Red Hat, Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied. See the License for the specific language governing permissions and limitations under the
License.
*/

package network

import (
"os"
"path/filepath"

. "github.com/onsi/ginkgo/v2/dsl/core"
. "github.com/onsi/gomega"
"github.com/spf13/pflag"
)

var _ = Describe("Listener", func() {
var tmp string

BeforeEach(func() {
// In order to avoid TCP port conflicts these tests will use only Unix sockets
// created in this temporary directory:
var err error
tmp, err = os.MkdirTemp("", "*.sockets")
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
err := os.RemoveAll(tmp)
Expect(err).ToNot(HaveOccurred())
})

It("Can't be created without a logger", func() {
address := filepath.Join(tmp, "my.socket")
listener, err := NewListener().
SetNetwork("unix").
SetAddress(address).
Build()
Expect(err).To(HaveOccurred())
Expect(listener).To(BeNil())
msg := err.Error()
Expect(msg).To(ContainSubstring("logger"))
Expect(msg).To(ContainSubstring("mandatory"))
})

It("Can't be created without an address", func() {
listener, err := NewListener().
SetLogger(logger).
SetNetwork("unix").
Build()
Expect(err).To(HaveOccurred())
Expect(listener).To(BeNil())
msg := err.Error()
Expect(msg).To(ContainSubstring("address"))
Expect(msg).To(ContainSubstring("mandatory"))
})

It("Can't be created with an incorrect address", func() {
listener, err := NewListener().
SetLogger(logger).
SetAddress("junk").
Build()
Expect(err).To(HaveOccurred())
Expect(listener).To(BeNil())
msg := err.Error()
Expect(msg).To(ContainSubstring("junk"))
})

It("Uses the given address", func() {
address := filepath.Join(tmp, "my.socket")
listener, err := NewListener().
SetLogger(logger).
SetNetwork("unix").
SetAddress(address).
Build()
Expect(err).ToNot(HaveOccurred())
Expect(listener).ToNot(BeNil())
Expect(listener.Addr().String()).To(Equal(address))
})

It("Honors the address flag", func() {
// Prepare the flags:
address := filepath.Join(tmp, "my.socket")
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
AddListenerFlags(flags, "my", "localhost:80")
err := flags.Parse([]string{
"--my-listener-address", address,
})
Expect(err).ToNot(HaveOccurred())

// Create the listener:
listener, err := NewListener().
SetLogger(logger).
SetNetwork("unix").
SetFlags(flags, "my").
Build()
Expect(err).ToNot(HaveOccurred())
Expect(listener).ToNot(BeNil())
Expect(listener.Addr().String()).To(Equal(address))
})

It("Ignores flags for other listeners", func() {
// Prepare the flags:
myAddress := filepath.Join(tmp, "my.socket")
yourAddress := filepath.Join(tmp, "your.socket")
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
AddListenerFlags(flags, "my", "localhost:80")
AddListenerFlags(flags, "your", "localhost:81")
err := flags.Parse([]string{
"--my-listener-address", myAddress,
"--your-listener-address", yourAddress,
})
Expect(err).ToNot(HaveOccurred())

// Create the listener:
listener, err := NewListener().
SetLogger(logger).
SetNetwork("unix").
SetFlags(flags, "my").
Build()
Expect(err).ToNot(HaveOccurred())
Expect(listener).ToNot(BeNil())
Expect(listener.Addr().String()).To(Equal(myAddress))
})
})
Loading

0 comments on commit 5364ccc

Please sign in to comment.