Skip to content

Commit

Permalink
add sagemaker-hyperpod compute type to resolve its pods via VPC ENI (k…
Browse files Browse the repository at this point in the history
…ubernetes-sigs#3886)

* add sagemaker-hyperpod compute type to resolve its pods via VPC ENI

* consolidate fargate/hyperpod pod flags in resolveViaCascadedLookup into isNonEc2Pod flag

* introduce PodsByComputeType struct
  • Loading branch information
amber-liu-amzn authored Oct 23, 2024
1 parent 75b5793 commit 2e1688b
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 18 deletions.
61 changes: 43 additions & 18 deletions pkg/networking/pod_eni_info_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ const (
// EC2:DescribeNetworkInterface supports up to 200 filters per call.
describeNetworkInterfacesFiltersLimit = 200

labelEKSComputeType = "eks.amazonaws.com/compute-type"
labelEKSComputeType = "eks.amazonaws.com/compute-type"
labelSageMakerComputeType = "sagemaker.amazonaws.com/compute-type"
)

// PodENIInfoResolver is responsible for resolve the AWS VPC ENI that supports pod network.
Expand Down Expand Up @@ -141,20 +142,20 @@ func (r *defaultPodENIInfoResolver) saveENIInfosToCache(pods []k8s.PodInfo, eniI
}

func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
podsOnEc2, podsOnFargate, err := r.classifyPodsByComputeType(ctx, pods)
podsByComputeType, err := r.classifyPodsByComputeType(ctx, pods)
if err != nil {
return nil, err
}
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
if len(podsOnEc2) > 0 {
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsOnEc2, false)
if len(podsByComputeType.ec2Pods) > 0 {
eniInfoByPodKeyEc2, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.ec2Pods, false)
if err != nil {
return nil, err
}
eniInfoByPodKey = eniInfoByPodKeyEc2
}
if len(podsOnFargate) > 0 {
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsOnFargate, true)
if len(podsByComputeType.fargatePods) > 0 {
eniInfoByPodKeyFargate, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.fargatePods, true)
if err != nil {
return nil, err
}
Expand All @@ -164,17 +165,28 @@ func (r *defaultPodENIInfoResolver) resolvePodsViaCascadedLookup(ctx context.Con
}
}
}
if len(podsByComputeType.sageMakerHyperPodPods) > 0 {
eniInfoByPodKeySageMakerHyperPod, err := r.resolveViaCascadedLookup(ctx, podsByComputeType.sageMakerHyperPodPods, true)
if err != nil {
return nil, err
}
if len(eniInfoByPodKeySageMakerHyperPod) > 0 {
for podKey, eniInfo := range eniInfoByPodKeySageMakerHyperPod {
eniInfoByPodKey[podKey] = eniInfo
}
}
}
return eniInfoByPodKey, nil
}

