From a11e7756eba485c1eac1d09d3aa415df12c76b47 Mon Sep 17 00:00:00 2001 From: Juan Hernandez Date: Wed, 8 Nov 2023 13:24:02 +0100 Subject: [PATCH] Add `--api-listener-address` This patch adds a new `--api-listener-address` command line option to all the servers. This will be used to specify the address where the server listens. By default it will be `localhost:8000`. Related: https://issues.redhat.com/browse/MGMT-16113 Signed-off-by: Juan Hernandez --- .../server/start_deployment_manager_server.go | 29 +++- internal/cmd/server/start_metadata_server.go | 29 +++- internal/network/flags.go | 52 +++++++ internal/network/listener.go | 105 ++++++++++++++ internal/network/listener_test.go | 134 ++++++++++++++++++ internal/network/suite_test.go | 41 ++++++ 6 files changed, 382 insertions(+), 8 deletions(-) create mode 100644 internal/network/flags.go create mode 100644 internal/network/listener.go create mode 100644 internal/network/listener_test.go create mode 100644 internal/network/suite_test.go diff --git a/internal/cmd/server/start_deployment_manager_server.go b/internal/cmd/server/start_deployment_manager_server.go index e3fd1ae0a..a3e0bf423 100644 --- a/internal/cmd/server/start_deployment_manager_server.go +++ b/internal/cmd/server/start_deployment_manager_server.go @@ -26,6 +26,7 @@ import ( "github.com/openshift-kni/oran-o2ims/internal/authorization" "github.com/openshift-kni/oran-o2ims/internal/exit" "github.com/openshift-kni/oran-o2ims/internal/logging" + "github.com/openshift-kni/oran-o2ims/internal/network" "github.com/openshift-kni/oran-o2ims/internal/service" ) @@ -41,6 +42,7 @@ func DeploymentManagerServer() *cobra.Command { flags := result.Flags() authentication.AddFlags(flags) authorization.AddFlags(flags) + network.AddListenerFlags(flags, network.APIListener, network.APIAddress) _ = flags.String( cloudIDFlagName, "", @@ -252,12 +254,31 @@ func (c *DeploymentManagerServerCommand) run(cmd *cobra.Command, argv []string) objectAdapter, ).Methods(http.MethodGet) - // Start the server: - err = http.ListenAndServe(":8080", router) + // Start the API server: + apiListener, err := network.NewListener(). + SetLogger(logger). + SetFlags(flags, network.APIListener). + Build() if err != nil { logger.Error( - "server finished with error", - "error", err, + "Failed to to create API listener", + slog.String("error", err.Error()), + ) + return exit.Error(1) + } + logger.Info( + "API listening", + slog.String("address", apiListener.Addr().String()), + ) + apiServer := http.Server{ + Addr: apiListener.Addr().String(), + Handler: router, + } + err = apiServer.Serve(apiListener) + if err != nil { + logger.Error( + "API server finished with error", + slog.String("error", err.Error()), ) return exit.Error(1) } diff --git a/internal/cmd/server/start_metadata_server.go b/internal/cmd/server/start_metadata_server.go index a40a7f889..577d0a2b1 100644 --- a/internal/cmd/server/start_metadata_server.go +++ b/internal/cmd/server/start_metadata_server.go @@ -23,6 +23,7 @@ import ( "github.com/openshift-kni/oran-o2ims/internal" "github.com/openshift-kni/oran-o2ims/internal/exit" + "github.com/openshift-kni/oran-o2ims/internal/network" "github.com/openshift-kni/oran-o2ims/internal/service" ) @@ -36,6 +37,7 @@ func MetadataServer() *cobra.Command { RunE: c.run, } flags := result.Flags() + network.AddListenerFlags(flags, network.APIListener, network.APIAddress) _ = flags.String( cloudIDFlagName, "", @@ -157,12 +159,31 @@ func (c *MetadataServerCommand) run(cmd *cobra.Command, argv []string) error { cloudInfoAdapter, ).Methods(http.MethodGet) - // Start the server: - err = http.ListenAndServe(":8080", router) + // Start the API server: + apiListener, err := network.NewListener(). + SetLogger(logger). + SetFlags(flags, network.APIListener). + Build() + if err != nil { + logger.Error( + "Failed to to create API listener", + slog.String("error", err.Error()), + ) + return exit.Error(1) + } + logger.Info( + "API listening", + slog.String("address", apiListener.Addr().String()), + ) + apiServer := http.Server{ + Addr: apiListener.Addr().String(), + Handler: router, + } + err = apiServer.Serve(apiListener) if err != nil { logger.Error( - "server finished with error", - "error", err, + "API server finished with error", + slog.String("error", err.Error()), ) return exit.Error(1) } diff --git a/internal/network/flags.go b/internal/network/flags.go new file mode 100644 index 000000000..7eaa04bad --- /dev/null +++ b/internal/network/flags.go @@ -0,0 +1,52 @@ +/* +Copyright 2023 Red Hat Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +*/ + +package network + +import ( + "fmt" + "strings" + + "github.com/spf13/pflag" +) + +// AddListenerFlags adds to the given flag set the flags needed to configure a network listener. It +// receives the name of the listerner and the default address. For example, to configure an API +// listener: +// +// network.AddListenerFlags("API", "localhost:8000") +// +// The name will be converted to lower case to generate a prefix for the flags, and will be used +// unchanged as a prefix for the help text. The above example will result in the following flags: +// +// --api-listener-address string API listen address. (default "localhost:8000") +func AddListenerFlags(set *pflag.FlagSet, name, addr string) { + _ = set.String( + listenerFlagName(name, listenerAddrFlagSuffix), + addr, + fmt.Sprintf("%s listen address.", name), + ) +} + +// Names of the flags: +const ( + listenerAddrFlagSuffix = "listener-address" +) + +// listenerFlagName calculates a complete flag name from a listener name and a flag name suffix. +// For example, if the listener name is 'API' and the flag name suffix is 'listener-address' it +// returns 'api-listener-address'. +func listenerFlagName(name, suffix string) string { + return fmt.Sprintf("%s-%s", strings.ToLower(name), suffix) +} diff --git a/internal/network/listener.go b/internal/network/listener.go new file mode 100644 index 000000000..3fe78f28a --- /dev/null +++ b/internal/network/listener.go @@ -0,0 +1,105 @@ +/* +Copyright 2023 Red Hat Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +*/ + +package network + +import ( + "fmt" + "log/slog" + "net" + + "github.com/spf13/pflag" +) + +// ListenerBuilder contains the data and logic needed to create a network listener. Don't create +// instances of this object directly, use the NewListener function instead. +type ListenerBuilder struct { + logger *slog.Logger + network string + address string +} + +// NewListener creates a builder that can then used to configure and create a network listener. +func NewListener() *ListenerBuilder { + return &ListenerBuilder{ + network: "tcp", + } +} + +// SetLogger sets the logger that the listener will use to send messages to the log. This is +// mandatory. +func (b *ListenerBuilder) SetLogger(value *slog.Logger) *ListenerBuilder { + b.logger = value + return b +} + +// SetFlags sets the command line flags that should be used to configure the listener. +// +// The name is used to select the options when there are multiple listeners. For example, if it +// is 'API' then it will only take into accounts the flags starting with '--api'. +// +// This is optional. +func (b *ListenerBuilder) SetFlags(flags *pflag.FlagSet, name string) *ListenerBuilder { + if flags != nil { + listenerAddrFlagName := listenerFlagName(name, listenerAddrFlagSuffix) + value, err := flags.GetString(listenerAddrFlagName) + if err == nil { + b.SetAddress(value) + } + } + return b +} + +// SetNetwork sets the network. This is optional and the default is TCP. +func (b *ListenerBuilder) SetNetwork(value string) *ListenerBuilder { + b.network = value + return b +} + +// SetAddress sets the listen address. This is mandatory. +func (b *ListenerBuilder) SetAddress(value string) *ListenerBuilder { + b.address = value + return b +} + +// Build uses the data stored in the builder to create a new network listener. +func (b *ListenerBuilder) Build() (result net.Listener, err error) { + // Check parameters: + if b.logger == nil { + err = fmt.Errorf("logger is mandatory") + return + } + if b.network == "" { + err = fmt.Errorf("network is mandatory") + return + } + if b.address == "" { + err = fmt.Errorf("address is mandatory") + return + } + + // Create and populate the object: + result, err = net.Listen(b.network, b.address) + return +} + +// Common listener names: +const ( + APIListener = "API" +) + +// Common listener addresses: +const ( + APIAddress = "localhost:8000" +) diff --git a/internal/network/listener_test.go b/internal/network/listener_test.go new file mode 100644 index 000000000..4aad2d0c3 --- /dev/null +++ b/internal/network/listener_test.go @@ -0,0 +1,134 @@ +/* +Copyright (c) 2023 Red Hat, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +*/ + +package network + +import ( + "os" + "path/filepath" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" + "github.com/spf13/pflag" +) + +var _ = Describe("Listener", func() { + var tmp string + + BeforeEach(func() { + // In order to avoid TCP port conflicts these tests will use only Unix sockets + // created in this temporary directory: + var err error + tmp, err = os.MkdirTemp("", "*.sockets") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmp) + Expect(err).ToNot(HaveOccurred()) + }) + + It("Can't be created without a logger", func() { + address := filepath.Join(tmp, "my.socket") + listener, err := NewListener(). + SetNetwork("unix"). + SetAddress(address). + Build() + Expect(err).To(HaveOccurred()) + Expect(listener).To(BeNil()) + msg := err.Error() + Expect(msg).To(ContainSubstring("logger")) + Expect(msg).To(ContainSubstring("mandatory")) + }) + + It("Can't be created without an address", func() { + listener, err := NewListener(). + SetLogger(logger). + SetNetwork("unix"). + Build() + Expect(err).To(HaveOccurred()) + Expect(listener).To(BeNil()) + msg := err.Error() + Expect(msg).To(ContainSubstring("address")) + Expect(msg).To(ContainSubstring("mandatory")) + }) + + It("Can't be created with an incorrect address", func() { + listener, err := NewListener(). + SetLogger(logger). + SetAddress("junk"). + Build() + Expect(err).To(HaveOccurred()) + Expect(listener).To(BeNil()) + msg := err.Error() + Expect(msg).To(ContainSubstring("junk")) + }) + + It("Uses the given address", func() { + address := filepath.Join(tmp, "my.socket") + listener, err := NewListener(). + SetLogger(logger). + SetNetwork("unix"). + SetAddress(address). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(listener).ToNot(BeNil()) + Expect(listener.Addr().String()).To(Equal(address)) + }) + + It("Honors the address flag", func() { + // Prepare the flags: + address := filepath.Join(tmp, "my.socket") + flags := pflag.NewFlagSet("", pflag.ContinueOnError) + AddListenerFlags(flags, "my", "localhost:80") + err := flags.Parse([]string{ + "--my-listener-address", address, + }) + Expect(err).ToNot(HaveOccurred()) + + // Create the listener: + listener, err := NewListener(). + SetLogger(logger). + SetNetwork("unix"). + SetFlags(flags, "my"). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(listener).ToNot(BeNil()) + Expect(listener.Addr().String()).To(Equal(address)) + }) + + It("Ignores flags for other listeners", func() { + // Prepare the flags: + myAddress := filepath.Join(tmp, "my.socket") + yourAddress := filepath.Join(tmp, "your.socket") + flags := pflag.NewFlagSet("", pflag.ContinueOnError) + AddListenerFlags(flags, "my", "localhost:80") + AddListenerFlags(flags, "your", "localhost:81") + err := flags.Parse([]string{ + "--my-listener-address", myAddress, + "--your-listener-address", yourAddress, + }) + Expect(err).ToNot(HaveOccurred()) + + // Create the listener: + listener, err := NewListener(). + SetLogger(logger). + SetNetwork("unix"). + SetFlags(flags, "my"). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(listener).ToNot(BeNil()) + Expect(listener.Addr().String()).To(Equal(myAddress)) + }) +}) diff --git a/internal/network/suite_test.go b/internal/network/suite_test.go new file mode 100644 index 000000000..1b9f2fa74 --- /dev/null +++ b/internal/network/suite_test.go @@ -0,0 +1,41 @@ +/* +Copyright 2023 Red Hat Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +*/ + +package network + +import ( + "log/slog" + "testing" + + . "github.com/onsi/ginkgo/v2/dsl/core" + . "github.com/onsi/gomega" +) + +func TestNetwork(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Network") +} + +// Logger used for tests: +var logger *slog.Logger + +var _ = BeforeSuite(func() { + // Create a logger that writes to the Ginkgo writer, so that the log messages will be + // attached to the output of the right test: + options := &slog.HandlerOptions{ + Level: slog.LevelDebug, + } + handler := slog.NewJSONHandler(GinkgoWriter, options) + logger = slog.New(handler) +})