Skip to content

Commit

Permalink
multiple port support
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke Lombardi authored and Luke Lombardi committed Feb 18, 2025
1 parent f96f1bb commit e7c7b4f
Show file tree
Hide file tree
Showing 10 changed files with 722 additions and 187 deletions.
34 changes: 22 additions & 12 deletions pkg/abstractions/pod/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"sort"
"strconv"
"sync"
"time"

Expand All @@ -27,7 +28,7 @@ const (

type container struct {
id string
address string
addressMap map[int32]string
connections int
}

Expand Down Expand Up @@ -160,9 +161,16 @@ func (pb *PodProxyBuffer) handleConnection(conn *connection) {
}
defer clientConn.Close()

containerConn, err := network.ConnectToHost(request.Context(), container.address, containerDialTimeoutDurationS, pb.tailscale, pb.tsConfig)
portStr := conn.ctx.Param("port")
port, err := strconv.Atoi(portStr)
if err != nil {
log.Error().Msgf("Error dialing pod container %s: %s", container.address, err.Error())
conn.ctx.String(http.StatusBadRequest, "Invalid port")
return
}

containerConn, err := network.ConnectToHost(request.Context(), container.addressMap[int32(port)], containerDialTimeoutDurationS, pb.tailscale, pb.tsConfig)
if err != nil {
log.Error().Msgf("Error dialing pod container %s: %s", container.addressMap[int32(port)], err.Error())
return
}
defer containerConn.Close()
Expand All @@ -179,7 +187,7 @@ func (pb *PodProxyBuffer) handleConnection(conn *connection) {
// Ensure the request URL is correctly formatted for the proxy.
// We'll set container.address to the Host and put subPath into the Path field.
request.URL.Scheme = "http"
request.URL.Host = container.address
request.URL.Host = container.addressMap[int32(port)]

// Get subPath, ensure it starts with a slash, and assign it to the path portion.
subPath := conn.ctx.Param("subPath")
Expand Down Expand Up @@ -242,7 +250,7 @@ func (pb *PodProxyBuffer) discoverContainers() {
return
}

containerAddress, err := pb.containerRepo.GetContainerAddress(cs.ContainerId)
addressMap, err := pb.containerRepo.GetContainerAddressMap(cs.ContainerId)
if err != nil {
return
}
Expand All @@ -254,14 +262,16 @@ func (pb *PodProxyBuffer) discoverContainers() {

connections := currentConnections

if pb.checkContainerAvailable(containerAddress) {
availableContainersChan <- container{
id: cs.ContainerId,
address: containerAddress,
connections: connections,
}
for _, port := range pb.stubConfig.Ports {
if pb.checkContainerAvailable(addressMap[int32(port)]) {
availableContainersChan <- container{
id: cs.ContainerId,
addressMap: addressMap,
connections: connections,
}

return
return
}
}
}(containerState)
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/common/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var (
schedulerContainerConfig string = "scheduler:container:config:%s"
schedulerContainerState string = "scheduler:container:state:%s"
schedulerContainerAddress string = "scheduler:container:container_addr:%s"
schedulerContainerAddressMap string = "scheduler:container:container_addr_map:%s"
schedulerContainerIndex string = "scheduler:container:index:%s"
schedulerContainerWorkerIndex string = "scheduler:container:worker:index:%s"
schedulerContainerWorkspaceIndex string = "scheduler:container:workspace:index:%s"
Expand Down Expand Up @@ -141,6 +142,10 @@ func (rk *redisKeys) SchedulerContainerAddress(containerId string) string {
return fmt.Sprintf(schedulerContainerAddress, containerId)
}

func (rk *redisKeys) SchedulerContainerAddressMap(containerId string) string {
return fmt.Sprintf(schedulerContainerAddressMap, containerId)
}

func (rk *redisKeys) SchedulerWorkerAddress(containerId string) string {
return fmt.Sprintf(schedulerWorkerAddress, containerId)
}
Expand Down
28 changes: 28 additions & 0 deletions pkg/gateway/services/repository/container_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ func (s *ContainerRepositoryService) SetContainerAddress(ctx context.Context, re
return &pb.SetContainerAddressResponse{Ok: true}, nil
}

func (s *ContainerRepositoryService) SetContainerAddressMap(ctx context.Context, req *pb.SetContainerAddressMapRequest) (*pb.SetContainerAddressMapResponse, error) {
addressMap := make(map[int32]string)
for k, v := range req.AddressMap {
addressMap[int32(k)] = v
}

err := s.containerRepo.SetContainerAddressMap(req.ContainerId, addressMap)
if err != nil {
return &pb.SetContainerAddressMapResponse{Ok: false, ErrorMsg: err.Error()}, nil
}

return &pb.SetContainerAddressMapResponse{Ok: true}, nil
}

func (s *ContainerRepositoryService) GetContainerAddressMap(ctx context.Context, req *pb.GetContainerAddressMapRequest) (*pb.GetContainerAddressMapResponse, error) {
addressMap, err := s.containerRepo.GetContainerAddressMap(req.ContainerId)
if err != nil {
return &pb.GetContainerAddressMapResponse{Ok: false, ErrorMsg: err.Error()}, nil
}

protoMap := make(map[int32]string)
for k, v := range addressMap {
protoMap[int32(k)] = v
}

return &pb.GetContainerAddressMapResponse{Ok: true, AddressMap: protoMap}, nil
}

func (s *ContainerRepositoryService) SetWorkerAddress(ctx context.Context, req *pb.SetWorkerAddressRequest) (*pb.SetWorkerAddressResponse, error) {
err := s.containerRepo.SetWorkerAddress(req.ContainerId, req.Address)
if err != nil {
Expand Down
24 changes: 24 additions & 0 deletions pkg/gateway/services/repository/container_repo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ service ContainerRepositoryService {
returns (SetContainerExitCodeResponse);
rpc SetContainerAddress(SetContainerAddressRequest)
returns (SetContainerAddressResponse);
rpc SetContainerAddressMap(SetContainerAddressMapRequest)
returns (SetContainerAddressMapResponse);
rpc GetContainerAddressMap(GetContainerAddressMapRequest)
returns (GetContainerAddressMapResponse);
rpc SetWorkerAddress(SetWorkerAddressRequest)
returns (SetWorkerAddressResponse);
rpc UpdateCheckpointState(UpdateCheckpointStateRequest)
Expand Down Expand Up @@ -70,6 +74,26 @@ message SetContainerAddressResponse {
string error_msg = 2;
}

message SetContainerAddressMapRequest {
string container_id = 1;
map<int32, string> address_map = 2;
}

message SetContainerAddressMapResponse {
bool ok = 1;
string error_msg = 2;
}

message GetContainerAddressMapRequest {
string container_id = 1;
}

message GetContainerAddressMapResponse {
bool ok = 1;
map<int32, string> address_map = 2;
string error_msg = 3;
}

message SetWorkerAddressRequest {
string container_id = 1;
string address = 2;
Expand Down
2 changes: 2 additions & 0 deletions pkg/repository/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type ContainerRepository interface {
DeleteContainerState(containerId string) error
SetWorkerAddress(containerId string, addr string) error
GetWorkerAddress(ctx context.Context, containerId string) (string, error)
SetContainerAddressMap(containerId string, addressMap map[int32]string) error
GetContainerAddressMap(containerId string) (map[int32]string, error)
SetContainerStateWithConcurrencyLimit(quota *types.ConcurrencyLimit, request *types.ContainerRequest) error
GetActiveContainersByStubId(stubId string) ([]types.ContainerState, error)
GetActiveContainersByWorkspaceId(workspaceId string) ([]types.ContainerState, error)
Expand Down
33 changes: 33 additions & 0 deletions pkg/repository/container_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package repository

import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -237,6 +238,38 @@ func (cr *ContainerRedisRepository) GetContainerAddress(containerId string) (str
return cr.rdb.Get(context.TODO(), common.RedisKeys.SchedulerContainerAddress(containerId)).Result()
}

func (cr *ContainerRedisRepository) SetContainerAddressMap(containerId string, addressMap map[int32]string) error {
data, err := json.Marshal(addressMap)
if err != nil {
return fmt.Errorf("failed to marshal addressMap for container %s: %w", containerId, err)
}

err = cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerContainerAddressMap(containerId), data, 0).Err()
if err != nil {
return fmt.Errorf("failed to set container addressMap for container %s: %w", containerId, err)
}

return nil
}

func (cr *ContainerRedisRepository) GetContainerAddressMap(containerId string) (map[int32]string, error) {
data, err := cr.rdb.Get(context.TODO(), common.RedisKeys.SchedulerContainerAddressMap(containerId)).Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}

return nil, fmt.Errorf("failed to get container addressMap for container %s: %w", containerId, err)
}

addressMap := make(map[int32]string)
if err := json.Unmarshal(data, &addressMap); err != nil {
return nil, fmt.Errorf("failed to unmarshal addressMap for container %s: %w", containerId, err)
}

return addressMap, nil
}

func (cr *ContainerRedisRepository) SetWorkerAddress(containerId string, addr string) error {
return cr.rdb.Set(context.TODO(), common.RedisKeys.SchedulerWorkerAddress(containerId), addr, 0).Err()
}
Expand Down
65 changes: 48 additions & 17 deletions pkg/worker/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,30 @@ func (s *Worker) RunContainer(ctx context.Context, request *types.ContainerReque
}
}

bindPort, err := getRandomFreePort()
if err != nil {
return err
// Determine how many ports we need
portsToExpose := len(request.Ports)
if portsToExpose == 0 {
portsToExpose = 1
request.Ports = []uint32{uint32(containerInnerPort)}
}
log.Info().Str("container_id", containerId).Msgf("acquired port: %d", bindPort)

bindPorts := make([]int, 0, portsToExpose)
for i := 0; i < portsToExpose; i++ {
bindPort, err := getRandomFreePort()
if err != nil {
return err
}
bindPorts = append(bindPorts, bindPort)
}

log.Info().Str("container_id", containerId).Msgf("acquired ports: %v", bindPorts)

// Read spec from bundle
initialBundleSpec, _ := s.readBundleConfig(request.ImageId, request.IsBuildRequest())

opts := &ContainerOptions{
BundlePath: bundlePath,
BindPort: bindPort,
BindPorts: bindPorts,
InitialSpec: initialBundleSpec,
}

Expand All @@ -239,15 +251,31 @@ func (s *Worker) RunContainer(ctx context.Context, request *types.ContainerReque

// Set an address (ip:port) for the pod/container in Redis. Depending on the stub type,
// gateway may need to directly interact with this pod/container.
containerAddr := fmt.Sprintf("%s:%d", s.podAddr, bindPort)
containerAddr := fmt.Sprintf("%s:%d", s.podAddr, opts.BindPorts[0])
_, err = handleGRPCResponse(s.containerRepoClient.SetContainerAddress(context.Background(), &pb.SetContainerAddressRequest{
ContainerId: request.ContainerId,
Address: containerAddr,
}))
if err != nil {
return err
}
log.Info().Str("container_id", containerId).Msg("set container address")
log.Info().Str("container_id", containerId).Msgf("set container address: %s", containerAddr)

// Set container address map
addressMap := make(map[int32]string)
for idx, containerPort := range request.Ports {
addressMap[int32(containerPort)] = fmt.Sprintf("%s:%d", s.podAddr, opts.BindPorts[idx])
}
_, err = handleGRPCResponse(s.containerRepoClient.SetContainerAddressMap(context.Background(), &pb.SetContainerAddressMapRequest{
ContainerId: request.ContainerId,
AddressMap: addressMap,
}))

if err != nil {
return err
}

log.Info().Str("container_id", containerId).Msgf("set container address map: %v", addressMap)

go s.containerWg.Add(1)

Expand Down Expand Up @@ -356,7 +384,7 @@ func (s *Worker) specFromRequest(request *types.ContainerRequest, options *Conta
return nil, err
}

containerHostname := fmt.Sprintf("%s:%d", s.podAddr, options.BindPort)
containerHostname := fmt.Sprintf("%s:%d", s.podAddr, options.BindPorts[0])
containerHostnamePath := filepath.Join(checkpointSignalDir(request.ContainerId), checkpointContainerHostnameFileName)
err = os.WriteFile(containerHostnamePath, []byte(containerHostname), 0644)
if err != nil {
Expand Down Expand Up @@ -449,7 +477,7 @@ func (s *Worker) getContainerEnvironment(request *types.ContainerRequest, option
// Most of these env vars are required to communicate with the gateway and vice versa
env := []string{
fmt.Sprintf("BIND_PORT=%d", containerInnerPort),
fmt.Sprintf("CONTAINER_HOSTNAME=%s", fmt.Sprintf("%s:%d", s.podAddr, options.BindPort)),
fmt.Sprintf("CONTAINER_HOSTNAME=%s", fmt.Sprintf("%s:%d", s.podAddr, options.BindPorts[0])),
fmt.Sprintf("CONTAINER_ID=%s", request.ContainerId),
fmt.Sprintf("BETA9_GATEWAY_HOST=%s", os.Getenv("BETA9_GATEWAY_HOST")),
fmt.Sprintf("BETA9_GATEWAY_PORT=%s", os.Getenv("BETA9_GATEWAY_PORT")),
Expand Down Expand Up @@ -561,16 +589,19 @@ func (s *Worker) spawn(request *types.ContainerRequest, spec *specs.Spec, output
return
}

innerPort := containerInnerPort
if len(request.Ports) > 0 {
innerPort = int(request.Ports[0])
portsToExpose := len(request.Ports)
if portsToExpose == 0 {
portsToExpose = 1
request.Ports = []uint32{uint32(containerInnerPort)}
}

// Expose the bind port
err = s.containerNetworkManager.ExposePort(containerId, opts.BindPort, innerPort)
if err != nil {
log.Error().Str("container_id", containerId).Msgf("failed to expose container bind port: %v", err)
return
// Expose the bind ports
for idx, bindPort := range opts.BindPorts {
err = s.containerNetworkManager.ExposePort(containerId, bindPort, int(request.Ports[idx]))
if err != nil {
log.Error().Str("container_id", containerId).Msgf("failed to expose container bind port: %v", err)
return
}
}

if request.RequiresGPU() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type ContainerInstance struct {

type ContainerOptions struct {
BundlePath string
BindPort int
BindPorts []int
InitialSpec *specs.Spec
}

Expand Down
Loading

0 comments on commit e7c7b4f

Please sign in to comment.