func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isFargateNode bool) (map[types.NamespacedName]ENIInfo, error) {
func (r *defaultPodENIInfoResolver) resolveViaCascadedLookup(ctx context.Context, pods []k8s.PodInfo, isNonEc2Pod bool) (map[types.NamespacedName]ENIInfo, error) {
eniInfoByPodKey := make(map[types.NamespacedName]ENIInfo)
resolveFuncs := []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
r.resolveViaPodENIAnnotation,
r.resolveViaNodeENIs,
// TODO, add support for kubenet CNI plugin(kops) by resolve via routeTable.
}
if isFargateNode {
if isNonEc2Pod {
resolveFuncs = []func(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error){
r.resolveViaVPCENIs,
}
Expand Down Expand Up @@ -281,6 +293,7 @@ func (r *defaultPodENIInfoResolver) resolveViaNodeENIs(ctx context.Context, pods

// resolveViaVPCENIs tries to resolve pod ENI by matching podIP against ENIs in vpc.
// with EKS fargate pods, podIP is supported by an ENI in vpc.
// with SageMaker HyperPod pods, podIP is supported by the visible cross-account ENI in customer vpc.
func (r *defaultPodENIInfoResolver) resolveViaVPCENIs(ctx context.Context, pods []k8s.PodInfo) (map[types.NamespacedName]ENIInfo, error) {
podKeysByIP := make(map[string][]types.NamespacedName, len(pods))
for _, pod := range pods {
Expand Down Expand Up @@ -388,33 +401,45 @@ func (r *defaultPodENIInfoResolver) isPodSupportedByNodeENI(pod k8s.PodInfo, nod
return false
}

// classifyPodsByComputeType classifies in to ec2 and fargate groups
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) ([]k8s.PodInfo, []k8s.PodInfo, error) {
podsOnFargate := make([]k8s.PodInfo, 0, len(pods))
podsOnEc2 := make([]k8s.PodInfo, 0, len(pods))
// PodsByComputeType groups pods based on their compute type (EC2, Fargate, SageMaker HyperPod)
type PodsByComputeType struct {
ec2Pods []k8s.PodInfo
fargatePods []k8s.PodInfo
sageMakerHyperPodPods []k8s.PodInfo
}

// classifyPodsByComputeType classifies in to ec2, fargate and sagemaker-hyperpod groups
func (r *defaultPodENIInfoResolver) classifyPodsByComputeType(ctx context.Context, pods []k8s.PodInfo) (PodsByComputeType, error) {
var podsByComputeType PodsByComputeType
nodeNameByComputeType := make(map[string]string)
for _, pod := range pods {
if _, exists := nodeNameByComputeType[pod.NodeName]; exists {
if nodeNameByComputeType[pod.NodeName] == "fargate" {
podsOnFargate = append(podsOnFargate, pod)
podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod)
} else if nodeNameByComputeType[pod.NodeName] == "sagemaker-hyperpod" {
podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod)
} else {
podsOnEc2 = append(podsOnEc2, pod)
podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod)
}
}

nodeKey := types.NamespacedName{Name: pod.NodeName}
node := &corev1.Node{}
if err := r.k8sClient.Get(ctx, nodeKey, node); err != nil {
return nil, nil, err
return PodsByComputeType{}, err
}
if node.Labels[labelEKSComputeType] == "fargate" {
podsOnFargate = append(podsOnFargate, pod)
podsByComputeType.fargatePods = append(podsByComputeType.fargatePods, pod)
nodeNameByComputeType[pod.NodeName] = "fargate"
} else if node.Labels[labelSageMakerComputeType] == "hyperpod" {
podsByComputeType.sageMakerHyperPodPods = append(podsByComputeType.sageMakerHyperPodPods, pod)
nodeNameByComputeType[pod.NodeName] = "sagemaker-hyperpod"
} else {
podsOnEc2 = append(podsOnEc2, pod)
podsByComputeType.ec2Pods = append(podsByComputeType.ec2Pods, pod)
nodeNameByComputeType[pod.NodeName] = "ec2"
}
}
return podsOnEc2, podsOnFargate, nil
return podsByComputeType, nil
}

// computePodENIInfoCacheKey computes the cacheKey for pod's ENIInfo cache.
Expand Down
181 changes: 181 additions & 0 deletions pkg/networking/pod_eni_info_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,187 @@ func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_Fargate(t *testing.
}
}

func Test_defaultPodENIInfoResolver_resolveViaCascadedLookup_SageMakerHyperPod(t *testing.T) {
hyperPodNodeA := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "hyperpod-i-04442beca624ba65b",
Labels: map[string]string{
"sagemaker.amazonaws.com/compute-type": "hyperpod",
},
},
Spec: corev1.NodeSpec{
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04442beca624ba65b",
},
}
hyperPodNodeB := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "hyperpod-i-04159267183583d03",
Labels: map[string]string{
"sagemaker.amazonaws.com/compute-type": "hyperpod",
},
},
Spec: corev1.NodeSpec{
ProviderID: "aws:///usw2-az2/sagemaker/cluster/hyperpod-xxxxxxxxxxxx-i-04159267183583d03",
},
}
type describeNetworkInterfacesAsListCall struct {
req *ec2sdk.DescribeNetworkInterfacesInput
resp []ec2types.NetworkInterface
err error
}
type fetchNodeInstancesCall struct {
nodes []*corev1.Node
nodeInstanceByNodeKey map[types.NamespacedName]*ec2types.Instance
err error
}
type env struct {
nodes []*corev1.Node
}
type fields struct {
describeNetworkInterfacesAsListCalls []describeNetworkInterfacesAsListCall
fetchNodeInstancesCalls []fetchNodeInstancesCall
}
type args struct {
pods []k8s.PodInfo
}
tests := []struct {
name string
env env
fields fields
args args
want map[types.NamespacedName]ENIInfo
wantErr error
}{
{
name: "all pod's ENI resolved via VPC's ENIs",
env: env{
nodes: []*corev1.Node{hyperPodNodeA, hyperPodNodeB},
},
fields: fields{
describeNetworkInterfacesAsListCalls: []describeNetworkInterfacesAsListCall{
{
req: &ec2sdk.DescribeNetworkInterfacesInput{
Filters: []ec2types.Filter{
{
Name: awssdk.String("vpc-id"),
Values: []string{"vpc-0d6d9ee10bd062dcc"},
},
{
Name: awssdk.String("addresses.private-ip-address"),
Values: []string{"192.168.128.151", "192.168.128.152"},
},
},
},
resp: []ec2types.NetworkInterface{
{
NetworkInterfaceId: awssdk.String("eni-c"),
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
{
PrivateIpAddress: awssdk.String("192.168.128.150"),
},
{
PrivateIpAddress: awssdk.String("192.168.128.151"),
},
},
Groups: []ec2types.GroupIdentifier{
{
GroupId: awssdk.String("sg-c-1"),
},
},
},
{
NetworkInterfaceId: awssdk.String("eni-d"),
PrivateIpAddresses: []ec2types.NetworkInterfacePrivateIpAddress{
{
PrivateIpAddress: awssdk.String("192.168.128.152"),
},
{
PrivateIpAddress: awssdk.String("192.168.128.153"),
},
},
Groups: []ec2types.GroupIdentifier{
{
GroupId: awssdk.String("sg-d-1"),
},
},
},
},
},
},
},
args: args{
pods: []k8s.PodInfo{
{
Key: types.NamespacedName{Namespace: "default", Name: "pod-1"},
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc01"),
NodeName: "hyperpod-i-04442beca624ba65b",
PodIP: "192.168.128.151",
},
{
Key: types.NamespacedName{Namespace: "default", Name: "pod-2"},
UID: types.UID("2d8740a6-f4b1-4074-a91c-f0084ec0bc02"),
NodeName: "hyperpod-i-04159267183583d03",
PodIP: "192.168.128.152",
},
},
},
want: map[types.NamespacedName]ENIInfo{
types.NamespacedName{Namespace: "default", Name: "pod-1"}: {
NetworkInterfaceID: "eni-c",
SecurityGroups: []string{"sg-c-1"},
},
types.NamespacedName{Namespace: "default", Name: "pod-2"}: {
NetworkInterfaceID: "eni-d",
SecurityGroups: []string{"sg-d-1"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ec2Client := services.NewMockEC2(ctrl)
for _, call := range tt.fields.describeNetworkInterfacesAsListCalls {
ec2Client.EXPECT().DescribeNetworkInterfacesAsList(gomock.Any(), call.req).Return(call.resp, call.err)
}
k8sSchema := runtime.NewScheme()
clientgoscheme.AddToScheme(k8sSchema)
k8sClient := fake.NewClientBuilder().WithScheme(k8sSchema).Build()
for _, node := range tt.env.nodes {
assert.NoError(t, k8sClient.Create(context.Background(), node.DeepCopy()))
}
nodeInfoProvider := NewMockNodeInfoProvider(ctrl)
for _, call := range tt.fields.fetchNodeInstancesCalls {
updatedNodes := make([]*corev1.Node, 0, len(call.nodes))
for _, node := range call.nodes {
updatedNode := &corev1.Node{}
assert.NoError(t, k8sClient.Get(context.Background(), k8s.NamespacedName(node), updatedNode))
updatedNodes = append(updatedNodes, updatedNode)
}
nodeInfoProvider.EXPECT().FetchNodeInstances(gomock.Any(), gomock.InAnyOrder(updatedNodes)).Return(call.nodeInstanceByNodeKey, call.err)
}
r := &defaultPodENIInfoResolver{
ec2Client: ec2Client,
k8sClient: k8sClient,
nodeInfoProvider: nodeInfoProvider,
vpcID: "vpc-0d6d9ee10bd062dcc",
logger: logr.New(&log.NullLogSink{}),
describeNetworkInterfacesIPChunkSize: 2,
}

got, err := r.resolveViaCascadedLookup(context.Background(), tt.args.pods, true)
if tt.wantErr != nil {
assert.EqualError(t, err, tt.wantErr.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

func Test_defaultPodENIInfoResolver_resolveViaPodENIAnnotation(t *testing.T) {
type describeNetworkInterfacesAsListCall struct {
req *ec2sdk.DescribeNetworkInterfacesInput
Expand Down

0 comments on commit 2e1688b

Please sign in to comment.