Skip to content

Commit

Permalink
Move Add/RemoveListener to APITracker
Browse files Browse the repository at this point in the history
When enableListeners flag is enabled we would also
call the host-switch service API in addition to adding
and removing the net listeners.

Signed-off-by: Nino Kodabande <[email protected]>
  • Loading branch information
Nino-K committed Sep 10, 2024
1 parent 54483db commit dc7e712
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 49 deletions.
3 changes: 1 addition & 2 deletions src/go/guestagent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func main() {
var portTracker tracker.Tracker

forwarder := forwarder.NewWSLProxyForwarder("/run/wsl-proxy.sock")
portTracker = tracker.NewAPITracker(forwarder, tracker.GatewayBaseURL, *tapIfaceIP, *adminInstall)
portTracker = tracker.NewAPITracker(ctx, forwarder, tracker.GatewayBaseURL, *tapIfaceIP, *adminInstall, *enableIptables)
// Manually register the port for K8s API, we would
// only want to send this manual port mapping if both
// of the following conditions are met:
Expand Down Expand Up @@ -183,7 +183,6 @@ func main() {
err := kube.WatchForServices(ctx,
*configPath,
k8sServiceListenerIP,
*enableIptables,
portTracker)
if err != nil {
return fmt.Errorf("error watching services: %w", err)
Expand Down
36 changes: 0 additions & 36 deletions src/go/guestagent/pkg/kube/watcher_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ func WatchForServices(
ctx context.Context,
configPath string,
k8sServiceListenerIP net.IP,
enableListeners bool,
portTracker tracker.Tracker,
) error {
// These variables are shared across the different states
Expand Down Expand Up @@ -154,24 +153,6 @@ func WatchForServices(
continue
case event := <-eventCh:
if event.deleted {
if enableListeners {
for port := range event.portMapping {
if err := portTracker.RemoveListener(ctx, k8sServiceListenerIP, int(port)); err != nil {
log.Errorw("failed to close listener", log.Fields{
"error": err,
"ports": event.portMapping,
"namespace": event.namespace,
"name": event.name,
})
}
}

log.Debugf("kubernetes service: deleted listener %s/%s:%v",
event.namespace, event.name, event.portMapping)

continue
}

if err := portTracker.Remove(string(event.UID)); err != nil {
log.Errorw("failed to delete a port from tracker", log.Fields{
"error": err,
Expand All @@ -185,23 +166,6 @@ func WatchForServices(
event.namespace, event.name, event.portMapping)
}
} else {
if enableListeners {
for port := range event.portMapping {
if err := portTracker.AddListener(ctx, k8sServiceListenerIP, int(port)); err != nil {
log.Errorw("failed to create listener", log.Fields{
"error": err,
"ports": event.portMapping,
"namespace": event.namespace,
"name": event.name,
})
}
}

log.Debugf("kubernetes service: started listener %s/%s:%v",
event.namespace, event.name, event.portMapping)

continue
}
portMapping, err := createPortMapping(event.portMapping, k8sServiceListenerIP)
if err != nil {
log.Errorw("failed to create port mapping", log.Fields{
Expand Down
1 change: 0 additions & 1 deletion src/go/guestagent/pkg/kube/watcher_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func WatchForServices(
ctx context.Context,
configPath string,
k8sServiceListenerIP net.IP,
enableListeners bool,
portTracker tracker.Tracker,
) error {
return fmt.Errorf("not implemented for non-linux")
Expand Down
32 changes: 31 additions & 1 deletion src/go/guestagent/pkg/tracker/apitracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ limitations under the License.
package tracker

import (
"context"
"errors"
"fmt"
"net"
"strconv"

"github.com/Masterminds/log-go"
"github.com/containers/gvisor-tap-vsock/pkg/types"
Expand All @@ -42,8 +44,10 @@ var (
// and unexposing the ports on the host. This should only be used when
// the Rancher Desktop networking is enabled and the privileged service is disabled.
type APITracker struct {
context context.Context
wslProxyForwarder forwarder.Forwarder
isAdmin bool
enableListeners bool
baseURL string
tapInterfaceIP string
portStorage *portStorage
Expand All @@ -52,10 +56,12 @@ type APITracker struct {
}

// NewAPITracker creates a new instance of a API Tracker.
func NewAPITracker(wslProxyForwarder forwarder.Forwarder, baseURL, tapIfaceIP string, isAdmin bool) *APITracker {
func NewAPITracker(ctx context.Context, wslProxyForwarder forwarder.Forwarder, baseURL, tapIfaceIP string, isAdmin, enableListeners bool) *APITracker {
return &APITracker{
context: ctx,
wslProxyForwarder: wslProxyForwarder,
isAdmin: isAdmin,
enableListeners: enableListeners,
baseURL: baseURL,
tapInterfaceIP: tapIfaceIP,
portStorage: newPortStorage(),
Expand All @@ -82,6 +88,18 @@ func (a *APITracker) Add(containerID string, portMap nat.PortMap) error {
continue
}

if a.enableListeners {
hostPort, err := strconv.Atoi(portBinding.HostPort)
if err != nil {
log.Errorf("error converting hostPort: %s", err)
continue
}
if err := a.AddListener(a.context, net.IP(portBinding.HostIP), hostPort); err != nil {
log.Errorf("creating listener for %s and %s failed: %s", portBinding.HostIP, portBinding.HostPort, err)
continue
}
}

log.Debugf("calling /services/forwarder/expose API for the following port binding: %+v", portBinding)

err = a.apiForwarder.Expose(
Expand Down Expand Up @@ -142,6 +160,18 @@ func (a *APITracker) Remove(containerID string) error {
continue
}

if a.enableListeners {
hostPort, err := strconv.Atoi(portBinding.HostPort)
if err != nil {
log.Errorf("error converting hostPort: %s", err)
continue
}
if err := a.RemoveListener(a.context, net.IP(portBinding.HostIP), hostPort); err != nil {
log.Errorf("removing listener for %s and %s failed: %s", portBinding.HostIP, portBinding.HostPort, err)
continue
}
}

log.Debugf("calling /services/forwarder/expose API for the following port binding: %+v", portBinding)

err = a.apiForwarder.Unexpose(
Expand Down
19 changes: 10 additions & 9 deletions src/go/guestagent/pkg/tracker/apitracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package tracker_test

import (
"context"
"encoding/json"
"fmt"
"net/http"
Expand Down Expand Up @@ -55,7 +56,7 @@ func TestBasicAdd(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)
portMapping := nat.PortMap{
"80/tcp": []nat.PortBinding{
{
Expand Down Expand Up @@ -91,7 +92,7 @@ func TestAddOverride(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)
portMapping := nat.PortMap{
"80/tcp": []nat.PortBinding{
{
Expand Down Expand Up @@ -182,7 +183,7 @@ func TestAddWithError(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)
portMapping := nat.PortMap{
"80/tcp": []nat.PortBinding{
{
Expand Down Expand Up @@ -277,7 +278,7 @@ func TestGet(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)
err := apiTracker.Add(containerID, portMapping)
require.NoError(t, err)

Expand Down Expand Up @@ -305,7 +306,7 @@ func TestRemove(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)
portMapping1 := nat.PortMap{
"80/tcp": []nat.PortBinding{
{
Expand Down Expand Up @@ -369,7 +370,7 @@ func TestRemoveWithError(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)

portMapping := nat.PortMap{
"80/tcp": []nat.PortBinding{
Expand Down Expand Up @@ -429,7 +430,7 @@ func TestRemoveAll(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)

portMapping1 := nat.PortMap{
"80/tcp": []nat.PortBinding{
Expand Down Expand Up @@ -493,7 +494,7 @@ func TestRemoveAllWithError(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, true)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, true, false)

portMapping1 := nat.PortMap{
"80/tcp": []nat.PortBinding{
Expand Down Expand Up @@ -573,7 +574,7 @@ func TestNonAdminInstall(t *testing.T) {
testSrv := httptest.NewServer(mux)
defer testSrv.Close()

apiTracker := tracker.NewAPITracker(&testForwarder{}, testSrv.URL, hostSwitchIP, false)
apiTracker := tracker.NewAPITracker(context.Background(), &testForwarder{}, testSrv.URL, hostSwitchIP, false, false)

portMapping := nat.PortMap{
"1025/tcp": []nat.PortBinding{
Expand Down

0 comments on commit dc7e712

Please sign in to comment.