From e78bb497d20bd0422c4398a8f6d69a95b6bb6f82 Mon Sep 17 00:00:00 2001
From: Rafael Raposo <rafaelraposo@spotify.com>
Date: Fri, 28 Jun 2024 12:25:43 +0200
Subject: [PATCH] Add Max Parallelism into flytewf

Signed-off-by: Rafael Raposo <rafaelraposo@spotify.com>
---
 .../pkg/workflowengine/impl/k8s_executor.go   |  4 ++
 .../workflowengine/impl/k8s_executor_test.go  | 68 +++++++++++++++++++
 2 files changed, 72 insertions(+)

diff --git a/flyteadmin/pkg/workflowengine/impl/k8s_executor.go b/flyteadmin/pkg/workflowengine/impl/k8s_executor.go
index 163a58cab3..39653fae9f 100644
--- a/flyteadmin/pkg/workflowengine/impl/k8s_executor.go
+++ b/flyteadmin/pkg/workflowengine/impl/k8s_executor.go
@@ -59,6 +59,10 @@ func (e K8sWorkflowExecutor) Execute(ctx context.Context, data interfaces.Execut
 		flyteWf.ConsoleURL = consoleURL
 	}
 
+	if data.ExecutionParameters.ExecutionConfig.MaxParallelism > 0 {
+		flyteWf.ExecutionConfig.MaxParallelism = uint32(data.ExecutionParameters.ExecutionConfig.MaxParallelism)
+	}
+
 	executionTargetSpec := executioncluster.ExecutionTargetSpec{
 		Project:               data.ExecutionID.Project,
 		Domain:                data.ExecutionID.Domain,
diff --git a/flyteadmin/pkg/workflowengine/impl/k8s_executor_test.go b/flyteadmin/pkg/workflowengine/impl/k8s_executor_test.go
index b384ebbcaf..b90d8c0a8d 100644
--- a/flyteadmin/pkg/workflowengine/impl/k8s_executor_test.go
+++ b/flyteadmin/pkg/workflowengine/impl/k8s_executor_test.go
@@ -234,6 +234,74 @@ func TestExecute_AlreadyExists(t *testing.T) {
 	assert.Equal(t, resp.Cluster, clusterID)
 }
 
+func TestExecute_MaxParallelism(t *testing.T) {
+	fakeFlyteWorkflow := FakeFlyteWorkflow{}
+	fakeFlyteWorkflow.createCallback = func(flyteWorkflow *v1alpha1.FlyteWorkflow, opts v1.CreateOptions) (*v1alpha1.FlyteWorkflow, error) {
+		assert.Equal(t, flyteWf, flyteWorkflow)
+		assert.Empty(t, opts)
+		return nil, nil
+	}
+	fakeFlyteWF.flyteWorkflowsCallback = func(ns string) v1alpha12.FlyteWorkflowInterface {
+		assert.Equal(t, namespace, ns)
+		return &fakeFlyteWorkflow
+	}
+
+	mockApplicationConfig := runtimeMocks.MockApplicationProvider{}
+	mockApplicationConfig.SetTopLevelConfig(runtimeInterfaces.ApplicationConfig{
+		UseOffloadedWorkflowClosure: false,
+	})
+	mockRuntime := runtimeMocks.NewMockConfigurationProvider(&mockApplicationConfig, nil, nil, nil, nil, nil)
+
+	mockBuilder := mocks.FlyteWorkflowBuilder{}
+	workflowClosure := core.CompiledWorkflowClosure{
+		Primary: &core.CompiledWorkflow{
+			Template: &core.WorkflowTemplate{
+				Id: &core.Identifier{
+					Project: "p",
+					Domain:  "d",
+					Name:    "n",
+					Version: "version",
+				},
+			},
+		},
+	}
+	mockBuilder.OnBuildMatch(mock.MatchedBy(func(wfClosure *core.CompiledWorkflowClosure) bool {
+		return proto.Equal(wfClosure, &workflowClosure)
+	}), mock.MatchedBy(func(inputs *core.LiteralMap) bool {
+		return proto.Equal(inputs, testInputs)
+	}), mock.MatchedBy(func(executionID *core.WorkflowExecutionIdentifier) bool {
+		return proto.Equal(executionID, execID)
+	}), namespace).Return(flyteWf, nil)
+	executor := K8sWorkflowExecutor{
+		config:           mockRuntime,
+		workflowBuilder:  &mockBuilder,
+		executionCluster: getFakeExecutionCluster(),
+	}
+
+	resp, err := executor.Execute(context.TODO(), interfaces.ExecutionData{
+		Namespace:               namespace,
+		ExecutionID:             execID,
+		ReferenceWorkflowName:   "ref_workflow_name",
+		ReferenceLaunchPlanName: "ref_lp_name",
+		WorkflowClosure:         &workflowClosure,
+		ExecutionParameters: interfaces.ExecutionParameters{
+			Inputs: testInputs,
+			ExecutionConfig: &admin.WorkflowExecutionConfig{
+				SecurityContext: &core.SecurityContext{
+					RunAs: &core.Identity{
+						IamRole:           testRoleSc,
+						K8SServiceAccount: testK8sServiceAccountSc,
+					},
+				},
+				MaxParallelism: 10,
+			},
+		},
+	})
+	assert.NoError(t, err)
+	assert.Equal(t, resp.Cluster, clusterID)
+	assert.Equal(t, flyteWf.ExecutionConfig.MaxParallelism, uint32(10))
+}
+
 func TestExecute_MiscError(t *testing.T) {
 	fakeFlyteWorkflow := FakeFlyteWorkflow{}
 	fakeFlyteWorkflow.createCallback = func(flyteWorkflow *v1alpha1.FlyteWorkflow, opts v1.CreateOptions) (*v1alpha1.FlyteWorkflow, error) {