diff --git a/src/go/guestagent/main.go b/src/go/guestagent/main.go index a93451f5d12..cad62aea705 100644 --- a/src/go/guestagent/main.go +++ b/src/go/guestagent/main.go @@ -110,7 +110,7 @@ func main() { var portTracker tracker.Tracker forwarder := forwarder.NewWSLProxyForwarder("/run/wsl-proxy.sock") - portTracker = tracker.NewAPITracker(forwarder, tracker.GatewayBaseURL, *tapIfaceIP, *adminInstall) + portTracker = tracker.NewAPITracker(ctx, forwarder, tracker.GatewayBaseURL, *tapIfaceIP, *adminInstall, *enableIptables) // Manually register the port for K8s API, we would // only want to send this manual port mapping if both // of the following conditions are met: @@ -183,7 +183,6 @@ func main() { err := kube.WatchForServices(ctx, *configPath, k8sServiceListenerIP, - *enableIptables, portTracker) if err != nil { return fmt.Errorf("error watching services: %w", err) diff --git a/src/go/guestagent/pkg/kube/watcher_linux.go b/src/go/guestagent/pkg/kube/watcher_linux.go index 0e86e9912f5..b8ff465c013 100644 --- a/src/go/guestagent/pkg/kube/watcher_linux.go +++ b/src/go/guestagent/pkg/kube/watcher_linux.go @@ -58,7 +58,6 @@ func WatchForServices( ctx context.Context, configPath string, k8sServiceListenerIP net.IP, - enableListeners bool, portTracker tracker.Tracker, ) error { // These variables are shared across the different states @@ -154,24 +153,6 @@ func WatchForServices( continue case event := <-eventCh: if event.deleted { - if enableListeners { - for port := range event.portMapping { - if err := portTracker.RemoveListener(ctx, k8sServiceListenerIP, int(port)); err != nil { - log.Errorw("failed to close listener", log.Fields{ - "error": err, - "ports": event.portMapping, - "namespace": event.namespace, - "name": event.name, - }) - } - } - - log.Debugf("kubernetes service: deleted listener %s/%s:%v", - event.namespace, event.name, event.portMapping) - - continue - } - if err := portTracker.Remove(string(event.UID)); err != nil { log.Errorw("failed to delete a port from tracker", log.Fields{ "error": err, @@ -185,23 +166,6 @@ func WatchForServices( event.namespace, event.name, event.portMapping) } } else { - if enableListeners { - for port := range event.portMapping { - if err := portTracker.AddListener(ctx, k8sServiceListenerIP, int(port)); err != nil { - log.Errorw("failed to create listener", log.Fields{ - "error": err, - "ports": event.portMapping, - "namespace": event.namespace, - "name": event.name, - }) - } - } - - log.Debugf("kubernetes service: started listener %s/%s:%v", - event.namespace, event.name, event.portMapping) - - continue - } portMapping, err := createPortMapping(event.portMapping, k8sServiceListenerIP) if err != nil { log.Errorw("failed to create port mapping", log.Fields{ diff --git a/src/go/guestagent/pkg/kube/watcher_stub.go b/src/go/guestagent/pkg/kube/watcher_stub.go index 1b68e4fb274..0a73862d421 100644 --- a/src/go/guestagent/pkg/kube/watcher_stub.go +++ b/src/go/guestagent/pkg/kube/watcher_stub.go @@ -27,7 +27,6 @@ func WatchForServices( ctx context.Context, configPath string, k8sServiceListenerIP net.IP, - enableListeners bool, portTracker tracker.Tracker, ) error { return fmt.Errorf("not implemented for non-linux") diff --git a/src/go/guestagent/pkg/tracker/apitracker.go b/src/go/guestagent/pkg/tracker/apitracker.go index 1f0a06984a4..d183a24a261 100644 --- a/src/go/guestagent/pkg/tracker/apitracker.go +++ b/src/go/guestagent/pkg/tracker/apitracker.go @@ -14,9 +14,11 @@ limitations under the License. package tracker import ( + "context" "errors" "fmt" "net" + "strconv" "github.com/Masterminds/log-go" "github.com/containers/gvisor-tap-vsock/pkg/types" @@ -42,8 +44,10 @@ var ( // and unexposing the ports on the host. This should only be used when // the Rancher Desktop networking is enabled and the privileged service is disabled. type APITracker struct { + context context.Context wslProxyForwarder forwarder.Forwarder isAdmin bool + enableListeners bool baseURL string tapInterfaceIP string portStorage *portStorage @@ -52,10 +56,12 @@ type APITracker struct { } // NewAPITracker creates a new instance of a API Tracker. -func NewAPITracker(wslProxyForwarder forwarder.Forwarder, baseURL, tapIfaceIP string, isAdmin bool) *APITracker { +func NewAPITracker(ctx context.Context, wslProxyForwarder forwarder.Forwarder, baseURL, tapIfaceIP string, isAdmin, enableListeners bool) *APITracker { return &APITracker{ + context: ctx, wslProxyForwarder: wslProxyForwarder, isAdmin: isAdmin, + enableListeners: enableListeners, baseURL: baseURL, tapInterfaceIP: tapIfaceIP, portStorage: newPortStorage(), @@ -82,6 +88,18 @@ func (a *APITracker) Add(containerID string, portMap nat.PortMap) error { continue } + if a.enableListeners { + hostPort, err := strconv.Atoi(portBinding.HostPort) + if err != nil { + log.Errorf("error converting hostPort: %s", err) + continue + } + if err := a.AddListener(a.context, net.IP(portBinding.HostIP), hostPort); err != nil { + log.Errorf("creating listener for %s and %s failed: %s", portBinding.HostIP, portBinding.HostPort, err) + continue + } + } + log.Debugf("calling /services/forwarder/expose API for the following port binding: %+v", portBinding) err = a.apiForwarder.Expose( @@ -142,6 +160,18 @@ func (a *APITracker) Remove(containerID string) error { continue } + if a.enableListeners { + hostPort, err := strconv.Atoi(portBinding.HostPort) + if err != nil { + log.Errorf("error converting hostPort: %s", err) + continue + } + if err := a.RemoveListener(a.context, net.IP(portBinding.HostIP), hostPort); err != nil { + log.Errorf("removing listener for %s and %s failed: %s", portBinding.HostIP, portBinding.HostPort, err) + continue + } + } + log.Debugf("calling /services/forwarder/expose API for the following port binding: %+v", portBinding) err = a.apiForwarder.Unexpose( diff --git a/src/go/guestagent/pkg/tracker/apitracker_test.go b/src/go/guestagent/pkg/tracker/apitracker_test.go index 97383363839..3d5e2714b94 100644 --- a/src/go/guestagent/pkg/tracker/apitracker_test.go +++ b/src/go/guestagent/pkg/tracker/apitracker_test.go @@ -14,6 +14,7 @@ limitations under the License. package tracker_test import ( + "context" "encoding/json" "fmt" "net/http" @@ -55,7 +56,7 @@ func TestBasicAdd(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping := nat.PortMap{ "80/tcp": []nat.PortBinding{ { @@ -91,7 +92,7 @@ func TestAddOverride(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping := nat.PortMap{ "80/tcp": []nat.PortBinding{ { @@ -182,7 +183,7 @@ func TestAddWithError(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping := nat.PortMap{ "80/tcp": []nat.PortBinding{ { @@ -277,7 +278,7 @@ func TestGet(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) err := apiTracker.Add(containerID, portMapping) require.NoError(t, err) @@ -305,7 +306,7 @@ func TestRemove(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping1 := nat.PortMap{ "80/tcp": []nat.PortBinding{ { @@ -369,7 +370,7 @@ func TestRemoveWithError(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping := nat.PortMap{ "80/tcp": []nat.PortBinding{ @@ -429,7 +430,7 @@ func TestRemoveAll(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping1 := nat.PortMap{ "80/tcp": []nat.PortBinding{ @@ -493,7 +494,7 @@ func TestRemoveAllWithError(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false) portMapping1 := nat.PortMap{ "80/tcp": []nat.PortBinding{ @@ -573,7 +574,7 @@ func TestNonAdminInstall(t *testing.T) { testSrv := httptest.NewServer(mux) defer testSrv.Close() - apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, false) + apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, false, false) portMapping := nat.PortMap{ "1025/tcp": []nat.PortBinding{