Skip to content

Commit

Permalink
use grpc code status to handle case list agent method not implemented
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
Future Outlier committed Dec 19, 2023
1 parent 1cd0e48 commit ef9cfe8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/single-binary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ jobs:
cpu: "0"
memory: "0"
EOF
flytectl demo start --image flyte-sandbox-bundled:local --disable-agent --imagePullPolicy Never
flytectl demo start --image flyte-sandbox-bundled:local --imagePullPolicy Never
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State)

case arrayCore.PhaseAbortSubTasks:
fallthrough

Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,10 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
messageCollector.Collect(childIdx, err.Error())
} else {
externalResources = append(externalResources, &core.ExternalResource{
ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(),
Index: uint32(originalIdx),
RetryAttempt: uint32(retryAttempt),
Phase: core.PhaseAborted,
ExternalID: stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(),
Index: uint32(originalIdx),
RetryAttempt: uint32(retryAttempt),
Phase: core.PhaseAborted,
})
}
}
Expand Down
29 changes: 18 additions & 11 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (

"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -36,7 +38,7 @@ type Plugin struct {
cfg *Config
getClient GetClientFunc
connectionCache map[*Agent]*grpc.ClientConn
agentRegistry map[string]map[bool]*Agent // map[taskType][isSync] => Agent
agentRegistry map[string]*Agent // map[taskType] => Agent
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -326,14 +328,13 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte
return context.WithTimeout(ctx, timeout)
}

func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, getAgentMetadataClientFunc GetAgentMetadataClientFunc) (map[string]map[bool]*Agent, error) {
agentRegistry := make(map[string]map[bool]*Agent)
func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, getAgentMetadataClientFunc GetAgentMetadataClientFunc) (map[string]*Agent, error) {
agentRegistry := make(map[string]*Agent)
var agentDeployments []*Agent

// Ensure that the old configuration is backward compatible
for taskType, agentID := range cfg.AgentForTaskTypes {
agentRegistry[taskType] = make(map[bool]*Agent)
agentRegistry[taskType][false] = cfg.Agents[agentID]
agentRegistry[taskType] = cfg.Agents[agentID]
}

if len(cfg.DefaultAgent.Endpoint) != 0 {
Expand All @@ -351,19 +352,25 @@ func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.Clien

res, err := client.ListAgent(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpc_status, ok := status.FromError(err)
if grpc_status.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
continue
}

if !ok {
return nil, fmt.Errorf("failed to list agent with a non-gRPC error : [%v]", err)
}

return nil, fmt.Errorf("failed to list agent with error: [%v]", err)
}

agents := res.GetAgents()
for _, agent := range agents {
supportedTaskTypes := agent.SupportedTaskTypes
isSync := agent.IsSync

for _, supportedTaskType := range supportedTaskTypes {
if _, ok := agentRegistry[supportedTaskType]; !ok {
agentRegistry[supportedTaskType] = make(map[bool]*Agent)
}
agentRegistry[supportedTaskType][isSync] = agentDeployment
agentRegistry[supportedTaskType] = agentDeployment
}
}
}
Expand Down

0 comments on commit ef9cfe8

Please sign in to comment.