diff --git a/.github/workflows/flyteidl-buf-publish.yml b/.github/workflows/flyteidl-buf-publish.yml new file mode 100644 index 0000000000..3a6915db6f --- /dev/null +++ b/.github/workflows/flyteidl-buf-publish.yml @@ -0,0 +1,18 @@ +name: Publish flyteidl Buf Package + +on: + push: + branches: + - master + paths: + - 'flyteidl/**' +jobs: + buf: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: bufbuild/buf-setup-action@v1 + - uses: bufbuild/buf-push-action@v1 + with: + buf_token: ${{ secrets.BUF_TOKEN }} + input: 'flyteidl/protos' diff --git a/README.md b/README.md index 8dd663230f..ba0a5537e0 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ Flyte is an **open-source orchestrator** that facilitates building production-gr OpenSSF Best Practices Flyte Helm Chart - - Flyte Twitter + + Flyte Slack

diff --git a/charts/flyte-binary/README.md b/charts/flyte-binary/README.md index 3ea5cd2f6c..d6d9a3bb6a 100644 --- a/charts/flyte-binary/README.md +++ b/charts/flyte-binary/README.md @@ -21,7 +21,11 @@ Chart for basic single Flyte executable deployment | clusterResourceTemplates.labels | object | `{}` | | | commonAnnotations | object | `{}` | | | commonLabels | object | `{}` | | -| configuration.agentService | object | `{}` | | +| configuration.agentService.defaultAgent.defaultTimeout | string | `"10s"` | | +| configuration.agentService.defaultAgent.endpoint | string | `"dns:///flyteagent.flyte.svc.cluster.local:8000"` | | +| configuration.agentService.defaultAgent.insecure | bool | `true` | | +| configuration.agentService.defaultAgent.timeouts.GetTask | string | `"10s"` | | +| configuration.agentService.supportedTaskTypes[0] | string | `"default_task"` | | | configuration.annotations | object | `{}` | | | configuration.auth.authorizedUris | list | `[]` | | | configuration.auth.clientSecretsExternalSecretRef | string | `""` | | @@ -103,9 +107,9 @@ Chart for basic single Flyte executable deployment | deployment.waitForDB.image.pullPolicy | string | `"IfNotPresent"` | | | deployment.waitForDB.image.repository | string | `"postgres"` | | | deployment.waitForDB.image.tag | string | `"15-alpine"` | | -| enabled_plugins.tasks | object | `{"task-plugins":{"default-for-task-types":{"container":"container","container_array":"k8s-array","sidecar":"sidecar"},"enabled-plugins":["container","sidecar","k8s-array"]}}` | Tasks specific configuration [structure](https://pkg.go.dev/github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config#GetConfig) | -| enabled_plugins.tasks.task-plugins | object | `{"default-for-task-types":{"container":"container","container_array":"k8s-array","sidecar":"sidecar"},"enabled-plugins":["container","sidecar","k8s-array"]}` | Plugins configuration, [structure](https://pkg.go.dev/github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config#TaskPluginConfig) | -| enabled_plugins.tasks.task-plugins.enabled-plugins | list | `["container","sidecar","k8s-array"]` | [Enabled Plugins](https://pkg.go.dev/github.com/lyft/flyteplugins/go/tasks/config#Config). Enable sagemaker*, athena if you install the backend plugins | +| enabled_plugins.tasks | object | `{"task-plugins":{"default-for-task-types":{"container":"container","container_array":"k8s-array","sidecar":"sidecar"},"enabled-plugins":["container","sidecar","k8s-array","agent-service"]}}` | Tasks specific configuration [structure](https://pkg.go.dev/github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config#GetConfig) | +| enabled_plugins.tasks.task-plugins | object | `{"default-for-task-types":{"container":"container","container_array":"k8s-array","sidecar":"sidecar"},"enabled-plugins":["container","sidecar","k8s-array","agent-service"]}` | Plugins configuration, [structure](https://pkg.go.dev/github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config#TaskPluginConfig) | +| enabled_plugins.tasks.task-plugins.enabled-plugins | list | `["container","sidecar","k8s-array","agent-service"]` | [Enabled Plugins](https://pkg.go.dev/github.com/lyft/flyteplugins/go/tasks/config#Config). Enable sagemaker*, athena if you install the backend plugins | | flyte-core-components.admin.disableClusterResourceManager | bool | `false` | | | flyte-core-components.admin.disableScheduler | bool | `false` | | | flyte-core-components.admin.disabled | bool | `false` | | diff --git a/charts/flyte-binary/values.yaml b/charts/flyte-binary/values.yaml index 5e555ac890..84ffe00112 100644 --- a/charts/flyte-binary/values.yaml +++ b/charts/flyte-binary/values.yaml @@ -151,11 +151,16 @@ configuration: # tag CoPilot sidecar image tag tag: v1.9.4 # FLYTECOPILOT_TAG # agentService Flyte Agent configuration - agentService: {} + agentService: + defaultAgent: + endpoint: "dns:///flyteagent.flyte.svc.cluster.local:8000" + insecure: true + timeouts: + GetTask: 10s + defaultTimeout: 10s # Uncomment and modify to include configuration for Flyte Agent - # defaultGrpcEndpoint: agent-service.agent-namespace:8000 - # supportedTaskTypes: - # - custom_task_type + supportedTaskTypes: + - default_task # externalConfigMap Specify an existing, external ConfigMap to use as configuration for Flyte # If set, no Flyte configuration will be generated by this chart externalConfigMap: "" @@ -382,8 +387,7 @@ enabled_plugins: - container - sidecar - k8s-array - # -- Uncomment to enable agent service - # - agent-service + - agent-service default-for-task-types: container: container sidecar: sidecar diff --git a/deployment/sandbox-binary/flyte_sandbox_binary_helm_generated.yaml b/deployment/sandbox-binary/flyte_sandbox_binary_helm_generated.yaml index 7c1524af06..bb5e0cd675 100644 --- a/deployment/sandbox-binary/flyte_sandbox_binary_helm_generated.yaml +++ b/deployment/sandbox-binary/flyte_sandbox_binary_helm_generated.yaml @@ -108,6 +108,7 @@ data: - container - sidecar - k8s-array + - agent-service plugins: logs: kubernetes-enabled: false @@ -122,6 +123,15 @@ data: kubernetes-enabled: false cloudwatch-enabled: false stackdriver-enabled: false + agent-service: + defaultAgent: + defaultTimeout: 10s + endpoint: dns:///flyteagent.flyte.svc.cluster.local:8000 + insecure: true + timeouts: + GetTask: 10s + supportedTaskTypes: + - default_task 002-database.yaml: | database: postgres: @@ -357,7 +367,7 @@ spec: app.kubernetes.io/instance: flyte app.kubernetes.io/component: flyte-binary annotations: - checksum/configuration: 528ce4a42638a7810c99802dfd49525967db2a99dbc1019544b7799de2490b61 + checksum/configuration: d220769393e7acbe0372fdccbf3d588797864ec934661f08912e88ec084cdfde checksum/configuration-secret: d5d93f4e67780b21593dc3799f0f6682aab0765e708e4020939975d14d44f929 checksum/cluster-resource-templates: 7dfa59f3d447e9c099b8f8ffad3af466fecbc9cf9f8c97295d9634254a55d4ae spec: diff --git a/docker/sandbox-bundled/Makefile b/docker/sandbox-bundled/Makefile index 709c04caf2..9ae4197673 100644 --- a/docker/sandbox-bundled/Makefile +++ b/docker/sandbox-bundled/Makefile @@ -19,6 +19,7 @@ flyte: manifests: mkdir -p manifests helm dependency update ../../charts/flyte-sandbox + helm dependency update ../../charts/flyteagent kustomize build \ --enable-helm \ --load-restrictor=LoadRestrictionsNone \ @@ -27,6 +28,10 @@ manifests: --enable-helm \ --load-restrictor=LoadRestrictionsNone \ kustomize/dev > manifests/dev.yaml + kustomize build \ + --enable-helm \ + --load-restrictor=LoadRestrictionsNone \ + kustomize/complete-agent > manifests/complete-agent.yaml .PHONY: build build: flyte manifests diff --git a/docker/sandbox-bundled/bootstrap/cmd/bootstrap/main.go b/docker/sandbox-bundled/bootstrap/cmd/bootstrap/main.go index 82bae88151..6afee0e82c 100644 --- a/docker/sandbox-bundled/bootstrap/cmd/bootstrap/main.go +++ b/docker/sandbox-bundled/bootstrap/cmd/bootstrap/main.go @@ -17,13 +17,15 @@ const ( clusterResourceTemplatesConfigMapName = "flyte-sandbox-extra-cluster-resource-templates" deploymentName = "flyte-sandbox" devModeEnvVar = "FLYTE_DEV" + disableAgentModeEnvVar = "DISABLE_AGENT" dockerHost = "host.docker.internal" namespace = "flyte" // Template paths - devTemplatePath = "/var/lib/rancher/k3s/server/manifests-staging/dev.yaml" - fullTemplatePath = "/var/lib/rancher/k3s/server/manifests-staging/complete.yaml" - renderedManifestPath = "/var/lib/rancher/k3s/server/manifests/flyte.yaml" + devTemplatePath = "/var/lib/rancher/k3s/server/manifests-staging/dev.yaml" + fullTemplatePath = "/var/lib/rancher/k3s/server/manifests-staging/complete.yaml" + fullAgentTemplatePath = "/var/lib/rancher/k3s/server/manifests-staging/complete-agent.yaml" + renderedManifestPath = "/var/lib/rancher/k3s/server/manifests/flyte.yaml" ) func main() { @@ -35,7 +37,11 @@ func main() { } else { // If we are not running in dev mode, look for user-specified configuration // to load into the sandbox deployment - tmplPath = fullTemplatePath + tmplPath = fullAgentTemplatePath + if os.Getenv(disableAgentModeEnvVar) == "True" { + tmplPath = fullTemplatePath + } + cOpts := config.LoaderOpts{ ConfigurationConfigMapName: configurationConfigMapName, ClusterResourceTemplatesConfigMapName: clusterResourceTemplatesConfigMapName, diff --git a/docker/sandbox-bundled/kustomize/complete-agent/kustomization.yaml b/docker/sandbox-bundled/kustomize/complete-agent/kustomization.yaml new file mode 100644 index 0000000000..3c6d5c6e5e --- /dev/null +++ b/docker/sandbox-bundled/kustomize/complete-agent/kustomization.yaml @@ -0,0 +1,12 @@ +helmGlobals: + chartHome: ../../../../charts +helmCharts: +- name: flyte-sandbox + releaseName: flyte-sandbox + namespace: flyte +- name: flyteagent + releaseName: flyteagent + namespace: flyte +namespace: flyte +resources: +- ../namespace.yaml diff --git a/docker/sandbox-bundled/manifests/complete-agent.yaml b/docker/sandbox-bundled/manifests/complete-agent.yaml new file mode 100644 index 0000000000..b18c5ac84e --- /dev/null +++ b/docker/sandbox-bundled/manifests/complete-agent.yaml @@ -0,0 +1,1912 @@ +apiVersion: v1 +kind: Namespace +metadata: + name: flyte +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox + namespace: flyte +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +--- +apiVersion: v1 +automountServiceAccountToken: true +kind: ServiceAccount +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: minio + helm.sh/chart: minio-12.1.1 + name: flyte-sandbox-minio + namespace: flyte +secrets: +- name: flyte-sandbox-minio +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + labels: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyteagent + helm.sh/chart: flyteagent-v0.1.10 + name: flyteagent + namespace: flyte +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +rules: +- apiGroups: + - "" + resourceNames: + - kubernetes-dashboard-key-holder + - kubernetes-dashboard-certs + - kubernetes-dashboard-csrf + resources: + - secrets + verbs: + - get + - update + - delete +- apiGroups: + - "" + resourceNames: + - kubernetes-dashboard-settings + resources: + - configmaps + verbs: + - get + - update +- apiGroups: + - "" + resourceNames: + - heapster + - dashboard-metrics-scraper + resources: + - services + verbs: + - proxy +- apiGroups: + - "" + resourceNames: + - heapster + - 'http:heapster:' + - 'https:heapster:' + - dashboard-metrics-scraper + - http:dashboard-metrics-scraper + resources: + - services/proxy + verbs: + - get +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-cluster-role + namespace: flyte +rules: +- apiGroups: + - "" + resources: + - namespaces + - resourcequotas + - secrets + verbs: + - create + - get + - list + - patch + - update +- apiGroups: + - "" + resources: + - pods + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - "" + resources: + - events + verbs: + - create + - delete + - patch + - update +- apiGroups: + - "" + resources: + - podtemplates + verbs: + - get + - list + - watch +- apiGroups: + - flyte.lyft.com + resources: + - flyteworkflows + verbs: + - create + - delete + - deletecollection + - get + - list + - patch + - post + - update + - watch +- apiGroups: + - apiextensions.k8s.io + resources: + - customresourcedefinitions + verbs: + - create + - get + - list +- apiGroups: + - admissionregistration.k8s.io + resources: + - mutatingwebhookconfigurations + verbs: + - create + - get + - list + - patch + - update +- apiGroups: + - '*' + resources: + - '*' + verbs: + - '*' +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard-readonly +rules: +- apiGroups: + - "" + resources: + - configmaps + - endpoints + - persistentvolumeclaims + - pods + - replicationcontrollers + - replicationcontrollers/scale + - serviceaccounts + - services + - nodes + - persistentvolumeclaims + - persistentvolumes + verbs: + - get + - list + - watch +- apiGroups: + - "" + resources: + - bindings + - events + - limitranges + - namespaces/status + - pods/log + - pods/status + - replicationcontrollers/status + - resourcequotas + - resourcequotas/status + verbs: + - get + - list + - watch +- apiGroups: + - "" + resources: + - namespaces + verbs: + - get + - list + - watch +- apiGroups: + - apps + resources: + - daemonsets + - deployments + - deployments/scale + - replicasets + - replicasets/scale + - statefulsets + verbs: + - get + - list + - watch +- apiGroups: + - autoscaling + resources: + - horizontalpodautoscalers + verbs: + - get + - list + - watch +- apiGroups: + - batch + resources: + - cronjobs + - jobs + verbs: + - get + - list + - watch +- apiGroups: + - extensions + resources: + - daemonsets + - deployments + - deployments/scale + - ingresses + - networkpolicies + - replicasets + - replicasets/scale + - replicationcontrollers/scale + verbs: + - get + - list + - watch +- apiGroups: + - policy + resources: + - poddisruptionbudgets + verbs: + - get + - list + - watch +- apiGroups: + - networking.k8s.io + resources: + - networkpolicies + - ingresses + verbs: + - get + - list + - watch +- apiGroups: + - storage.k8s.io + resources: + - storageclasses + - volumeattachments + verbs: + - get + - list + - watch +- apiGroups: + - rbac.authorization.k8s.io + resources: + - clusterrolebindings + - clusterroles + - roles + - rolebindings + verbs: + - get + - list + - watch +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: flyte-sandbox-kubernetes-dashboard +subjects: +- kind: ServiceAccount + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-cluster-role-binding + namespace: flyte +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: flyte-sandbox-cluster-role +subjects: +- kind: ServiceAccount + name: flyte-sandbox + namespace: flyte +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard-readonly +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: ClusterRole + name: flyte-sandbox-kubernetes-dashboard-readonly +subjects: +- kind: ServiceAccount + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +--- +apiVersion: v1 +data: + namespace.yaml: | + apiVersion: v1 + kind: Namespace + metadata: + name: '{{ namespace }}' +kind: ConfigMap +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-cluster-resource-templates + namespace: flyte +--- +apiVersion: v1 +data: + 000-core.yaml: | + admin: + endpoint: localhost:8089 + insecure: true + catalog-cache: + endpoint: localhost:8081 + insecure: true + type: datacatalog + cluster_resources: + standaloneDeployment: false + templatePath: /etc/flyte/cluster-resource-templates + logger: + show-source: true + level: 6 + propeller: + create-flyteworkflow-crd: true + webhook: + certDir: /var/run/flyte/certs + localCert: true + secretName: flyte-sandbox-webhook-secret + serviceName: flyte-sandbox-webhook + servicePort: 443 + flyte: + admin: + disableClusterResourceManager: false + disableScheduler: false + disabled: false + seedProjects: + - flytesnacks + dataCatalog: + disabled: false + propeller: + disableWebhook: false + disabled: false + 001-plugins.yaml: | + tasks: + task-plugins: + default-for-task-types: + container: container + container_array: k8s-array + sidecar: sidecar + enabled-plugins: + - container + - sidecar + - k8s-array + - agent-service + plugins: + logs: + kubernetes-enabled: true + kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} + cloudwatch-enabled: false + stackdriver-enabled: false + k8s: + co-pilot: + image: "cr.flyte.org/flyteorg/flytecopilot:v1.9.4" + k8s-array: + logs: + config: + kubernetes-enabled: true + kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} + cloudwatch-enabled: false + stackdriver-enabled: false + agent-service: + defaultAgent: + defaultTimeout: 10s + endpoint: dns:///flyteagent.flyte.svc.cluster.local:8000 + insecure: true + timeouts: + GetTask: 10s + supportedTaskTypes: + - default_task + 002-database.yaml: | + database: + postgres: + username: postgres + host: flyte-sandbox-postgresql + port: 5432 + dbname: flyte + options: "sslmode=disable" + 003-storage.yaml: | + propeller: + rawoutput-prefix: s3://my-s3-bucket/data + storage: + type: stow + stow: + kind: s3 + config: + region: us-east-1 + disable_ssl: true + v2_signing: true + endpoint: http://flyte-sandbox-minio.flyte:9000 + auth_type: accesskey + container: my-s3-bucket + 100-inline-config.yaml: | + plugins: + k8s: + default-env-vars: + - FLYTE_AWS_ENDPOINT: http://flyte-sandbox-minio.flyte:9000 + - FLYTE_AWS_ACCESS_KEY_ID: minio + - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage + storage: + signedURL: + stowConfigOverride: + endpoint: http://localhost:30002 + task_resources: + defaults: + cpu: 500m + ephemeralStorage: 0 + gpu: 0 + memory: 1Gi + limits: + cpu: 0 + ephemeralStorage: 0 + gpu: 0 + memory: 0 +kind: ConfigMap +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-config + namespace: flyte +--- +apiVersion: v1 +data: + config.yml: |- + health: + storagedriver: + enabled: true + interval: 10s + threshold: 3 + http: + addr: :5000 + debug: + addr: :5001 + prometheus: + enabled: false + path: /metrics + headers: + X-Content-Type-Options: + - nosniff + log: + fields: + service: registry + storage: + cache: + blobdescriptor: inmemory + version: 0.1 +kind: ConfigMap +metadata: + labels: + app: docker-registry + chart: docker-registry-2.2.2 + heritage: Helm + release: flyte-sandbox + name: flyte-sandbox-docker-registry-config + namespace: flyte +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: flyte-sandbox-extra-cluster-resource-templates + namespace: flyte +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: flyte-sandbox-extra-config + namespace: flyte +--- +apiVersion: v1 +data: + envoy.yaml: | + admin: + access_log_path: /dev/stdout + static_resources: + listeners: + - address: + socket_address: + address: 0.0.0.0 + port_value: 8000 + filter_chains: + - filters: + - name: envoy.filters.network.http_connection_manager + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager + stat_prefix: ingress_http + codec_type: AUTO + upgrade_configs: + - upgrade_type: websocket + route_config: + name: local_route + virtual_hosts: + - name: backend + domains: + - "*" + routes: + - match: + path: "/" + redirect: + path_redirect: "/console/" + - match: + prefix: "/.well-known" + route: + cluster: flyte + - match: + prefix: "/__webpack_hmr" + route: + cluster: flyte + - match: + prefix: "/api" + route: + cluster: flyte + - match: + prefix: "/callback" + route: + cluster: flyte + - match: + prefix: "/config" + route: + cluster: flyte + - match: + prefix: "/console" + route: + cluster: flyte + - match: + prefix: "/healthcheck" + route: + cluster: flyte + - match: + prefix: "/login" + route: + cluster: flyte + - match: + prefix: "/logout" + route: + cluster: flyte + - match: + prefix: "/me" + route: + cluster: flyte + - match: + prefix: "/oauth2" + route: + cluster: flyte + - match: + prefix: "/v1" + route: + cluster: flyte + - match: + prefix: "/flyteidl.service.AdminService" + route: + cluster: flyte_grpc + - match: + prefix: "/flyteidl.service.AuthMetadataService" + route: + cluster: flyte_grpc + - match: + prefix: "/flyteidl.service.DataProxyService" + route: + cluster: flyte_grpc + - match: + prefix: "/flyteidl.service.IdentityService" + route: + cluster: flyte_grpc + - match: + prefix: "/grpc.health.v1.Health" + route: + cluster: flyte_grpc + - match: + prefix: "/flyteidl.service.SignalService" + route: + cluster: flyte_grpc + - match: + path: "/kubernetes-dashboard" + redirect: + path_redirect: "/kubernetes-dashboard/" + - match: + prefix: "/kubernetes-dashboard/" + route: + cluster: kubernetes-dashboard + prefix_rewrite: / + - match: + path: "/minio" + redirect: + path_redirect: "/minio/" + - match: + prefix: "/minio/" + route: + cluster: minio + prefix_rewrite: / + http_filters: + - name: envoy.filters.http.router + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router + clusters: + - name: flyte + connect_timeout: 0.25s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: flyte + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: flyte-sandbox-http + port_value: 8088 + - name: flyte_grpc + connect_timeout: 0.25s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + http2_protocol_options: {} + load_assignment: + cluster_name: flyte_grpc + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: flyte-sandbox-grpc + port_value: 8089 + - name: kubernetes-dashboard + connect_timeout: 0.25s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: kubernetes-dashboard + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: flyte-sandbox-kubernetes-dashboard + port_value: 80 + - name: minio + connect_timeout: 0.25s + type: STRICT_DNS + lb_policy: ROUND_ROBIN + load_assignment: + cluster_name: minio + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: flyte-sandbox-minio + port_value: 9001 +kind: ConfigMap +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-proxy-config + namespace: flyte +--- +apiVersion: v1 +data: null +kind: ConfigMap +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: kubernetes-dashboard-settings + namespace: flyte +--- +apiVersion: v1 +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-config-secret + namespace: flyte +stringData: + 012-database-secrets.yaml: | + database: + postgres: + password: "postgres" + 013-storage-secrets.yaml: | + storage: + stow: + config: + access_key_id: "minio" + secret_key: "miniostorage" +type: Opaque +--- +apiVersion: v1 +data: + haSharedSecret: R2JRWFVRYThnRFVLbHpuSA== + proxyPassword: "" + proxyUsername: "" +kind: Secret +metadata: + labels: + app: docker-registry + chart: docker-registry-2.2.2 + heritage: Helm + release: flyte-sandbox + name: flyte-sandbox-docker-registry-secret + namespace: flyte +type: Opaque +--- +apiVersion: v1 +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard-certs + namespace: flyte +type: Opaque +--- +apiVersion: v1 +data: + root-password: bWluaW9zdG9yYWdl + root-user: bWluaW8= +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: minio + helm.sh/chart: minio-12.1.1 + name: flyte-sandbox-minio + namespace: flyte +type: Opaque +--- +apiVersion: v1 +data: + postgres-password: cG9zdGdyZXM= +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: postgresql + helm.sh/chart: postgresql-12.1.9 + name: flyte-sandbox-postgresql + namespace: flyte +type: Opaque +--- +apiVersion: v1 +data: + username: User +kind: Secret +metadata: + name: flyteagent + namespace: flyte +type: Opaque +--- +apiVersion: v1 +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: kubernetes-dashboard-csrf + namespace: flyte +type: Opaque +--- +apiVersion: v1 +kind: Secret +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: kubernetes-dashboard-key-holder + namespace: flyte +type: Opaque +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app: docker-registry + chart: docker-registry-2.2.2 + heritage: Helm + release: flyte-sandbox + name: flyte-sandbox-docker-registry + namespace: flyte +spec: + ports: + - name: http-5000 + nodePort: 30000 + port: 5000 + protocol: TCP + targetPort: 5000 + selector: + app: docker-registry + release: flyte-sandbox + type: NodePort +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-grpc + namespace: flyte +spec: + ports: + - name: grpc + nodePort: null + port: 8089 + targetPort: grpc + selector: + app.kubernetes.io/component: flyte-binary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-http + namespace: flyte +spec: + ports: + - name: http + nodePort: null + port: 8088 + targetPort: http + selector: + app.kubernetes.io/component: flyte-binary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/component: kubernetes-dashboard + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + kubernetes.io/cluster-service: "true" + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +spec: + ports: + - name: http + port: 80 + targetPort: http + selector: + app.kubernetes.io/component: kubernetes-dashboard + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: kubernetes-dashboard + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: minio + helm.sh/chart: minio-12.1.1 + name: flyte-sandbox-minio + namespace: flyte +spec: + externalTrafficPolicy: Cluster + ports: + - name: minio-api + nodePort: 30002 + port: 9000 + targetPort: minio-api + - name: minio-console + port: 9001 + targetPort: minio-console + selector: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: minio + type: NodePort +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: postgresql + helm.sh/chart: postgresql-12.1.9 + name: flyte-sandbox-postgresql + namespace: flyte +spec: + externalTrafficPolicy: Cluster + ports: + - name: tcp-postgresql + nodePort: 30001 + port: 5432 + targetPort: tcp-postgresql + selector: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: postgresql + sessionAffinity: None + type: NodePort +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: postgresql + helm.sh/chart: postgresql-12.1.9 + service.alpha.kubernetes.io/tolerate-unready-endpoints: "true" + name: flyte-sandbox-postgresql-hl + namespace: flyte +spec: + clusterIP: None + ports: + - name: tcp-postgresql + port: 5432 + targetPort: tcp-postgresql + publishNotReadyAddresses: true + selector: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: postgresql + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-proxy + namespace: flyte +spec: + ports: + - name: http + nodePort: 30080 + port: 8000 + protocol: TCP + selector: + app.kubernetes.io/component: proxy + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + type: NodePort +--- +apiVersion: v1 +kind: Service +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox-webhook + namespace: flyte +spec: + ports: + - name: webhook + port: 443 + targetPort: webhook + selector: + app.kubernetes.io/component: flyte-binary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + type: ClusterIP +--- +apiVersion: v1 +kind: Service +metadata: + annotations: + projectcontour.io/upstream-protocol.h2c: grpc + labels: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyteagent + helm.sh/chart: flyteagent-v0.1.10 + name: flyteagent + namespace: flyte +spec: + ports: + - name: agent-grpc + port: 8000 + protocol: TCP + targetPort: agent-grpc + selector: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/name: flyteagent + type: ClusterIP +--- +apiVersion: v1 +kind: PersistentVolume +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-db-storage + namespace: flyte +spec: + accessModes: + - ReadWriteOnce + capacity: + storage: 1Gi + hostPath: + path: /var/lib/flyte/storage/db + storageClassName: manual +--- +apiVersion: v1 +kind: PersistentVolume +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-minio-storage + namespace: flyte +spec: + accessModes: + - ReadWriteOnce + capacity: + storage: 1Gi + hostPath: + path: /var/lib/flyte/storage/minio + storageClassName: manual +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-db-storage + namespace: flyte +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 1Gi + storageClassName: manual + volumeName: flyte-sandbox-db-storage +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-minio-storage + namespace: flyte +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 1Gi + storageClassName: manual + volumeName: flyte-sandbox-minio-storage +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-binary-v0.1.10 + name: flyte-sandbox + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/component: flyte-binary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + strategy: + type: Recreate + template: + metadata: + annotations: + checksum/cluster-resource-templates: 6fd9b172465e3089fcc59f738b92b8dc4d8939360c19de8ee65f68b0e7422035 + checksum/configuration: b765a68950c83acd0c069dac2a6569cf2b0f0f76a2760eea3561d1d04d6be831 + checksum/configuration-secret: 09216ffaa3d29e14f88b1f30af580d02a2a5e014de4d750b7f275cc07ed4e914 + labels: + app.kubernetes.io/component: flyte-binary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + spec: + containers: + - args: + - start + - --config + - /etc/flyte/config.d/*.yaml + env: + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + image: flyte-binary:sandbox + imagePullPolicy: Never + livenessProbe: + httpGet: + path: /healthcheck + port: http + name: flyte + ports: + - containerPort: 8088 + name: http + - containerPort: 8089 + name: grpc + - containerPort: 9443 + name: webhook + readinessProbe: + httpGet: + path: /healthcheck + port: http + volumeMounts: + - mountPath: /etc/flyte/cluster-resource-templates + name: cluster-resource-templates + - mountPath: /etc/flyte/config.d + name: config + - mountPath: /var/run/flyte + name: state + initContainers: + - args: + - | + until pg_isready \ + -h flyte-sandbox-postgresql \ + -p 5432 \ + -U postgres + do + echo waiting for database + sleep 0.1 + done + command: + - sh + - -ec + image: bitnami/postgresql:sandbox + imagePullPolicy: Never + name: wait-for-db + serviceAccountName: flyte-sandbox + volumes: + - name: cluster-resource-templates + projected: + sources: + - configMap: + name: flyte-sandbox-cluster-resource-templates + - configMap: + name: flyte-sandbox-extra-cluster-resource-templates + - name: config + projected: + sources: + - configMap: + name: flyte-sandbox-config + - secret: + name: flyte-sandbox-config-secret + - configMap: + name: flyte-sandbox-extra-config + - emptyDir: {} + name: state +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-buildkit + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/component: buildkit + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + template: + metadata: + labels: + app.kubernetes.io/component: buildkit + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + spec: + containers: + - args: + - --addr + - unix:///run/buildkit/buildkitd.sock + - --addr + - tcp://0.0.0.0:30003 + image: moby/buildkit:sandbox + imagePullPolicy: Never + livenessProbe: + exec: + command: + - buildctl + - debug + - workers + initialDelaySeconds: 5 + periodSeconds: 30 + name: buildkit + ports: + - containerPort: 30003 + name: tcp + protocol: TCP + readinessProbe: + exec: + command: + - buildctl + - debug + - workers + initialDelaySeconds: 5 + periodSeconds: 30 + securityContext: + privileged: true + hostNetwork: true +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app: docker-registry + chart: docker-registry-2.2.2 + heritage: Helm + release: flyte-sandbox + name: flyte-sandbox-docker-registry + namespace: flyte +spec: + minReadySeconds: 5 + replicas: 1 + selector: + matchLabels: + app: docker-registry + release: flyte-sandbox + template: + metadata: + annotations: + checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 + checksum/secret: 0ee1553aec7c03152a0a44e7b1a82985795774412a779f7b607a57e59f42c8ef + labels: + app: docker-registry + release: flyte-sandbox + spec: + containers: + - command: + - /bin/registry + - serve + - /etc/docker/registry/config.yml + env: + - name: REGISTRY_HTTP_SECRET + valueFrom: + secretKeyRef: + key: haSharedSecret + name: flyte-sandbox-docker-registry-secret + - name: REGISTRY_STORAGE_FILESYSTEM_ROOTDIRECTORY + value: /var/lib/registry + image: registry:sandbox + imagePullPolicy: Never + livenessProbe: + httpGet: + path: / + port: 5000 + name: docker-registry + ports: + - containerPort: 5000 + readinessProbe: + httpGet: + path: / + port: 5000 + resources: {} + volumeMounts: + - mountPath: /etc/docker/registry + name: flyte-sandbox-docker-registry-config + - mountPath: /var/lib/registry/ + name: data + securityContext: + fsGroup: 1000 + runAsUser: 1000 + volumes: + - configMap: + name: flyte-sandbox-docker-registry-config + name: flyte-sandbox-docker-registry-config + - emptyDir: {} + name: data +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/component: kubernetes-dashboard + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + name: flyte-sandbox-kubernetes-dashboard + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/component: kubernetes-dashboard + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: kubernetes-dashboard + strategy: + rollingUpdate: + maxSurge: 0 + maxUnavailable: 1 + type: RollingUpdate + template: + metadata: + annotations: null + labels: + app.kubernetes.io/component: kubernetes-dashboard + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: kubernetes-dashboard + app.kubernetes.io/version: 2.7.0 + helm.sh/chart: kubernetes-dashboard-6.0.0 + spec: + containers: + - args: + - --namespace=flyte + - --metrics-provider=none + - --enable-insecure-login + - --enable-skip-login + image: kubernetesui/dashboard:sandbox + imagePullPolicy: Never + livenessProbe: + httpGet: + path: / + port: 9090 + scheme: HTTP + initialDelaySeconds: 30 + timeoutSeconds: 30 + name: kubernetes-dashboard + ports: + - containerPort: 9090 + name: http + protocol: TCP + resources: + limits: + cpu: 2 + memory: 200Mi + requests: + cpu: 100m + memory: 200Mi + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + runAsGroup: 2001 + runAsUser: 1001 + volumeMounts: + - mountPath: /certs + name: kubernetes-dashboard-certs + - mountPath: /tmp + name: tmp-volume + securityContext: + seccompProfile: + type: RuntimeDefault + serviceAccountName: flyte-sandbox-kubernetes-dashboard + volumes: + - name: kubernetes-dashboard-certs + secret: + secretName: flyte-sandbox-kubernetes-dashboard-certs + - emptyDir: {} + name: tmp-volume +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: minio + helm.sh/chart: minio-12.1.1 + name: flyte-sandbox-minio + namespace: flyte +spec: + selector: + matchLabels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: minio + strategy: + type: Recreate + template: + metadata: + annotations: + checksum/credentials-secret: c199ac45f9d95d97966921c814d6c8b38cbf7416458e19cbe6d001a04c264448 + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: minio + helm.sh/chart: minio-12.1.1 + spec: + affinity: + nodeAffinity: null + podAffinity: null + podAntiAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - podAffinityTerm: + labelSelector: + matchLabels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: minio + topologyKey: kubernetes.io/hostname + weight: 1 + containers: + - env: + - name: BITNAMI_DEBUG + value: "false" + - name: MINIO_SCHEME + value: http + - name: MINIO_FORCE_NEW_KEYS + value: "no" + - name: MINIO_ROOT_USER + valueFrom: + secretKeyRef: + key: root-user + name: flyte-sandbox-minio + - name: MINIO_ROOT_PASSWORD + valueFrom: + secretKeyRef: + key: root-password + name: flyte-sandbox-minio + - name: MINIO_DEFAULT_BUCKETS + value: my-s3-bucket + - name: MINIO_BROWSER + value: "on" + - name: MINIO_PROMETHEUS_AUTH_TYPE + value: public + - name: MINIO_CONSOLE_PORT_NUMBER + value: "9001" + - name: MINIO_BROWSER_REDIRECT_URL + value: http://localhost:30080/minio + envFrom: null + image: docker.io/bitnami/minio:sandbox + imagePullPolicy: Never + livenessProbe: + failureThreshold: 5 + httpGet: + path: /minio/health/live + port: minio-api + scheme: HTTP + initialDelaySeconds: 5 + periodSeconds: 5 + successThreshold: 1 + timeoutSeconds: 5 + name: minio + ports: + - containerPort: 9000 + name: minio-api + protocol: TCP + - containerPort: 9001 + name: minio-console + protocol: TCP + readinessProbe: + failureThreshold: 5 + initialDelaySeconds: 5 + periodSeconds: 5 + successThreshold: 1 + tcpSocket: + port: minio-api + timeoutSeconds: 1 + resources: + limits: {} + requests: {} + securityContext: + runAsNonRoot: true + runAsUser: 1001 + volumeMounts: + - mountPath: /data + name: data + initContainers: + - command: + - /bin/bash + - -ec + - | + chown -R 1001:1001 /data + image: docker.io/bitnami/bitnami-shell:sandbox + imagePullPolicy: Never + name: volume-permissions + resources: + limits: {} + requests: {} + securityContext: + runAsUser: 0 + volumeMounts: + - mountPath: /data + name: data + securityContext: + fsGroup: 1001 + serviceAccountName: flyte-sandbox-minio + volumes: + - name: data + persistentVolumeClaim: + claimName: flyte-sandbox-minio-storage +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyte-sandbox + app.kubernetes.io/version: 1.16.0 + helm.sh/chart: flyte-sandbox-0.1.0 + name: flyte-sandbox-proxy + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/component: proxy + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + template: + metadata: + labels: + app.kubernetes.io/component: proxy + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: flyte-sandbox + spec: + containers: + - image: envoyproxy/envoy:sandbox + imagePullPolicy: Never + name: proxy + ports: + - containerPort: 8000 + name: http + volumeMounts: + - mountPath: /etc/envoy + name: config + volumes: + - configMap: + name: flyte-sandbox-proxy-config + name: config +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyteagent + helm.sh/chart: flyteagent-v0.1.10 + name: flyteagent + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/name: flyteagent + template: + metadata: + annotations: null + labels: + app.kubernetes.io/instance: flyteagent + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: flyteagent + helm.sh/chart: flyteagent-v0.1.10 + spec: + containers: + - command: + - pyflyte + - serve + image: ghcr.io/flyteorg/flyteagent:1.9.1 + imagePullPolicy: IfNotPresent + name: flyteagent + ports: + - containerPort: 8000 + name: agent-grpc + resources: + limits: + cpu: 500m + ephemeral-storage: 200Mi + memory: 200Mi + requests: + cpu: 500m + ephemeral-storage: 200Mi + memory: 200Mi + volumeMounts: + - mountPath: /etc/secrets + name: flyteagent + serviceAccountName: flyteagent + volumes: + - name: flyteagent + secret: + secretName: flyteagent +--- +apiVersion: apps/v1 +kind: StatefulSet +metadata: + labels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: postgresql + helm.sh/chart: postgresql-12.1.9 + name: flyte-sandbox-postgresql + namespace: flyte +spec: + replicas: 1 + selector: + matchLabels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: postgresql + serviceName: flyte-sandbox-postgresql-hl + template: + metadata: + annotations: null + labels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/managed-by: Helm + app.kubernetes.io/name: postgresql + helm.sh/chart: postgresql-12.1.9 + name: flyte-sandbox-postgresql + spec: + affinity: + nodeAffinity: null + podAffinity: null + podAntiAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - podAffinityTerm: + labelSelector: + matchLabels: + app.kubernetes.io/component: primary + app.kubernetes.io/instance: flyte-sandbox + app.kubernetes.io/name: postgresql + topologyKey: kubernetes.io/hostname + weight: 1 + containers: + - env: + - name: BITNAMI_DEBUG + value: "false" + - name: POSTGRESQL_PORT_NUMBER + value: "5432" + - name: POSTGRESQL_VOLUME_DIR + value: /bitnami/postgresql + - name: PGDATA + value: /bitnami/postgresql/data + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + key: postgres-password + name: flyte-sandbox-postgresql + - name: POSTGRESQL_ENABLE_LDAP + value: "no" + - name: POSTGRESQL_ENABLE_TLS + value: "no" + - name: POSTGRESQL_LOG_HOSTNAME + value: "false" + - name: POSTGRESQL_LOG_CONNECTIONS + value: "false" + - name: POSTGRESQL_LOG_DISCONNECTIONS + value: "false" + - name: POSTGRESQL_PGAUDIT_LOG_CATALOG + value: "off" + - name: POSTGRESQL_CLIENT_MIN_MESSAGES + value: error + - name: POSTGRESQL_SHARED_PRELOAD_LIBRARIES + value: pgaudit + image: docker.io/bitnami/postgresql:sandbox + imagePullPolicy: Never + livenessProbe: + exec: + command: + - /bin/sh + - -c + - exec pg_isready -U "postgres" -h 127.0.0.1 -p 5432 + failureThreshold: 6 + initialDelaySeconds: 30 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 5 + name: postgresql + ports: + - containerPort: 5432 + name: tcp-postgresql + readinessProbe: + exec: + command: + - /bin/sh + - -c + - -e + - | + exec pg_isready -U "postgres" -h 127.0.0.1 -p 5432 + [ -f /opt/bitnami/postgresql/tmp/.initialized ] || [ -f /bitnami/postgresql/.initialized ] + failureThreshold: 6 + initialDelaySeconds: 5 + periodSeconds: 10 + successThreshold: 1 + timeoutSeconds: 5 + resources: + limits: {} + requests: + cpu: 250m + memory: 256Mi + securityContext: + runAsUser: 1001 + volumeMounts: + - mountPath: /bitnami/postgresql + name: data + hostIPC: false + hostNetwork: false + initContainers: + - command: + - /bin/sh + - -ec + - | + chown 1001:1001 /bitnami/postgresql + mkdir -p /bitnami/postgresql/data + chmod 700 /bitnami/postgresql/data + find /bitnami/postgresql -mindepth 1 -maxdepth 1 -not -name "conf" -not -name ".snapshot" -not -name "lost+found" | \ + xargs -r chown -R 1001:1001 + image: docker.io/bitnami/bitnami-shell:sandbox + imagePullPolicy: Never + name: init-chmod-data + resources: + limits: {} + requests: {} + securityContext: + runAsUser: 0 + volumeMounts: + - mountPath: /bitnami/postgresql + name: data + securityContext: + fsGroup: 1001 + serviceAccountName: default + volumes: + - name: data + persistentVolumeClaim: + claimName: flyte-sandbox-db-storage + updateStrategy: + rollingUpdate: {} + type: RollingUpdate diff --git a/docker/sandbox-bundled/manifests/complete.yaml b/docker/sandbox-bundled/manifests/complete.yaml index 175e7aece8..d7f6e8b0cc 100644 --- a/docker/sandbox-bundled/manifests/complete.yaml +++ b/docker/sandbox-bundled/manifests/complete.yaml @@ -448,6 +448,7 @@ data: - container - sidecar - k8s-array + - agent-service plugins: logs: kubernetes-enabled: true @@ -464,6 +465,15 @@ data: kubernetes-template-uri: http://localhost:30080/kubernetes-dashboard/#/log/{{.namespace }}/{{ .podName }}/pod?namespace={{ .namespace }} cloudwatch-enabled: false stackdriver-enabled: false + agent-service: + defaultAgent: + defaultTimeout: 10s + endpoint: dns:///flyteagent.flyte.svc.cluster.local:8000 + insecure: true + timeouts: + GetTask: 10s + supportedTaskTypes: + - default_task 002-database.yaml: | database: postgres: @@ -795,7 +805,7 @@ type: Opaque --- apiVersion: v1 data: - haSharedSecret: bzd6QlVrSG9ya1c0MUxBWg== + haSharedSecret: d1l6eWRCOXBJcFhiNEo5QQ== proxyPassword: "" proxyUsername: "" kind: Secret @@ -1193,7 +1203,7 @@ spec: metadata: annotations: checksum/cluster-resource-templates: 6fd9b172465e3089fcc59f738b92b8dc4d8939360c19de8ee65f68b0e7422035 - checksum/configuration: 91f9c46efb44022473a71c6c25bc6ef20190610644a48f81a9c0e1ae01c2a73d + checksum/configuration: b765a68950c83acd0c069dac2a6569cf2b0f0f76a2760eea3561d1d04d6be831 checksum/configuration-secret: 09216ffaa3d29e14f88b1f30af580d02a2a5e014de4d750b7f275cc07ed4e914 labels: app.kubernetes.io/component: flyte-binary @@ -1356,7 +1366,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: c60195b739184d9ad0f4dd231ec9b2bdbedcbc835c4651806c1fa32d29279994 + checksum/secret: 2f5b6d46fd3276b5b25c8a537298beb6943b13b0b21900db8b2da23e166f0593 labels: app: docker-registry release: flyte-sandbox diff --git a/docker/sandbox-bundled/manifests/dev.yaml b/docker/sandbox-bundled/manifests/dev.yaml index 164a867eb7..4f3f0592e8 100644 --- a/docker/sandbox-bundled/manifests/dev.yaml +++ b/docker/sandbox-bundled/manifests/dev.yaml @@ -499,7 +499,7 @@ metadata: --- apiVersion: v1 data: - haSharedSecret: S3hhYmcwb1E0enNmZXpHQw== + haSharedSecret: UkFsUVRMRndZeTNJUVNFSA== proxyPassword: "" proxyUsername: "" kind: Secret @@ -933,7 +933,7 @@ spec: metadata: annotations: checksum/config: 8f50e768255a87f078ba8b9879a0c174c3e045ffb46ac8723d2eedbe293c8d81 - checksum/secret: 32e8e4864e56d8e05e03763b1e04dc6c1821c30c5079087b39a02c1348560d34 + checksum/secret: 25a046ef1aaf34ffb59f7b92554e1cfd0015b9a11f7f165ce06bba31e3bced1b labels: app: docker-registry release: flyte-sandbox diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index daa91968bc..ef94d85756 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -2,6 +2,7 @@ package admin import ( "context" + "errors" "fmt" "net/http" @@ -16,10 +17,12 @@ import ( "google.golang.org/grpc" ) +const ProxyAuthorizationHeader = "proxy-authorization" + // MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. -func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error { - authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg) +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error { + authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture) if err != nil { return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) } @@ -48,19 +51,70 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T return nil } +func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) { + tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand) + if err != nil { + return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err) + } + proxyTokenSource, err := tokenSourceProvider.GetTokenSource(ctx) + if err != nil { + return nil, err + } + return proxyTokenSource, nil +} + +func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) error { + proxyTokenSource, err := GetProxyTokenSource(ctx, cfg) + if err != nil { + return err + } + + wrappedTokenSource := NewCustomHeaderTokenSource(proxyTokenSource, cfg.UseInsecureConnection, ProxyAuthorizationHeader) + proxyCredentialsFuture.Store(wrappedTokenSource) + + return nil +} + func shouldAttemptToAuthenticate(errorCode codes.Code) bool { return errorCode == codes.Unauthenticated } +type proxyAuthTransport struct { + transport http.RoundTripper + proxyCredentialsFuture *PerRPCCredentialsFuture +} + +func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // check if the proxy credentials future is initialized + if !c.proxyCredentialsFuture.IsInitialized() { + return nil, errors.New("proxy credentials future is not initialized") + } + + metadata, err := c.proxyCredentialsFuture.GetRequestMetadata(context.Background(), "") + if err != nil { + return nil, err + } + token := metadata[ProxyAuthorizationHeader] + req.Header.Add(ProxyAuthorizationHeader, token) + return c.transport.RoundTrip(req) +} + // Set up http client used in oauth2 -func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context { +func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) context.Context { httpClient := &http.Client{} + transport := &http.Transport{} if len(cfg.HTTPProxyURL.String()) > 0 { // create a transport that uses the proxy - transport := &http.Transport{ - Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL), + transport.Proxy = http.ProxyURL(&cfg.HTTPProxyURL.URL) + } + + if cfg.ProxyCommand != nil { + httpClient.Transport = &proxyAuthTransport{ + transport: transport, + proxyCredentialsFuture: proxyCredentialsFuture, } + } else { httpClient.Transport = transport } @@ -77,9 +131,9 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context { // more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once // a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should // be able to find and acquire a valid AccessToken to annotate the request with. -func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { +func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = setHTTPClientContext(ctx, cfg) + ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { @@ -89,7 +143,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut // If the error we receive from executing the request expects if shouldAttemptToAuthenticate(st.Code()) { logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) - newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture) if newErr != nil { return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) } @@ -102,3 +156,18 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut return err } } + +func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + newErr := MaterializeProxyAuthCredentials(ctx, cfg, proxyCredentialsFuture) + if newErr != nil { + return fmt.Errorf("proxy authorization error! Original Error: %v, Proxy Auth Error: %w", err, newErr) + } + return invoker(ctx, method, req, reply, cc, opts...) + } + return err + } +} diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index fccd64769f..ce99c99270 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -114,7 +115,8 @@ func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServ func Test_newAuthInterceptor(t *testing.T) { t.Run("Other Error", func(t *testing.T) { f := NewPerRPCCredentialsFuture() - interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f) + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p) otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Canceled, "").Err() } @@ -146,11 +148,12 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Unauthenticated, "").Err() } @@ -177,11 +180,13 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return nil } @@ -216,11 +221,13 @@ func Test_newAuthInterceptor(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { return status.New(codes.Aborted, "").Err() } @@ -246,6 +253,8 @@ func TestMaterializeCredentials(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, @@ -254,7 +263,7 @@ func TestMaterializeCredentials(t *testing.T) { Scopes: []string{"all"}, Audience: "http://localhost:30081", AuthorizationHeader: "authorization", - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { @@ -271,13 +280,119 @@ func TestMaterializeCredentials(t *testing.T) { assert.NoError(t, err) f := NewPerRPCCredentialsFuture() + p := NewPerRPCCredentialsFuture() + err = MaterializeCredentials(ctx, &Config{ Endpoint: config.URL{URL: *u}, UseInsecureConnection: true, AuthType: AuthTypeClientSecret, TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), Scopes: []string{"all"}, - }, &mocks.TokenCache{}, f) + }, &mocks.TokenCache{}, f, p) assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err") }) } + +func TestNewProxyAuthInterceptor(t *testing.T) { + cfg := &Config{ + ProxyCommand: []string{"echo", "test-token"}, + } + + p := NewPerRPCCredentialsFuture() + + interceptor := NewProxyAuthInterceptor(cfg, p) + + ctx := context.Background() + method := "/test.method" + req := "request" + reply := "reply" + cc := new(grpc.ClientConn) + + errorInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return errors.New("test error") + } + + // Call should return an error and trigger the interceptor to materialize proxy auth credentials + err := interceptor(ctx, method, req, reply, cc, errorInvoker) + assert.Error(t, err) + + // Check if proxyCredentialsFuture contains a proxy auth header token + creds, err := p.Get().GetRequestMetadata(ctx, "") + assert.True(t, p.IsInitialized()) + assert.NoError(t, err) + assert.Equal(t, "Bearer test-token", creds[ProxyAuthorizationHeader]) +} + +type testRoundTripper struct { + RoundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RoundTripFunc(req) +} + +func TestSetHTTPClientContext(t *testing.T) { + ctx := context.Background() + + t.Run("no proxy command and no proxy url", func(t *testing.T) { + cfg := &Config{} + + newCtx := setHTTPClientContext(ctx, cfg, nil) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + transport, ok := httpClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.Nil(t, transport.Proxy) + }) + + t.Run("proxy url", func(t *testing.T) { + cfg := &Config{ + HTTPProxyURL: config. + URL{URL: url.URL{ + Scheme: "http", + Host: "localhost:8080", + }}, + } + newCtx := setHTTPClientContext(ctx, cfg, nil) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + transport, ok := httpClient.Transport.(*http.Transport) + assert.True(t, ok) + assert.NotNil(t, transport.Proxy) + }) + + t.Run("proxy command adds proxy-authorization header", func(t *testing.T) { + cfg := &Config{ + ProxyCommand: []string{"echo", "test-token-http-client"}, + } + + p := NewPerRPCCredentialsFuture() + err := MaterializeProxyAuthCredentials(ctx, cfg, p) + assert.NoError(t, err) + + newCtx := setHTTPClientContext(ctx, cfg, p) + + httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client) + assert.True(t, ok) + + pat, ok := httpClient.Transport.(*proxyAuthTransport) + assert.True(t, ok) + + testRoundTripper := &testRoundTripper{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + // Check if the ProxyAuthorizationHeader is correctly set + assert.Equal(t, "Bearer test-token-http-client", req.Header.Get(ProxyAuthorizationHeader)) + return &http.Response{StatusCode: http.StatusOK}, nil + }, + } + pat.transport = testRoundTripper + + req, _ := http.NewRequest("GET", "http://example.com", nil) + _, err = httpClient.Do(req) + assert.NoError(t, err) + }) +} diff --git a/flyteidl/clients/go/admin/client.go b/flyteidl/clients/go/admin/client.go index ffba612a09..7555f64045 100644 --- a/flyteidl/clients/go/admin/client.go +++ b/flyteidl/clients/go/admin/client.go @@ -110,9 +110,9 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr } // InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client. -func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) { +func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) (client service.AuthMetadataServiceClient, err error) { // Create an unauthenticated connection to fetch AuthMetadata - authMetadataConnection, err := NewAdminConnection(ctx, cfg) + authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture) if err != nil { return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err) } @@ -120,11 +120,11 @@ func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client serv return service.NewAuthMetadataServiceClient(authMetadataConnection), nil } -func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) { +func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture, opts ...grpc.DialOption) (*grpc.ClientConn, error) { if opts == nil { // Initialize opts list to the potential number of options we will add. Initialization optimizes memory // allocation. - opts = make([]grpc.DialOption, 0, 5) + opts = make([]grpc.DialOption, 0, 7) } if cfg.UseInsecureConnection { @@ -153,6 +153,11 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...) + if cfg.ProxyCommand != nil { + opts = append(opts, grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyCredentialsFuture))) + opts = append(opts, grpc.WithPerRPCCredentials(proxyCredentialsFuture)) + } + return grpc.Dial(cfg.Endpoint.String(), opts...) } @@ -172,15 +177,17 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp // for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection. func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) { credentialsFuture := NewPerRPCCredentialsFuture() + proxyCredentialsFuture := NewPerRPCCredentialsFuture() + opts = append(opts, - grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)), + grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)), grpc.WithPerRPCCredentials(credentialsFuture)) if cfg.DefaultServiceConfig != "" { opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig)) } - adminConnection, err := NewAdminConnection(ctx, cfg, opts...) + adminConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture, opts...) if err != nil { logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error()) } diff --git a/flyteidl/clients/go/admin/config.go b/flyteidl/clients/go/admin/config.go index 1eea36a89e..03f2f8ecc2 100644 --- a/flyteidl/clients/go/admin/config.go +++ b/flyteidl/clients/go/admin/config.go @@ -74,6 +74,8 @@ type Config struct { Command []string `json:"command" pflag:",Command for external authentication token generation"` + ProxyCommand []string `json:"proxyCommand" pflag:",Command for external proxy-authorization token generation"` + // Set the gRPC service config formatted as a json string https://github.com/grpc/grpc/blob/master/doc/service_config.md // eg. {"loadBalancingConfig": [{"round_robin":{}}], "methodConfig": [{"name":[{"service": "foo", "method": "bar"}, {"service": "baz"}], "timeout": "1.000000001s"}]} // find the full schema here https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto#L625 diff --git a/flyteidl/clients/go/admin/config_flags.go b/flyteidl/clients/go/admin/config_flags.go index 0472413b57..db1305c0b1 100755 --- a/flyteidl/clients/go/admin/config_flags.go +++ b/flyteidl/clients/go/admin/config_flags.go @@ -75,6 +75,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "deviceFlowConfig.timeout"), defaultConfig.DeviceFlowConfig.Timeout.String(), "amount of time the device flow should complete or else it will be cancelled.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "deviceFlowConfig.pollInterval"), defaultConfig.DeviceFlowConfig.PollInterval.String(), "amount of time the device flow would poll the token endpoint if auth server doesn't return a polling interval. Okta and google IDP do return an interval'") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "command"), defaultConfig.Command, "Command for external authentication token generation") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "proxyCommand"), defaultConfig.ProxyCommand, "Command for external proxy-authorization token generation") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultServiceConfig"), defaultConfig.DefaultServiceConfig, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "httpProxyURL"), defaultConfig.HTTPProxyURL.String(), "OPTIONAL: HTTP Proxy to be used for OAuth requests.") return cmdFlags diff --git a/flyteidl/clients/go/admin/config_flags_test.go b/flyteidl/clients/go/admin/config_flags_test.go index 1fb1e2a214..e815bcb5f3 100755 --- a/flyteidl/clients/go/admin/config_flags_test.go +++ b/flyteidl/clients/go/admin/config_flags_test.go @@ -449,6 +449,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_proxyCommand", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config(defaultConfig.ProxyCommand, ",") + + cmdFlags.Set("proxyCommand", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("proxyCommand"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ProxyCommand) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_defaultServiceConfig", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator.go b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator.go index 0ca48107ae..cf5a85671e 100644 --- a/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator.go +++ b/flyteidl/clients/go/admin/pkce/auth_flow_orchestrator.go @@ -2,6 +2,7 @@ package pkce import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -63,8 +64,15 @@ func (f TokenOrchestrator) FetchTokenFromAuthFlow(ctx context.Context) (*oauth2. serveMux := http.NewServeMux() server := &http.Server{Addr: redirectURL.Host, Handler: serveMux, ReadHeaderTimeout: 0} // Register the call back handler + + // Pass along http client used in oauth2 + httpClient, ok := ctx.Value(oauth2.HTTPClient).(*http.Client) + if !ok { + return nil, errors.New("Unable to retrieve httpClient used in oauth2 from context") + } + serveMux.HandleFunc(redirectURL.Path, getAuthServerCallbackHandler(f.ClientConfig, pkceCodeVerifier, - tokenChannel, errorChannel, stateString)) // the oauth2 callback endpoint + tokenChannel, errorChannel, stateString, httpClient)) // the oauth2 callback endpoint defer server.Close() go func() { diff --git a/flyteidl/clients/go/admin/pkce/handle_app_call_back.go b/flyteidl/clients/go/admin/pkce/handle_app_call_back.go index 72ac53883f..775dcb4795 100644 --- a/flyteidl/clients/go/admin/pkce/handle_app_call_back.go +++ b/flyteidl/clients/go/admin/pkce/handle_app_call_back.go @@ -11,7 +11,7 @@ import ( ) func getAuthServerCallbackHandler(c *oauth.Config, codeVerifier string, tokenChannel chan *oauth2.Token, - errorChannel chan error, stateString string) func(rw http.ResponseWriter, req *http.Request) { + errorChannel chan error, stateString string, client *http.Client) func(rw http.ResponseWriter, req *http.Request) { return func(rw http.ResponseWriter, req *http.Request) { _, _ = rw.Write([]byte(`

Flyte Authentication

`)) @@ -43,7 +43,8 @@ func getAuthServerCallbackHandler(c *oauth.Config, codeVerifier string, tokenCha var opts []oauth2.AuthCodeOption opts = append(opts, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) - token, err := c.Exchange(context.Background(), req.URL.Query().Get("code"), opts...) + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + token, err := c.Exchange(ctx, req.URL.Query().Get("code"), opts...) if err != nil { errorChannel <- fmt.Errorf("error while exchanging auth code due to %v", err) _, _ = rw.Write([]byte(fmt.Sprintf(`

Couldn't get access token due to error: %s

`, err.Error()))) diff --git a/flyteidl/clients/go/admin/pkce/handle_app_call_back_test.go b/flyteidl/clients/go/admin/pkce/handle_app_call_back_test.go index c28b833322..30c409002d 100644 --- a/flyteidl/clients/go/admin/pkce/handle_app_call_back_test.go +++ b/flyteidl/clients/go/admin/pkce/handle_app_call_back_test.go @@ -25,7 +25,7 @@ func HandleAppCallBackSetup(t *testing.T, state string) (tokenChannel chan *oaut errorChannel = make(chan error, 1) tokenChannel = make(chan *oauth2.Token) testAuthConfig = &oauth.Config{Config: &oauth2.Config{}, DeviceEndpoint: "dummyDeviceEndpoint"} - callBackFn = getAuthServerCallbackHandler(testAuthConfig, "", tokenChannel, errorChannel, state) + callBackFn = getAuthServerCallbackHandler(testAuthConfig, "", tokenChannel, errorChannel, state, &http.Client{}) assert.NotNil(t, callBackFn) req = &http.Request{ Method: http.MethodGet, diff --git a/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.cc index 1c9f90856f..e4e12e3543 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.cc @@ -324,6 +324,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fadmin_2fagent_2eproto::o ~0u, // no _weak_field_map_ PROTOBUF_FIELD_OFFSET(::flyteidl::admin::Resource, state_), PROTOBUF_FIELD_OFFSET(::flyteidl::admin::Resource, outputs_), + PROTOBUF_FIELD_OFFSET(::flyteidl::admin::Resource, message_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::admin::DeleteTaskRequest, _internal_metadata_), ~0u, // no _extensions_ @@ -347,8 +348,8 @@ static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SE { 53, -1, sizeof(::flyteidl::admin::GetTaskRequest)}, { 60, -1, sizeof(::flyteidl::admin::GetTaskResponse)}, { 66, -1, sizeof(::flyteidl::admin::Resource)}, - { 73, -1, sizeof(::flyteidl::admin::DeleteTaskRequest)}, - { 80, -1, sizeof(::flyteidl::admin::DeleteTaskResponse)}, + { 74, -1, sizeof(::flyteidl::admin::DeleteTaskRequest)}, + { 81, -1, sizeof(::flyteidl::admin::DeleteTaskResponse)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { @@ -399,20 +400,21 @@ const char descriptor_table_protodef_flyteidl_2fadmin_2fagent_2eproto[] = "\016GetTaskRequest\022\021\n\ttask_type\030\001 \001(\t\022\025\n\rre" "source_meta\030\002 \001(\014\"=\n\017GetTaskResponse\022*\n\010" "resource\030\001 \001(\0132\030.flyteidl.admin.Resource" - "\"\\\n\010Resource\022$\n\005state\030\001 \001(\0162\025.flyteidl.a" + "\"m\n\010Resource\022$\n\005state\030\001 \001(\0162\025.flyteidl.a" "dmin.State\022*\n\007outputs\030\002 \001(\0132\031.flyteidl.c" - "ore.LiteralMap\"=\n\021DeleteTaskRequest\022\021\n\tt" - "ask_type\030\001 \001(\t\022\025\n\rresource_meta\030\002 \001(\014\"\024\n" - "\022DeleteTaskResponse*^\n\005State\022\025\n\021RETRYABL" - "E_FAILURE\020\000\022\025\n\021PERMANENT_FAILURE\020\001\022\013\n\007PE" - "NDING\020\002\022\013\n\007RUNNING\020\003\022\r\n\tSUCCEEDED\020\004B=Z;g" - "ithub.com/flyteorg/flyte/flyteidl/gen/pb" - "-go/flyteidl/adminb\006proto3" + "ore.LiteralMap\022\017\n\007message\030\003 \001(\t\"=\n\021Delet" + "eTaskRequest\022\021\n\ttask_type\030\001 \001(\t\022\025\n\rresou" + "rce_meta\030\002 \001(\014\"\024\n\022DeleteTaskResponse*^\n\005" + "State\022\025\n\021RETRYABLE_FAILURE\020\000\022\025\n\021PERMANEN" + "T_FAILURE\020\001\022\013\n\007PENDING\020\002\022\013\n\007RUNNING\020\003\022\r\n" + "\tSUCCEEDED\020\004B=Z;github.com/flyteorg/flyt" + "e/flyteidl/gen/pb-go/flyteidl/adminb\006pro" + "to3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fadmin_2fagent_2eproto = { false, InitDefaults_flyteidl_2fadmin_2fagent_2eproto, descriptor_table_protodef_flyteidl_2fadmin_2fagent_2eproto, - "flyteidl/admin/agent.proto", &assign_descriptors_table_flyteidl_2fadmin_2fagent_2eproto, 1426, + "flyteidl/admin/agent.proto", &assign_descriptors_table_flyteidl_2fadmin_2fagent_2eproto, 1443, }; void AddDescriptors_flyteidl_2fadmin_2fagent_2eproto() { @@ -2962,6 +2964,7 @@ void Resource::clear_outputs() { #if !defined(_MSC_VER) || _MSC_VER >= 1900 const int Resource::kStateFieldNumber; const int Resource::kOutputsFieldNumber; +const int Resource::kMessageFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 Resource::Resource() @@ -2973,6 +2976,10 @@ Resource::Resource(const Resource& from) : ::google::protobuf::Message(), _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); + message_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); + if (from.message().size() > 0) { + message_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.message_); + } if (from.has_outputs()) { outputs_ = new ::flyteidl::core::LiteralMap(*from.outputs_); } else { @@ -2985,6 +2992,7 @@ Resource::Resource(const Resource& from) void Resource::SharedCtor() { ::google::protobuf::internal::InitSCC( &scc_info_Resource_flyteidl_2fadmin_2fagent_2eproto.base); + message_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); ::memset(&outputs_, 0, static_cast( reinterpret_cast(&state_) - reinterpret_cast(&outputs_)) + sizeof(state_)); @@ -2996,6 +3004,7 @@ Resource::~Resource() { } void Resource::SharedDtor() { + message_.DestroyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); if (this != internal_default_instance()) delete outputs_; } @@ -3014,6 +3023,7 @@ void Resource::Clear() { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + message_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); if (GetArenaNoVirtual() == nullptr && outputs_ != nullptr) { delete outputs_; } @@ -3056,6 +3066,22 @@ const char* Resource::_InternalParse(const char* begin, const char* end, void* o {parser_till_end, object}, ptr - size, ptr)); break; } + // string message = 3; + case 3: { + if (static_cast<::google::protobuf::uint8>(tag) != 26) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + ctx->extra_parse_data().SetFieldName("flyteidl.admin.Resource.message"); + object = msg->mutable_message(); + if (size > end - ptr + ::google::protobuf::internal::ParseContext::kSlopBytes) { + parser_till_end = ::google::protobuf::internal::GreedyStringParserUTF8; + goto string_till_end; + } + GOOGLE_PROTOBUF_PARSER_ASSERT(::google::protobuf::internal::StringCheckUTF8(ptr, size, ctx)); + ::google::protobuf::internal::InlineGreedyStringParser(object, ptr, size, ctx); + ptr += size; + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -3071,6 +3097,10 @@ const char* Resource::_InternalParse(const char* begin, const char* end, void* o } // switch } // while return ptr; +string_till_end: + static_cast<::std::string*>(object)->clear(); + static_cast<::std::string*>(object)->reserve(size); + goto len_delim_till_end; len_delim_till_end: return ctx->StoreAndTailCall(ptr, end, {_InternalParse, msg}, {parser_till_end, object}, size); @@ -3111,6 +3141,21 @@ bool Resource::MergePartialFromCodedStream( break; } + // string message = 3; + case 3: { + if (static_cast< ::google::protobuf::uint8>(tag) == (26 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadString( + input, this->mutable_message())); + DO_(::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->message().data(), static_cast(this->message().length()), + ::google::protobuf::internal::WireFormatLite::PARSE, + "flyteidl.admin.Resource.message")); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -3150,6 +3195,16 @@ void Resource::SerializeWithCachedSizes( 2, HasBitSetters::outputs(this), output); } + // string message = 3; + if (this->message().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->message().data(), static_cast(this->message().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.admin.Resource.message"); + ::google::protobuf::internal::WireFormatLite::WriteStringMaybeAliased( + 3, this->message(), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -3176,6 +3231,17 @@ ::google::protobuf::uint8* Resource::InternalSerializeWithCachedSizesToArray( 2, HasBitSetters::outputs(this), target); } + // string message = 3; + if (this->message().size() > 0) { + ::google::protobuf::internal::WireFormatLite::VerifyUtf8String( + this->message().data(), static_cast(this->message().length()), + ::google::protobuf::internal::WireFormatLite::SERIALIZE, + "flyteidl.admin.Resource.message"); + target = + ::google::protobuf::internal::WireFormatLite::WriteStringToArray( + 3, this->message(), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -3197,6 +3263,13 @@ size_t Resource::ByteSizeLong() const { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + // string message = 3; + if (this->message().size() > 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::StringSize( + this->message()); + } + // .flyteidl.core.LiteralMap outputs = 2; if (this->has_outputs()) { total_size += 1 + @@ -3237,6 +3310,10 @@ void Resource::MergeFrom(const Resource& from) { ::google::protobuf::uint32 cached_has_bits = 0; (void) cached_has_bits; + if (from.message().size() > 0) { + + message_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.message_); + } if (from.has_outputs()) { mutable_outputs()->::flyteidl::core::LiteralMap::MergeFrom(from.outputs()); } @@ -3270,6 +3347,8 @@ void Resource::Swap(Resource* other) { void Resource::InternalSwap(Resource* other) { using std::swap; _internal_metadata_.Swap(&other->_internal_metadata_); + message_.Swap(&other->message_, &::google::protobuf::internal::GetEmptyStringAlreadyInited(), + GetArenaNoVirtual()); swap(outputs_, other->outputs_); swap(state_, other->state_); } diff --git a/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.h b/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.h index 5ee08079fc..5d43236ad3 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/admin/agent.pb.h @@ -1015,6 +1015,20 @@ class Resource final : // accessors ------------------------------------------------------- + // string message = 3; + void clear_message(); + static const int kMessageFieldNumber = 3; + const ::std::string& message() const; + void set_message(const ::std::string& value); + #if LANG_CXX11 + void set_message(::std::string&& value); + #endif + void set_message(const char* value); + void set_message(const char* value, size_t size); + ::std::string* mutable_message(); + ::std::string* release_message(); + void set_allocated_message(::std::string* message); + // .flyteidl.core.LiteralMap outputs = 2; bool has_outputs() const; void clear_outputs(); @@ -1035,6 +1049,7 @@ class Resource final : class HasBitSetters; ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + ::google::protobuf::internal::ArenaStringPtr message_; ::flyteidl::core::LiteralMap* outputs_; int state_; mutable ::google::protobuf::internal::CachedSize _cached_size_; @@ -1985,6 +2000,59 @@ inline void Resource::set_allocated_outputs(::flyteidl::core::LiteralMap* output // @@protoc_insertion_point(field_set_allocated:flyteidl.admin.Resource.outputs) } +// string message = 3; +inline void Resource::clear_message() { + message_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline const ::std::string& Resource::message() const { + // @@protoc_insertion_point(field_get:flyteidl.admin.Resource.message) + return message_.GetNoArena(); +} +inline void Resource::set_message(const ::std::string& value) { + + message_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), value); + // @@protoc_insertion_point(field_set:flyteidl.admin.Resource.message) +} +#if LANG_CXX11 +inline void Resource::set_message(::std::string&& value) { + + message_.SetNoArena( + &::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::move(value)); + // @@protoc_insertion_point(field_set_rvalue:flyteidl.admin.Resource.message) +} +#endif +inline void Resource::set_message(const char* value) { + GOOGLE_DCHECK(value != nullptr); + + message_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), ::std::string(value)); + // @@protoc_insertion_point(field_set_char:flyteidl.admin.Resource.message) +} +inline void Resource::set_message(const char* value, size_t size) { + + message_.SetNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), + ::std::string(reinterpret_cast(value), size)); + // @@protoc_insertion_point(field_set_pointer:flyteidl.admin.Resource.message) +} +inline ::std::string* Resource::mutable_message() { + + // @@protoc_insertion_point(field_mutable:flyteidl.admin.Resource.message) + return message_.MutableNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline ::std::string* Resource::release_message() { + // @@protoc_insertion_point(field_release:flyteidl.admin.Resource.message) + + return message_.ReleaseNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); +} +inline void Resource::set_allocated_message(::std::string* message) { + if (message != nullptr) { + + } else { + + } + message_.SetAllocatedNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), message); + // @@protoc_insertion_point(field_set_allocated:flyteidl.admin.Resource.message) +} + // ------------------------------------------------------------------- // DeleteTaskRequest diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc index 303cb8003e..187cc04817 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.cc @@ -83,6 +83,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fkubeflow_2ften PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, ps_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, chief_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, run_policy_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask, evaluator_replicas_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec, _internal_metadata_), ~0u, // no _extensions_ @@ -95,7 +96,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fkubeflow_2ften }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingTask)}, - { 9, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec)}, + { 10, -1, sizeof(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { @@ -113,7 +114,7 @@ const char descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_ "\n*flyteidl/plugins/kubeflow/tensorflow.p" "roto\022\031flyteidl.plugins.kubeflow\032\031flyteid" "l/core/tasks.proto\032&flyteidl/plugins/kub" - "eflow/common.proto\"\362\002\n!DistributedTensor" + "eflow/common.proto\"\323\003\n!DistributedTensor" "flowTrainingTask\022\\\n\017worker_replicas\030\001 \001(" "\0132C.flyteidl.plugins.kubeflow.Distribute" "dTensorflowTrainingReplicaSpec\022X\n\013ps_rep" @@ -122,19 +123,21 @@ const char descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_ "\022[\n\016chief_replicas\030\003 \001(\0132C.flyteidl.plug" "ins.kubeflow.DistributedTensorflowTraini" "ngReplicaSpec\0228\n\nrun_policy\030\004 \001(\0132$.flyt" - "eidl.plugins.kubeflow.RunPolicy\"\272\001\n(Dist" - "ributedTensorflowTrainingReplicaSpec\022\020\n\010" - "replicas\030\001 \001(\005\022\r\n\005image\030\002 \001(\t\022+\n\tresourc" - "es\030\003 \001(\0132\030.flyteidl.core.Resources\022@\n\016re" - "start_policy\030\004 \001(\0162(.flyteidl.plugins.ku" - "beflow.RestartPolicyB\?Z=github.com/flyte" - "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" - "uginsb\006proto3" + "eidl.plugins.kubeflow.RunPolicy\022_\n\022evalu" + "ator_replicas\030\005 \001(\0132C.flyteidl.plugins.k" + "ubeflow.DistributedTensorflowTrainingRep" + "licaSpec\"\272\001\n(DistributedTensorflowTraini" + "ngReplicaSpec\022\020\n\010replicas\030\001 \001(\005\022\r\n\005image" + "\030\002 \001(\t\022+\n\tresources\030\003 \001(\0132\030.flyteidl.cor" + "e.Resources\022@\n\016restart_policy\030\004 \001(\0162(.fl" + "yteidl.plugins.kubeflow.RestartPolicyB\?Z" + "=github.com/flyteorg/flyte/flyteidl/gen/" + "pb-go/flyteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, - "flyteidl/plugins/kubeflow/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, 773, + "flyteidl/plugins/kubeflow/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto, 870, }; void AddDescriptors_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto() { @@ -163,6 +166,8 @@ void DistributedTensorflowTrainingTask::InitAsDefaultInstance() { ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::internal_default_instance()); ::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingTask_default_instance_._instance.get_mutable()->run_policy_ = const_cast< ::flyteidl::plugins::kubeflow::RunPolicy*>( ::flyteidl::plugins::kubeflow::RunPolicy::internal_default_instance()); + ::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingTask_default_instance_._instance.get_mutable()->evaluator_replicas_ = const_cast< ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec*>( + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::internal_default_instance()); } class DistributedTensorflowTrainingTask::HasBitSetters { public: @@ -170,6 +175,7 @@ class DistributedTensorflowTrainingTask::HasBitSetters { static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& ps_replicas(const DistributedTensorflowTrainingTask* msg); static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& chief_replicas(const DistributedTensorflowTrainingTask* msg); static const ::flyteidl::plugins::kubeflow::RunPolicy& run_policy(const DistributedTensorflowTrainingTask* msg); + static const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& evaluator_replicas(const DistributedTensorflowTrainingTask* msg); }; const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& @@ -188,6 +194,10 @@ const ::flyteidl::plugins::kubeflow::RunPolicy& DistributedTensorflowTrainingTask::HasBitSetters::run_policy(const DistributedTensorflowTrainingTask* msg) { return *msg->run_policy_; } +const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& +DistributedTensorflowTrainingTask::HasBitSetters::evaluator_replicas(const DistributedTensorflowTrainingTask* msg) { + return *msg->evaluator_replicas_; +} void DistributedTensorflowTrainingTask::clear_run_policy() { if (GetArenaNoVirtual() == nullptr && run_policy_ != nullptr) { delete run_policy_; @@ -199,6 +209,7 @@ const int DistributedTensorflowTrainingTask::kWorkerReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kPsReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kChiefReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kRunPolicyFieldNumber; +const int DistributedTensorflowTrainingTask::kEvaluatorReplicasFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask() @@ -230,6 +241,11 @@ DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask(const Distr } else { run_policy_ = nullptr; } + if (from.has_evaluator_replicas()) { + evaluator_replicas_ = new ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec(*from.evaluator_replicas_); + } else { + evaluator_replicas_ = nullptr; + } // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask) } @@ -237,8 +253,8 @@ void DistributedTensorflowTrainingTask::SharedCtor() { ::google::protobuf::internal::InitSCC( &scc_info_DistributedTensorflowTrainingTask_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto.base); ::memset(&worker_replicas_, 0, static_cast( - reinterpret_cast(&run_policy_) - - reinterpret_cast(&worker_replicas_)) + sizeof(run_policy_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&worker_replicas_)) + sizeof(evaluator_replicas_)); } DistributedTensorflowTrainingTask::~DistributedTensorflowTrainingTask() { @@ -251,6 +267,7 @@ void DistributedTensorflowTrainingTask::SharedDtor() { if (this != internal_default_instance()) delete ps_replicas_; if (this != internal_default_instance()) delete chief_replicas_; if (this != internal_default_instance()) delete run_policy_; + if (this != internal_default_instance()) delete evaluator_replicas_; } void DistributedTensorflowTrainingTask::SetCachedSize(int size) const { @@ -284,6 +301,10 @@ void DistributedTensorflowTrainingTask::Clear() { delete run_policy_; } run_policy_ = nullptr; + if (GetArenaNoVirtual() == nullptr && evaluator_replicas_ != nullptr) { + delete evaluator_replicas_; + } + evaluator_replicas_ = nullptr; _internal_metadata_.Clear(); } @@ -352,6 +373,19 @@ const char* DistributedTensorflowTrainingTask::_InternalParse(const char* begin, {parser_till_end, object}, ptr - size, ptr)); break; } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + case 5: { + if (static_cast<::google::protobuf::uint8>(tag) != 42) goto handle_unusual; + ptr = ::google::protobuf::io::ReadSize(ptr, &size); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + parser_till_end = ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::_InternalParse; + object = msg->mutable_evaluator_replicas(); + if (size > end - ptr) goto len_delim_till_end; + ptr += size; + GOOGLE_PROTOBUF_PARSER_ASSERT(ctx->ParseExactRange( + {parser_till_end, object}, ptr - size, ptr)); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -426,6 +460,17 @@ bool DistributedTensorflowTrainingTask::MergePartialFromCodedStream( break; } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + case 5: { + if (static_cast< ::google::protobuf::uint8>(tag) == (42 & 0xFF)) { + DO_(::google::protobuf::internal::WireFormatLite::ReadMessage( + input, mutable_evaluator_replicas())); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -477,6 +522,12 @@ void DistributedTensorflowTrainingTask::SerializeWithCachedSizes( 4, HasBitSetters::run_policy(this), output); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( + 5, HasBitSetters::evaluator_replicas(this), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -518,6 +569,13 @@ ::google::protobuf::uint8* DistributedTensorflowTrainingTask::InternalSerializeW 4, HasBitSetters::run_policy(this), target); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + target = ::google::protobuf::internal::WireFormatLite:: + InternalWriteMessageToArray( + 5, HasBitSetters::evaluator_replicas(this), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -567,6 +625,13 @@ size_t DistributedTensorflowTrainingTask::ByteSizeLong() const { *run_policy_); } + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + if (this->has_evaluator_replicas()) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::MessageSize( + *evaluator_replicas_); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -606,6 +671,9 @@ void DistributedTensorflowTrainingTask::MergeFrom(const DistributedTensorflowTra if (from.has_run_policy()) { mutable_run_policy()->::flyteidl::plugins::kubeflow::RunPolicy::MergeFrom(from.run_policy()); } + if (from.has_evaluator_replicas()) { + mutable_evaluator_replicas()->::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec::MergeFrom(from.evaluator_replicas()); + } } void DistributedTensorflowTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -637,6 +705,7 @@ void DistributedTensorflowTrainingTask::InternalSwap(DistributedTensorflowTraini swap(ps_replicas_, other->ps_replicas_); swap(chief_replicas_, other->chief_replicas_); swap(run_policy_, other->run_policy_); + swap(evaluator_replicas_, other->evaluator_replicas_); } ::google::protobuf::Metadata DistributedTensorflowTrainingTask::GetMetadata() const { diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h index 9839ca4817..4a100b0233 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/kubeflow/tensorflow.pb.h @@ -205,6 +205,15 @@ class DistributedTensorflowTrainingTask final : ::flyteidl::plugins::kubeflow::RunPolicy* mutable_run_policy(); void set_allocated_run_policy(::flyteidl::plugins::kubeflow::RunPolicy* run_policy); + // .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + bool has_evaluator_replicas() const; + void clear_evaluator_replicas(); + static const int kEvaluatorReplicasFieldNumber = 5; + const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& evaluator_replicas() const; + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* release_evaluator_replicas(); + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* mutable_evaluator_replicas(); + void set_allocated_evaluator_replicas(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask) private: class HasBitSetters; @@ -214,6 +223,7 @@ class DistributedTensorflowTrainingTask final : ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* ps_replicas_; ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* chief_replicas_; ::flyteidl::plugins::kubeflow::RunPolicy* run_policy_; + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fkubeflow_2ftensorflow_2eproto; }; @@ -570,6 +580,57 @@ inline void DistributedTensorflowTrainingTask::set_allocated_run_policy(::flytei // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.run_policy) } +// .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; +inline bool DistributedTensorflowTrainingTask::has_evaluator_replicas() const { + return this != internal_default_instance() && evaluator_replicas_ != nullptr; +} +inline void DistributedTensorflowTrainingTask::clear_evaluator_replicas() { + if (GetArenaNoVirtual() == nullptr && evaluator_replicas_ != nullptr) { + delete evaluator_replicas_; + } + evaluator_replicas_ = nullptr; +} +inline const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec& DistributedTensorflowTrainingTask::evaluator_replicas() const { + const ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* p = evaluator_replicas_; + // @@protoc_insertion_point(field_get:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + return p != nullptr ? *p : *reinterpret_cast( + &::flyteidl::plugins::kubeflow::_DistributedTensorflowTrainingReplicaSpec_default_instance_); +} +inline ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* DistributedTensorflowTrainingTask::release_evaluator_replicas() { + // @@protoc_insertion_point(field_release:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + + ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* temp = evaluator_replicas_; + evaluator_replicas_ = nullptr; + return temp; +} +inline ::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* DistributedTensorflowTrainingTask::mutable_evaluator_replicas() { + + if (evaluator_replicas_ == nullptr) { + auto* p = CreateMaybeMessage<::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec>(GetArenaNoVirtual()); + evaluator_replicas_ = p; + } + // @@protoc_insertion_point(field_mutable:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) + return evaluator_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_allocated_evaluator_replicas(::flyteidl::plugins::kubeflow::DistributedTensorflowTrainingReplicaSpec* evaluator_replicas) { + ::google::protobuf::Arena* message_arena = GetArenaNoVirtual(); + if (message_arena == nullptr) { + delete evaluator_replicas_; + } + if (evaluator_replicas) { + ::google::protobuf::Arena* submessage_arena = nullptr; + if (message_arena != submessage_arena) { + evaluator_replicas = ::google::protobuf::internal::GetOwnedMessage( + message_arena, evaluator_replicas, submessage_arena); + } + + } else { + + } + evaluator_replicas_ = evaluator_replicas; + // @@protoc_insertion_point(field_set_allocated:flyteidl.plugins.kubeflow.DistributedTensorflowTrainingTask.evaluator_replicas) +} + // ------------------------------------------------------------------- // DistributedTensorflowTrainingReplicaSpec diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc index 7d01fad6e2..23d343be1a 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.cc @@ -55,6 +55,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2ftensorflow_2ep PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, workers_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, ps_replicas_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, chief_replicas_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::DistributedTensorflowTrainingTask, evaluator_replicas_), }; static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, sizeof(::flyteidl::plugins::DistributedTensorflowTrainingTask)}, @@ -72,16 +73,17 @@ ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_fl const char descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto[] = "\n!flyteidl/plugins/tensorflow.proto\022\020fly" - "teidl.plugins\"a\n!DistributedTensorflowTr" + "teidl.plugins\"}\n!DistributedTensorflowTr" "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" - "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B\?Z=gith" - "ub.com/flyteorg/flyte/flyteidl/gen/pb-go" - "/flyteidl/pluginsb\006proto3" + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005\022\032\n\022eval" + "uator_replicas\030\004 \001(\005B\?Z=github.com/flyte" + "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + "uginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2ftensorflow_2eproto = { false, InitDefaults_flyteidl_2fplugins_2ftensorflow_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2ftensorflow_2eproto, - "flyteidl/plugins/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto, 225, + "flyteidl/plugins/tensorflow.proto", &assign_descriptors_table_flyteidl_2fplugins_2ftensorflow_2eproto, 253, }; void AddDescriptors_flyteidl_2fplugins_2ftensorflow_2eproto() { @@ -108,6 +110,7 @@ class DistributedTensorflowTrainingTask::HasBitSetters { const int DistributedTensorflowTrainingTask::kWorkersFieldNumber; const int DistributedTensorflowTrainingTask::kPsReplicasFieldNumber; const int DistributedTensorflowTrainingTask::kChiefReplicasFieldNumber; +const int DistributedTensorflowTrainingTask::kEvaluatorReplicasFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask() @@ -120,15 +123,15 @@ DistributedTensorflowTrainingTask::DistributedTensorflowTrainingTask(const Distr _internal_metadata_(nullptr) { _internal_metadata_.MergeFrom(from._internal_metadata_); ::memcpy(&workers_, &from.workers_, - static_cast(reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + static_cast(reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.DistributedTensorflowTrainingTask) } void DistributedTensorflowTrainingTask::SharedCtor() { ::memset(&workers_, 0, static_cast( - reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); } DistributedTensorflowTrainingTask::~DistributedTensorflowTrainingTask() { @@ -155,8 +158,8 @@ void DistributedTensorflowTrainingTask::Clear() { (void) cached_has_bits; ::memset(&workers_, 0, static_cast( - reinterpret_cast(&chief_replicas_) - - reinterpret_cast(&workers_)) + sizeof(chief_replicas_)); + reinterpret_cast(&evaluator_replicas_) - + reinterpret_cast(&workers_)) + sizeof(evaluator_replicas_)); _internal_metadata_.Clear(); } @@ -194,6 +197,13 @@ const char* DistributedTensorflowTrainingTask::_InternalParse(const char* begin, GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } + // int32 evaluator_replicas = 4; + case 4: { + if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; + msg->set_evaluator_replicas(::google::protobuf::internal::ReadVarint(&ptr)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -260,6 +270,19 @@ bool DistributedTensorflowTrainingTask::MergePartialFromCodedStream( break; } + // int32 evaluator_replicas = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { + + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::int32, ::google::protobuf::internal::WireFormatLite::TYPE_INT32>( + input, &evaluator_replicas_))); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -302,6 +325,11 @@ void DistributedTensorflowTrainingTask::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteInt32(3, this->chief_replicas(), output); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteInt32(4, this->evaluator_replicas(), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -330,6 +358,11 @@ ::google::protobuf::uint8* DistributedTensorflowTrainingTask::InternalSerializeW target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(3, this->chief_replicas(), target); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteInt32ToArray(4, this->evaluator_replicas(), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -372,6 +405,13 @@ size_t DistributedTensorflowTrainingTask::ByteSizeLong() const { this->chief_replicas()); } + // int32 evaluator_replicas = 4; + if (this->evaluator_replicas() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::Int32Size( + this->evaluator_replicas()); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -408,6 +448,9 @@ void DistributedTensorflowTrainingTask::MergeFrom(const DistributedTensorflowTra if (from.chief_replicas() != 0) { set_chief_replicas(from.chief_replicas()); } + if (from.evaluator_replicas() != 0) { + set_evaluator_replicas(from.evaluator_replicas()); + } } void DistributedTensorflowTrainingTask::CopyFrom(const ::google::protobuf::Message& from) { @@ -438,6 +481,7 @@ void DistributedTensorflowTrainingTask::InternalSwap(DistributedTensorflowTraini swap(workers_, other->workers_); swap(ps_replicas_, other->ps_replicas_); swap(chief_replicas_, other->chief_replicas_); + swap(evaluator_replicas_, other->evaluator_replicas_); } ::google::protobuf::Metadata DistributedTensorflowTrainingTask::GetMetadata() const { diff --git a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h index 613ed31d80..4150592a60 100644 --- a/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h +++ b/flyteidl/gen/pb-cpp/flyteidl/plugins/tensorflow.pb.h @@ -178,6 +178,12 @@ class DistributedTensorflowTrainingTask final : ::google::protobuf::int32 chief_replicas() const; void set_chief_replicas(::google::protobuf::int32 value); + // int32 evaluator_replicas = 4; + void clear_evaluator_replicas(); + static const int kEvaluatorReplicasFieldNumber = 4; + ::google::protobuf::int32 evaluator_replicas() const; + void set_evaluator_replicas(::google::protobuf::int32 value); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.DistributedTensorflowTrainingTask) private: class HasBitSetters; @@ -186,6 +192,7 @@ class DistributedTensorflowTrainingTask final : ::google::protobuf::int32 workers_; ::google::protobuf::int32 ps_replicas_; ::google::protobuf::int32 chief_replicas_; + ::google::protobuf::int32 evaluator_replicas_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2ftensorflow_2eproto; }; @@ -242,6 +249,20 @@ inline void DistributedTensorflowTrainingTask::set_chief_replicas(::google::prot // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.chief_replicas) } +// int32 evaluator_replicas = 4; +inline void DistributedTensorflowTrainingTask::clear_evaluator_replicas() { + evaluator_replicas_ = 0; +} +inline ::google::protobuf::int32 DistributedTensorflowTrainingTask::evaluator_replicas() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.DistributedTensorflowTrainingTask.evaluator_replicas) + return evaluator_replicas_; +} +inline void DistributedTensorflowTrainingTask::set_evaluator_replicas(::google::protobuf::int32 value) { + + evaluator_replicas_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.DistributedTensorflowTrainingTask.evaluator_replicas) +} + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif // __GNUC__ diff --git a/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.go b/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.go index 436d00d29a..7511d46559 100644 --- a/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.go @@ -349,10 +349,12 @@ type Resource struct { // The outputs of the execution. It's typically used by sql task. Agent service will create a // Structured dataset pointing to the query result table. // +optional - Outputs *core.LiteralMap `protobuf:"bytes,2,opt,name=outputs,proto3" json:"outputs,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + Outputs *core.LiteralMap `protobuf:"bytes,2,opt,name=outputs,proto3" json:"outputs,omitempty"` + // A descriptive message for the current state. e.g. waiting for cluster. + Message string `protobuf:"bytes,3,opt,name=message,proto3" json:"message,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *Resource) Reset() { *m = Resource{} } @@ -394,6 +396,13 @@ func (m *Resource) GetOutputs() *core.LiteralMap { return nil } +func (m *Resource) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + // A message used to delete a task. type DeleteTaskRequest struct { // A predefined yet extensible Task type identifier. @@ -494,51 +503,52 @@ func init() { func init() { proto.RegisterFile("flyteidl/admin/agent.proto", fileDescriptor_c434e52bb0028071) } var fileDescriptor_c434e52bb0028071 = []byte{ - // 726 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x55, 0xed, 0x6e, 0xe2, 0x46, - 0x14, 0x2d, 0x10, 0x08, 0x5c, 0xf2, 0x01, 0xd3, 0xa0, 0x3a, 0x24, 0xad, 0x22, 0xaa, 0x56, 0x51, - 0xab, 0x1a, 0x85, 0x54, 0x6d, 0xd2, 0xaa, 0xad, 0x48, 0x70, 0x11, 0x12, 0x41, 0xd1, 0x04, 0xaa, - 0xb6, 0xd2, 0x2e, 0x1a, 0xcc, 0x85, 0xb5, 0x30, 0x63, 0xaf, 0x67, 0x8c, 0xc2, 0xef, 0x7d, 0x89, - 0x7d, 0xdc, 0x95, 0xc7, 0x1f, 0x01, 0xc4, 0xae, 0x12, 0xed, 0x3f, 0xcf, 0x3d, 0xe7, 0x9e, 0x39, - 0x73, 0xee, 0xd8, 0x86, 0xea, 0xc4, 0x5e, 0x4a, 0xb4, 0xc6, 0x76, 0x9d, 0x8d, 0xe7, 0x16, 0xaf, - 0xb3, 0x29, 0x72, 0xa9, 0xbb, 0x9e, 0x23, 0x1d, 0x72, 0x10, 0x63, 0xba, 0xc2, 0xaa, 0xa7, 0x09, - 0xd7, 0x74, 0x3c, 0xac, 0xdb, 0x96, 0x44, 0x8f, 0xd9, 0x22, 0x64, 0x57, 0x8f, 0xd7, 0x51, 0xc9, - 0xc4, 0x2c, 0x86, 0xbe, 0x5e, 0x87, 0x2c, 0x2e, 0xd1, 0x9b, 0x30, 0x13, 0x23, 0xf8, 0x9b, 0x0d, - 0x78, 0x8c, 0x5c, 0x5a, 0x13, 0x0b, 0xbd, 0x10, 0xaf, 0xbd, 0xcf, 0x42, 0xa5, 0xcf, 0xc4, 0xcc, - 0x78, 0x44, 0xd3, 0x97, 0x96, 0xc3, 0xef, 0x50, 0xb2, 0x31, 0x93, 0x8c, 0x50, 0x28, 0x07, 0xfb, - 0x0c, 0x31, 0x46, 0x86, 0xd6, 0x58, 0x4b, 0x9d, 0xa5, 0xce, 0x8b, 0x8d, 0xef, 0xf5, 0xc4, 0x7d, - 0xa0, 0xaa, 0xaf, 0x09, 0x74, 0x92, 0x2d, 0xe8, 0xa1, 0x5c, 0x07, 0xc8, 0x29, 0x14, 0x38, 0x9b, - 0xa3, 0x70, 0x99, 0x89, 0x5a, 0xfa, 0x2c, 0x75, 0x5e, 0xa0, 0x4f, 0x05, 0xd2, 0x81, 0x9c, 0xcd, - 0x46, 0x68, 0x0b, 0x2d, 0x73, 0x96, 0x39, 0x2f, 0x36, 0x2e, 0xf4, 0xf5, 0x90, 0xf4, 0xad, 0x46, - 0xf5, 0xae, 0xea, 0x31, 0xb8, 0xf4, 0x96, 0x34, 0x12, 0x20, 0xff, 0x42, 0x91, 0x71, 0xee, 0x48, - 0x16, 0x30, 0x85, 0xb6, 0xa3, 0xf4, 0x7e, 0x79, 0x9e, 0x5e, 0xf3, 0xa9, 0x31, 0x14, 0x5d, 0x95, - 0x22, 0x3a, 0x7c, 0x39, 0xbb, 0x12, 0x43, 0x81, 0xde, 0xc2, 0x32, 0x71, 0xc8, 0x4c, 0xd3, 0xf1, - 0xb9, 0xd4, 0xb2, 0xea, 0x30, 0xe5, 0xd9, 0x95, 0x78, 0x08, 0x91, 0x66, 0x08, 0x10, 0x09, 0x15, - 0xe4, 0x0b, 0xcb, 0x73, 0xf8, 0x1c, 0xb9, 0x1c, 0x2e, 0x98, 0x67, 0xb1, 0x91, 0x8d, 0x42, 0xcb, - 0x29, 0x4f, 0x7f, 0x3d, 0xcf, 0x93, 0xf1, 0x24, 0xf1, 0x4f, 0xac, 0x10, 0x9a, 0x3b, 0xc2, 0x2d, - 0x50, 0xf5, 0x1a, 0x8a, 0x2b, 0xb1, 0x90, 0x12, 0x64, 0x66, 0xb8, 0x54, 0xd3, 0x2b, 0xd0, 0xe0, - 0x91, 0x1c, 0x41, 0x76, 0xc1, 0x6c, 0x3f, 0x9e, 0x42, 0xb8, 0xf8, 0x2d, 0x7d, 0x95, 0xaa, 0xfe, - 0x09, 0xa5, 0xcd, 0x04, 0x5e, 0xd4, 0xdf, 0x86, 0xe3, 0x8f, 0xba, 0x7d, 0x89, 0x50, 0xed, 0x5d, - 0x1a, 0xca, 0xb7, 0x1e, 0x32, 0x89, 0x41, 0x26, 0x14, 0xdf, 0xfa, 0x28, 0x24, 0xb9, 0x80, 0x9c, - 0xc5, 0x5d, 0x5f, 0x8a, 0xe8, 0x2e, 0x1e, 0x6f, 0xdc, 0xc5, 0x6e, 0xf8, 0xe6, 0xdc, 0x31, 0x97, - 0x46, 0x44, 0xf2, 0x2b, 0xe4, 0x25, 0xce, 0x5d, 0x9b, 0xc9, 0x70, 0x97, 0x62, 0xe3, 0x64, 0xcb, - 0x05, 0xee, 0x47, 0x14, 0x9a, 0x90, 0xc9, 0xb7, 0xb0, 0xef, 0xf8, 0xd2, 0xf5, 0xe5, 0xd0, 0xf5, - 0x70, 0x62, 0x3d, 0x6a, 0x19, 0xe5, 0x71, 0x2f, 0x2c, 0xde, 0xab, 0x1a, 0x79, 0x05, 0x5f, 0x6d, - 0xbc, 0x27, 0xf3, 0x68, 0x6a, 0xda, 0x8e, 0xda, 0xec, 0xbb, 0x67, 0x8d, 0x98, 0x56, 0xe4, 0xb6, - 0x72, 0xed, 0x1a, 0xc8, 0x6a, 0x08, 0xc2, 0x75, 0xb8, 0x50, 0xce, 0x3c, 0x14, 0x8e, 0xef, 0x99, - 0xa8, 0xb6, 0x53, 0x61, 0xec, 0xd1, 0xbd, 0xb8, 0x18, 0xb4, 0xd7, 0x28, 0x1c, 0xb4, 0x51, 0xae, - 0x86, 0x77, 0x02, 0x05, 0xe5, 0x55, 0x2e, 0x5d, 0x8c, 0x86, 0x90, 0x0f, 0x0a, 0xfd, 0xa5, 0xbb, - 0x45, 0x33, 0xbd, 0x45, 0xb3, 0x0d, 0x87, 0x89, 0x66, 0xe4, 0xe5, 0x67, 0xc8, 0xc7, 0x94, 0x68, - 0x26, 0xda, 0xe6, 0x89, 0x69, 0x84, 0xd3, 0x84, 0x59, 0xb3, 0x21, 0x1f, 0x57, 0xc9, 0x8f, 0x90, - 0x15, 0x32, 0x98, 0x4e, 0xd0, 0x7e, 0xd0, 0xa8, 0x6c, 0xb6, 0x3f, 0x04, 0x20, 0x0d, 0x39, 0xe4, - 0x12, 0x76, 0xc3, 0xfc, 0x45, 0x34, 0xcc, 0x4f, 0xdc, 0x80, 0x98, 0x59, 0x1b, 0x40, 0xb9, 0x85, - 0x36, 0xae, 0x5f, 0xa5, 0xcf, 0x4f, 0xe3, 0x08, 0xc8, 0xaa, 0x6c, 0x18, 0xc8, 0x0f, 0xaf, 0x21, - 0xab, 0x1c, 0x93, 0x0a, 0x94, 0xa9, 0xd1, 0xa7, 0xff, 0x35, 0x6f, 0xba, 0xc6, 0xf0, 0xef, 0x66, - 0xa7, 0x3b, 0xa0, 0x46, 0xe9, 0x8b, 0xa0, 0x7c, 0x6f, 0xd0, 0xbb, 0x66, 0xcf, 0xe8, 0xf5, 0x93, - 0x72, 0x8a, 0x14, 0x61, 0xf7, 0xde, 0xe8, 0xb5, 0x3a, 0xbd, 0x76, 0x29, 0x1d, 0x2c, 0xe8, 0xa0, - 0xd7, 0x0b, 0x16, 0x19, 0xb2, 0x0f, 0x85, 0x87, 0xc1, 0xed, 0xad, 0x61, 0xb4, 0x8c, 0x56, 0x69, - 0xe7, 0xe6, 0x8f, 0xff, 0x7f, 0x9f, 0x5a, 0xf2, 0x8d, 0x3f, 0xd2, 0x4d, 0x67, 0x5e, 0x57, 0x87, - 0x77, 0xbc, 0x69, 0xf8, 0x50, 0x4f, 0xbe, 0xf7, 0x53, 0xe4, 0x75, 0x77, 0xf4, 0xd3, 0xd4, 0xa9, - 0xaf, 0xff, 0x86, 0x46, 0x39, 0xf5, 0xe5, 0xbf, 0xfc, 0x10, 0x00, 0x00, 0xff, 0xff, 0x14, 0xef, - 0x1e, 0x8b, 0x9f, 0x06, 0x00, 0x00, + // 740 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x55, 0x7f, 0x6b, 0x2a, 0x47, + 0x14, 0xad, 0x1a, 0x8d, 0x5e, 0xf3, 0x43, 0xa7, 0x91, 0x6e, 0x4c, 0x5a, 0x82, 0xa5, 0x25, 0xb4, + 0x74, 0x25, 0xa6, 0xb4, 0x49, 0x4b, 0x5b, 0x4c, 0xdc, 0x8a, 0x60, 0x24, 0x4c, 0xb4, 0xb4, 0x85, + 0x56, 0xc6, 0xf5, 0xea, 0x5b, 0x5c, 0x67, 0xf7, 0xed, 0xcc, 0x4a, 0xfc, 0xfb, 0xc1, 0xfb, 0x0c, + 0xef, 0xe3, 0x3e, 0x76, 0xf6, 0x47, 0x54, 0x7c, 0x8f, 0x84, 0xf7, 0xdf, 0xce, 0x3d, 0xf7, 0x9e, + 0x39, 0x73, 0xce, 0x2c, 0x03, 0xd5, 0x89, 0xbd, 0x94, 0x68, 0x8d, 0xed, 0x3a, 0x1b, 0xcf, 0x2d, + 0x5e, 0x67, 0x53, 0xe4, 0x52, 0x77, 0x3d, 0x47, 0x3a, 0xe4, 0x20, 0xc6, 0x74, 0x85, 0x55, 0x4f, + 0x93, 0x5e, 0xd3, 0xf1, 0xb0, 0x6e, 0x5b, 0x12, 0x3d, 0x66, 0x8b, 0xb0, 0xbb, 0x7a, 0xbc, 0x8e, + 0x4a, 0x26, 0x66, 0x31, 0xf4, 0xe5, 0x3a, 0x64, 0x71, 0x89, 0xde, 0x84, 0x99, 0x18, 0xc1, 0x5f, + 0x6d, 0xc0, 0x63, 0xe4, 0xd2, 0x9a, 0x58, 0xe8, 0x85, 0x78, 0xed, 0x5d, 0x16, 0x2a, 0x7d, 0x26, + 0x66, 0xc6, 0x23, 0x9a, 0xbe, 0xb4, 0x1c, 0x7e, 0x87, 0x92, 0x8d, 0x99, 0x64, 0x84, 0x42, 0x39, + 0xd8, 0x67, 0x88, 0x31, 0x32, 0xb4, 0xc6, 0x5a, 0xea, 0x2c, 0x75, 0x5e, 0x6c, 0x7c, 0xab, 0x27, + 0xea, 0x03, 0x56, 0x7d, 0x8d, 0xa0, 0x93, 0x6c, 0x41, 0x0f, 0xe5, 0x3a, 0x40, 0x4e, 0xa1, 0xc0, + 0xd9, 0x1c, 0x85, 0xcb, 0x4c, 0xd4, 0xd2, 0x67, 0xa9, 0xf3, 0x02, 0x7d, 0x2a, 0x90, 0x0e, 0xe4, + 0x6c, 0x36, 0x42, 0x5b, 0x68, 0x99, 0xb3, 0xcc, 0x79, 0xb1, 0x71, 0xa1, 0xaf, 0x9b, 0xa4, 0x6f, + 0x15, 0xaa, 0x77, 0xd5, 0x8c, 0xc1, 0xa5, 0xb7, 0xa4, 0x11, 0x01, 0xf9, 0x1b, 0x8a, 0x8c, 0x73, + 0x47, 0xb2, 0xa0, 0x53, 0x68, 0x3b, 0x8a, 0xef, 0xa7, 0xe7, 0xf1, 0x35, 0x9f, 0x06, 0x43, 0xd2, + 0x55, 0x2a, 0xa2, 0xc3, 0xe7, 0xb3, 0x2b, 0x31, 0x14, 0xe8, 0x2d, 0x2c, 0x13, 0x87, 0xcc, 0x34, + 0x1d, 0x9f, 0x4b, 0x2d, 0xab, 0x0e, 0x53, 0x9e, 0x5d, 0x89, 0x87, 0x10, 0x69, 0x86, 0x00, 0x91, + 0x50, 0x41, 0xbe, 0xb0, 0x3c, 0x87, 0xcf, 0x91, 0xcb, 0xe1, 0x82, 0x79, 0x16, 0x1b, 0xd9, 0x28, + 0xb4, 0x9c, 0xd2, 0xf4, 0xc7, 0xf3, 0x34, 0x19, 0x4f, 0x14, 0x7f, 0xc5, 0x0c, 0xa1, 0xb8, 0x23, + 0xdc, 0x02, 0x55, 0xaf, 0xa1, 0xb8, 0x62, 0x0b, 0x29, 0x41, 0x66, 0x86, 0x4b, 0x95, 0x5e, 0x81, + 0x06, 0x9f, 0xe4, 0x08, 0xb2, 0x0b, 0x66, 0xfb, 0x71, 0x0a, 0xe1, 0xe2, 0x97, 0xf4, 0x55, 0xaa, + 0xfa, 0x3b, 0x94, 0x36, 0x1d, 0x78, 0xd1, 0x7c, 0x1b, 0x8e, 0x3f, 0xa8, 0xf6, 0x25, 0x44, 0xb5, + 0x37, 0x69, 0x28, 0xdf, 0x7a, 0xc8, 0x24, 0x06, 0x9e, 0x50, 0x7c, 0xed, 0xa3, 0x90, 0xe4, 0x02, + 0x72, 0x16, 0x77, 0x7d, 0x29, 0xa2, 0xbb, 0x78, 0xbc, 0x71, 0x17, 0xbb, 0xe1, 0x9f, 0x73, 0xc7, + 0x5c, 0x1a, 0x35, 0x92, 0x9f, 0x21, 0x2f, 0x71, 0xee, 0xda, 0x4c, 0x86, 0xbb, 0x14, 0x1b, 0x27, + 0x5b, 0x2e, 0x70, 0x3f, 0x6a, 0xa1, 0x49, 0x33, 0xf9, 0x1a, 0xf6, 0x1d, 0x5f, 0xba, 0xbe, 0x1c, + 0xba, 0x1e, 0x4e, 0xac, 0x47, 0x2d, 0xa3, 0x34, 0xee, 0x85, 0xc5, 0x7b, 0x55, 0x23, 0xff, 0xc1, + 0x17, 0x1b, 0xff, 0xc9, 0x3c, 0x4a, 0x4d, 0xdb, 0x51, 0x9b, 0x7d, 0xf3, 0xac, 0x88, 0x69, 0x45, + 0x6e, 0x2b, 0xd7, 0xae, 0x81, 0xac, 0x9a, 0x20, 0x5c, 0x87, 0x0b, 0xa5, 0xcc, 0x43, 0xe1, 0xf8, + 0x9e, 0x89, 0x6a, 0x3b, 0x65, 0xc6, 0x1e, 0xdd, 0x8b, 0x8b, 0xc1, 0x78, 0x8d, 0xc2, 0x41, 0x1b, + 0xe5, 0xaa, 0x79, 0x27, 0x50, 0x50, 0x5a, 0xe5, 0xd2, 0xc5, 0x28, 0x84, 0x7c, 0x50, 0xe8, 0x2f, + 0xdd, 0x2d, 0x9c, 0xe9, 0x2d, 0x9c, 0x6d, 0x38, 0x4c, 0x38, 0x23, 0x2d, 0x3f, 0x42, 0x3e, 0x6e, + 0x89, 0x32, 0xd1, 0x36, 0x4f, 0x4c, 0x23, 0x9c, 0x26, 0x9d, 0xb5, 0xb7, 0x29, 0xc8, 0xc7, 0x65, + 0xf2, 0x3d, 0x64, 0x85, 0x0c, 0xe2, 0x09, 0xe6, 0x0f, 0x1a, 0x95, 0xcd, 0xf9, 0x87, 0x00, 0xa4, + 0x61, 0x0f, 0xb9, 0x84, 0xdd, 0x30, 0x00, 0x11, 0xa5, 0xf9, 0x91, 0x2b, 0x10, 0x77, 0x12, 0x0d, + 0x76, 0xe7, 0x28, 0x04, 0x9b, 0x62, 0x14, 0x62, 0xbc, 0xac, 0x0d, 0xa0, 0xdc, 0x42, 0x1b, 0xd7, + 0x6f, 0xd9, 0xa7, 0x1b, 0x75, 0x04, 0x64, 0x95, 0x36, 0xf4, 0xea, 0xbb, 0xff, 0x21, 0xab, 0xce, + 0x42, 0x2a, 0x50, 0xa6, 0x46, 0x9f, 0xfe, 0xd3, 0xbc, 0xe9, 0x1a, 0xc3, 0x3f, 0x9b, 0x9d, 0xee, + 0x80, 0x1a, 0xa5, 0xcf, 0x82, 0xf2, 0xbd, 0x41, 0xef, 0x9a, 0x3d, 0xa3, 0xd7, 0x4f, 0xca, 0x29, + 0x52, 0x84, 0xdd, 0x7b, 0xa3, 0xd7, 0xea, 0xf4, 0xda, 0xa5, 0x74, 0xb0, 0xa0, 0x83, 0x5e, 0x2f, + 0x58, 0x64, 0xc8, 0x3e, 0x14, 0x1e, 0x06, 0xb7, 0xb7, 0x86, 0xd1, 0x32, 0x5a, 0xa5, 0x9d, 0x9b, + 0xdf, 0xfe, 0xfd, 0x75, 0x6a, 0xc9, 0x57, 0xfe, 0x48, 0x37, 0x9d, 0x79, 0x5d, 0xd9, 0xe2, 0x78, + 0xd3, 0xf0, 0xa3, 0x9e, 0x3c, 0x05, 0x53, 0xe4, 0x75, 0x77, 0xf4, 0xc3, 0xd4, 0xa9, 0xaf, 0xbf, + 0x50, 0xa3, 0x9c, 0x7a, 0x14, 0x2e, 0xdf, 0x07, 0x00, 0x00, 0xff, 0xff, 0xc9, 0x0a, 0x0e, 0xc8, + 0xba, 0x06, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.validate.go index 0ba579cb0b..7efab6b18a 100644 --- a/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/admin/agent.pb.validate.go @@ -454,6 +454,8 @@ func (m *Resource) Validate() error { } } + // no validation rules for Message + return nil } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go index ccb4eff2ac..04243dec6e 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.go @@ -32,10 +32,12 @@ type DistributedTensorflowTrainingTask struct { // RunPolicy encapsulates various runtime policies of the distributed training // job, for example how to clean up resources and how long the job can stay // active. - RunPolicy *RunPolicy `protobuf:"bytes,4,opt,name=run_policy,json=runPolicy,proto3" json:"run_policy,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + RunPolicy *RunPolicy `protobuf:"bytes,4,opt,name=run_policy,json=runPolicy,proto3" json:"run_policy,omitempty"` + // Evaluator replicas spec + EvaluatorReplicas *DistributedTensorflowTrainingReplicaSpec `protobuf:"bytes,5,opt,name=evaluator_replicas,json=evaluatorReplicas,proto3" json:"evaluator_replicas,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *DistributedTensorflowTrainingTask) Reset() { *m = DistributedTensorflowTrainingTask{} } @@ -91,6 +93,13 @@ func (m *DistributedTensorflowTrainingTask) GetRunPolicy() *RunPolicy { return nil } +func (m *DistributedTensorflowTrainingTask) GetEvaluatorReplicas() *DistributedTensorflowTrainingReplicaSpec { + if m != nil { + return m.EvaluatorReplicas + } + return nil +} + type DistributedTensorflowTrainingReplicaSpec struct { // Number of replicas Replicas int32 `protobuf:"varint,1,opt,name=replicas,proto3" json:"replicas,omitempty"` @@ -170,28 +179,29 @@ func init() { } var fileDescriptor_93de2bd764ddf01a = []byte{ - // 358 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x92, 0x41, 0x4b, 0xc3, 0x30, - 0x14, 0xc7, 0x99, 0x73, 0xe2, 0x32, 0x56, 0xa1, 0x78, 0x98, 0x3b, 0xe9, 0x10, 0x19, 0x82, 0x0d, - 0x4c, 0xf0, 0x26, 0x82, 0xf3, 0xae, 0xc4, 0x9d, 0xbc, 0x8c, 0x36, 0x7b, 0xeb, 0x62, 0xdb, 0x24, - 0xbc, 0xa4, 0x8c, 0x7d, 0x23, 0xbf, 0x98, 0xdf, 0x43, 0xd6, 0xac, 0xed, 0x14, 0x36, 0x3c, 0xec, - 0xf6, 0x5e, 0xf3, 0xcf, 0xff, 0xf7, 0x5e, 0xfa, 0x27, 0xb7, 0xf3, 0x74, 0x65, 0x41, 0xcc, 0x52, - 0xaa, 0xd3, 0x3c, 0x16, 0xd2, 0xd0, 0x24, 0x8f, 0x60, 0x9e, 0xaa, 0x25, 0xb5, 0x20, 0x8d, 0xc2, - 0x75, 0x19, 0x68, 0x54, 0x56, 0xf9, 0x17, 0xa5, 0x36, 0xd8, 0x68, 0x83, 0x52, 0xdb, 0xaf, 0x8e, - 0x28, 0x57, 0x08, 0xd4, 0x86, 0x26, 0x31, 0xee, 0x56, 0xff, 0x66, 0x37, 0x81, 0xab, 0x2c, 0x53, - 0xd2, 0xe9, 0x06, 0x5f, 0x4d, 0x72, 0xf5, 0x22, 0x8c, 0x45, 0x11, 0xe5, 0x16, 0x66, 0x93, 0x8a, - 0x3e, 0xc1, 0x50, 0x48, 0x21, 0xe3, 0x49, 0x68, 0x12, 0x3f, 0x25, 0x67, 0x4b, 0x85, 0x09, 0xe0, - 0x14, 0x41, 0xa7, 0x82, 0x87, 0xa6, 0xd7, 0xb8, 0x6c, 0x0c, 0x3b, 0xa3, 0x71, 0xb0, 0x73, 0xba, - 0x60, 0xaf, 0x2d, 0x73, 0x3e, 0xef, 0x1a, 0x38, 0xf3, 0x9c, 0xf7, 0xe6, 0x93, 0xf1, 0x67, 0xa4, - 0xa3, 0x4d, 0x4d, 0x3a, 0x3a, 0x1c, 0x89, 0x68, 0x53, 0x51, 0x3e, 0x89, 0xc7, 0x17, 0x02, 0xe6, - 0x35, 0xa8, 0x79, 0x38, 0x50, 0xb7, 0xb0, 0xae, 0x58, 0x63, 0x42, 0x30, 0x97, 0x53, 0xad, 0x52, - 0xc1, 0x57, 0xbd, 0xe3, 0x82, 0x73, 0xbd, 0x87, 0xc3, 0x72, 0xf9, 0x56, 0x68, 0x59, 0x1b, 0xcb, - 0x72, 0xf0, 0xdd, 0x20, 0xc3, 0xff, 0x0e, 0xe0, 0xf7, 0xc9, 0xe9, 0xaf, 0x5f, 0xd5, 0x62, 0x55, - 0xef, 0x9f, 0x93, 0x96, 0xc8, 0xc2, 0x18, 0x8a, 0x97, 0x6d, 0x33, 0xd7, 0xf8, 0x0f, 0xa4, 0x8d, - 0x60, 0x54, 0x8e, 0x1c, 0xca, 0xa7, 0xe8, 0xd5, 0x23, 0xae, 0x03, 0x16, 0xb0, 0xf2, 0x9c, 0xd5, - 0x52, 0xff, 0x95, 0x78, 0x08, 0xc6, 0x86, 0x68, 0xb7, 0xf7, 0xf3, 0x46, 0xc3, 0x7d, 0xfb, 0xb9, - 0x0b, 0x9b, 0x1d, 0xbb, 0xb8, 0xdd, 0x3e, 0x3f, 0x7d, 0x3c, 0xc6, 0xc2, 0x2e, 0xf2, 0x28, 0xe0, - 0x2a, 0xa3, 0x85, 0x89, 0xc2, 0xd8, 0x15, 0xb4, 0x8a, 0x75, 0x0c, 0x92, 0xea, 0xe8, 0x2e, 0x56, - 0xf4, 0x6f, 0xd2, 0xa3, 0x93, 0x22, 0xda, 0xf7, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x15, 0x1b, - 0x83, 0x24, 0x66, 0x03, 0x00, 0x00, + // 382 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x93, 0xc1, 0x6a, 0xe3, 0x30, + 0x10, 0x86, 0xc9, 0x6e, 0xb2, 0x6c, 0x14, 0xe2, 0x65, 0xc5, 0x1e, 0xb2, 0x39, 0xed, 0x86, 0x65, + 0x09, 0x85, 0x5a, 0x90, 0x42, 0x6f, 0xa5, 0xd0, 0xf4, 0xde, 0xa2, 0xe6, 0xd4, 0x4b, 0x90, 0x15, + 0xc5, 0x51, 0x2d, 0x4b, 0x62, 0x24, 0x35, 0xe4, 0x5d, 0xfa, 0x7a, 0x7d, 0x8f, 0x12, 0x3b, 0xb6, + 0xd3, 0x42, 0x42, 0x0f, 0xb9, 0xcd, 0x58, 0xff, 0xfc, 0xdf, 0x78, 0xa4, 0x41, 0x67, 0x4b, 0xb5, + 0xf1, 0x42, 0x2e, 0x14, 0xb1, 0x2a, 0xa4, 0x52, 0x3b, 0x92, 0x85, 0x44, 0x2c, 0x95, 0x59, 0x13, + 0x2f, 0xb4, 0x33, 0xb0, 0x0d, 0x63, 0x0b, 0xc6, 0x1b, 0xfc, 0xbb, 0xd2, 0xc6, 0x3b, 0x6d, 0x5c, + 0x69, 0x87, 0xf5, 0x11, 0xe1, 0x06, 0x04, 0xf1, 0xcc, 0x65, 0xae, 0xac, 0x1a, 0xfe, 0x3f, 0x4c, + 0xe0, 0x26, 0xcf, 0x8d, 0x2e, 0x75, 0xa3, 0x97, 0x36, 0xfa, 0x7b, 0x2b, 0x9d, 0x07, 0x99, 0x04, + 0x2f, 0x16, 0xb3, 0x9a, 0x3e, 0x03, 0x26, 0xb5, 0xd4, 0xe9, 0x8c, 0xb9, 0x0c, 0x2b, 0xf4, 0x63, + 0x6d, 0x20, 0x13, 0x30, 0x07, 0x61, 0x95, 0xe4, 0xcc, 0x0d, 0x5a, 0x7f, 0x5a, 0xe3, 0xde, 0x64, + 0x1a, 0x1f, 0xec, 0x2e, 0x3e, 0x6a, 0x4b, 0x4b, 0x9f, 0x07, 0x2b, 0x38, 0x8d, 0x4a, 0xef, 0xdd, + 0x27, 0x87, 0x17, 0xa8, 0x67, 0x5d, 0x43, 0xfa, 0x72, 0x3a, 0x12, 0xb2, 0xae, 0xa6, 0x3c, 0xa1, + 0x88, 0xaf, 0xa4, 0x58, 0x36, 0xa0, 0xaf, 0xa7, 0x03, 0xf5, 0x0b, 0xeb, 0x9a, 0x35, 0x45, 0x08, + 0x82, 0x9e, 0x5b, 0xa3, 0x24, 0xdf, 0x0c, 0xda, 0x05, 0xe7, 0xdf, 0x11, 0x0e, 0x0d, 0xfa, 0xbe, + 0xd0, 0xd2, 0x2e, 0x54, 0x21, 0x06, 0x84, 0xc5, 0x33, 0x53, 0x81, 0x79, 0xb3, 0x77, 0x0f, 0x9d, + 0xd3, 0x35, 0xfd, 0xb3, 0xb6, 0xaf, 0x1a, 0x1f, 0xbd, 0xb6, 0xd0, 0xf8, 0xb3, 0xf5, 0x78, 0x88, + 0xbe, 0xbf, 0x7b, 0x1e, 0x1d, 0x5a, 0xe7, 0xf8, 0x17, 0xea, 0xc8, 0x9c, 0xa5, 0xa2, 0xb8, 0xcd, + 0x2e, 0x2d, 0x13, 0x7c, 0x89, 0xba, 0x20, 0x9c, 0x09, 0xc0, 0x45, 0x35, 0xfe, 0x41, 0xf3, 0x27, + 0xdb, 0x47, 0x1d, 0xd3, 0xea, 0x9c, 0x36, 0x52, 0x7c, 0x87, 0x22, 0x10, 0xce, 0x33, 0xf0, 0xfb, + 0x33, 0x8d, 0x26, 0xe3, 0x63, 0x33, 0x2d, 0x0b, 0x76, 0x73, 0xed, 0xc3, 0x7e, 0x7a, 0x73, 0xfd, + 0x78, 0x95, 0x4a, 0xbf, 0x0a, 0x49, 0xcc, 0x4d, 0x4e, 0x0a, 0x13, 0x03, 0x69, 0x19, 0x90, 0x7a, + 0x95, 0x52, 0xa1, 0x89, 0x4d, 0xce, 0x53, 0x43, 0x3e, 0x6e, 0x57, 0xf2, 0xad, 0x58, 0xa7, 0x8b, + 0xb7, 0x00, 0x00, 0x00, 0xff, 0xff, 0x74, 0xcf, 0x60, 0xf4, 0xda, 0x03, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go index 098b4dc7cf..397b3a813b 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow/tensorflow.pb.validate.go @@ -84,6 +84,16 @@ func (m *DistributedTensorflowTrainingTask) Validate() error { } } + if v, ok := interface{}(m.GetEvaluatorReplicas()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return DistributedTensorflowTrainingTaskValidationError{ + field: "EvaluatorReplicas", + reason: "embedded message failed validation", + cause: err, + } + } + } + return nil } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go index d9f1006792..a07ff3feeb 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.go @@ -22,11 +22,15 @@ const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator type DistributedTensorflowTrainingTask struct { - // number of worker, ps, chief replicas spawned in the cluster for this job + // number of worker replicas spawned in the cluster for this job Workers int32 `protobuf:"varint,1,opt,name=workers,proto3" json:"workers,omitempty"` // PS -> Parameter server - PsReplicas int32 `protobuf:"varint,2,opt,name=ps_replicas,json=psReplicas,proto3" json:"ps_replicas,omitempty"` - ChiefReplicas int32 `protobuf:"varint,3,opt,name=chief_replicas,json=chiefReplicas,proto3" json:"chief_replicas,omitempty"` + // number of ps replicas spawned in the cluster for this job + PsReplicas int32 `protobuf:"varint,2,opt,name=ps_replicas,json=psReplicas,proto3" json:"ps_replicas,omitempty"` + // number of chief replicas spawned in the cluster for this job + ChiefReplicas int32 `protobuf:"varint,3,opt,name=chief_replicas,json=chiefReplicas,proto3" json:"chief_replicas,omitempty"` + // number of evaluator replicas spawned in the cluster for this job + EvaluatorReplicas int32 `protobuf:"varint,4,opt,name=evaluator_replicas,json=evaluatorReplicas,proto3" json:"evaluator_replicas,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -78,6 +82,13 @@ func (m *DistributedTensorflowTrainingTask) GetChiefReplicas() int32 { return 0 } +func (m *DistributedTensorflowTrainingTask) GetEvaluatorReplicas() int32 { + if m != nil { + return m.EvaluatorReplicas + } + return 0 +} + func init() { proto.RegisterType((*DistributedTensorflowTrainingTask)(nil), "flyteidl.plugins.DistributedTensorflowTrainingTask") } @@ -85,18 +96,19 @@ func init() { func init() { proto.RegisterFile("flyteidl/plugins/tensorflow.proto", fileDescriptor_8da02783614e1bcc) } var fileDescriptor_8da02783614e1bcc = []byte{ - // 203 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0x8f, 0xc1, 0x4a, 0xc4, 0x30, - 0x10, 0x86, 0xa9, 0xa2, 0x42, 0x44, 0x91, 0x9c, 0x7a, 0xd3, 0x0a, 0x82, 0x17, 0x9b, 0x83, 0x67, - 0x11, 0xc4, 0x27, 0x28, 0x3d, 0x79, 0x91, 0xa6, 0x4d, 0xd3, 0xa1, 0xd9, 0x4c, 0x98, 0x49, 0x29, - 0xfb, 0x00, 0xfb, 0xde, 0x0b, 0xd9, 0xb6, 0x0b, 0x7b, 0x9b, 0xf9, 0xe7, 0x1b, 0xf8, 0x3f, 0x51, - 0xf4, 0x6e, 0x1f, 0x0d, 0x74, 0x4e, 0x05, 0x37, 0x59, 0xf0, 0xac, 0xa2, 0xf1, 0x8c, 0xd4, 0x3b, - 0x9c, 0xcb, 0x40, 0x18, 0x51, 0x3e, 0xad, 0x48, 0xb9, 0x20, 0xaf, 0x87, 0x4c, 0x14, 0xbf, 0xc0, - 0x91, 0x40, 0x4f, 0xd1, 0x74, 0xf5, 0xf6, 0x51, 0x53, 0x03, 0x1e, 0xbc, 0xad, 0x1b, 0x1e, 0x65, - 0x2e, 0xee, 0x66, 0xa4, 0xd1, 0x10, 0xe7, 0xd9, 0x4b, 0xf6, 0x7e, 0x53, 0xad, 0xab, 0x7c, 0x16, - 0xf7, 0x81, 0xff, 0xc9, 0x04, 0x07, 0x6d, 0xc3, 0xf9, 0x55, 0xba, 0x8a, 0xc0, 0xd5, 0x92, 0xc8, - 0x37, 0xf1, 0xd8, 0x0e, 0x60, 0xfa, 0x33, 0x73, 0x9d, 0x98, 0x87, 0x94, 0xae, 0xd8, 0xcf, 0xf7, - 0xdf, 0x97, 0x85, 0x38, 0x4c, 0xba, 0x6c, 0x71, 0xa7, 0x52, 0x4d, 0x24, 0x7b, 0x1a, 0xd4, 0x26, - 0x66, 0x8d, 0x57, 0x41, 0x7f, 0x58, 0x54, 0x97, 0xae, 0xfa, 0x36, 0x19, 0x7e, 0x1e, 0x03, 0x00, - 0x00, 0xff, 0xff, 0x55, 0x40, 0x42, 0x06, 0x06, 0x01, 0x00, 0x00, + // 220 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x5c, 0xd0, 0xbf, 0x4a, 0x04, 0x31, + 0x10, 0x06, 0x70, 0xd6, 0xbf, 0x10, 0x51, 0x34, 0xd5, 0x76, 0x7a, 0x82, 0x60, 0x73, 0x9b, 0xc2, + 0x5a, 0x04, 0xf1, 0x09, 0x8e, 0xad, 0x6c, 0x24, 0xbb, 0x37, 0x9b, 0x1b, 0x2e, 0x66, 0xc2, 0x4c, + 0xe2, 0xe1, 0x7b, 0xf9, 0x80, 0x42, 0xbc, 0xcd, 0x81, 0x5d, 0xf2, 0x7d, 0xbf, 0x14, 0xf9, 0xd4, + 0x62, 0xf2, 0xdf, 0x09, 0x70, 0xed, 0x4d, 0xf4, 0xd9, 0x61, 0x10, 0x93, 0x20, 0x08, 0xf1, 0xe4, + 0x69, 0xd7, 0x45, 0xa6, 0x44, 0xfa, 0x7a, 0x26, 0xdd, 0x9e, 0xdc, 0xff, 0x34, 0x6a, 0xf1, 0x86, + 0x92, 0x18, 0x87, 0x9c, 0x60, 0xdd, 0xd7, 0x17, 0x3d, 0x5b, 0x0c, 0x18, 0x5c, 0x6f, 0x65, 0xab, + 0x5b, 0x75, 0xbe, 0x23, 0xde, 0x02, 0x4b, 0xdb, 0xdc, 0x35, 0x8f, 0xa7, 0xab, 0xf9, 0xaa, 0x6f, + 0xd5, 0x45, 0x94, 0x0f, 0x86, 0xe8, 0x71, 0xb4, 0xd2, 0x1e, 0x95, 0x56, 0x45, 0x59, 0xed, 0x13, + 0xfd, 0xa0, 0xae, 0xc6, 0x0d, 0xc2, 0x74, 0x30, 0xc7, 0xc5, 0x5c, 0x96, 0xb4, 0xb2, 0xa5, 0xd2, + 0xf0, 0x65, 0x7d, 0xb6, 0x89, 0xf8, 0x40, 0x4f, 0x0a, 0xbd, 0xa9, 0xcd, 0xcc, 0x5f, 0x5f, 0xde, + 0x9f, 0x1d, 0xa6, 0x4d, 0x1e, 0xba, 0x91, 0x3e, 0x4d, 0xf9, 0x15, 0xb1, 0xfb, 0x3b, 0x98, 0xba, + 0x83, 0x83, 0x60, 0xe2, 0xb0, 0x74, 0x64, 0xfe, 0x4f, 0x33, 0x9c, 0x95, 0x41, 0x9e, 0x7e, 0x03, + 0x00, 0x00, 0xff, 0xff, 0xa6, 0x22, 0x34, 0xcf, 0x35, 0x01, 0x00, 0x00, } diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go index ed7a8eeb80..00db969ce6 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/tensorflow.pb.validate.go @@ -50,6 +50,8 @@ func (m *DistributedTensorflowTrainingTask) Validate() error { // no validation rules for ChiefReplicas + // no validation rules for EvaluatorReplicas + return nil } diff --git a/flyteidl/gen/pb-go/flyteidl/service/agent.swagger.json b/flyteidl/gen/pb-go/flyteidl/service/agent.swagger.json index 2815bbdae9..4149506f17 100644 --- a/flyteidl/gen/pb-go/flyteidl/service/agent.swagger.json +++ b/flyteidl/gen/pb-go/flyteidl/service/agent.swagger.json @@ -197,6 +197,10 @@ "outputs": { "$ref": "#/definitions/coreLiteralMap", "title": "The outputs of the execution. It's typically used by sql task. Agent service will create a\nStructured dataset pointing to the query result table.\n+optional" + }, + "message": { + "type": "string", + "description": "A descriptive message for the current state. e.g. waiting for cluster." } } }, diff --git a/flyteidl/gen/pb-java/flyteidl/admin/Agent.java b/flyteidl/gen/pb-java/flyteidl/admin/Agent.java index b0d9d2574e..b2dfd16e61 100644 --- a/flyteidl/gen/pb-java/flyteidl/admin/Agent.java +++ b/flyteidl/gen/pb-java/flyteidl/admin/Agent.java @@ -5460,6 +5460,24 @@ public interface ResourceOrBuilder extends * .flyteidl.core.LiteralMap outputs = 2; */ flyteidl.core.Literals.LiteralMapOrBuilder getOutputsOrBuilder(); + + /** + *
+     * A descriptive message for the current state. e.g. waiting for cluster.
+     * 
+ * + * string message = 3; + */ + java.lang.String getMessage(); + /** + *
+     * A descriptive message for the current state. e.g. waiting for cluster.
+     * 
+ * + * string message = 3; + */ + com.google.protobuf.ByteString + getMessageBytes(); } /** * Protobuf type {@code flyteidl.admin.Resource} @@ -5475,6 +5493,7 @@ private Resource(com.google.protobuf.GeneratedMessageV3.Builder builder) { } private Resource() { state_ = 0; + message_ = ""; } @java.lang.Override @@ -5520,6 +5539,12 @@ private Resource( break; } + case 26: { + java.lang.String s = input.readStringRequireUtf8(); + + message_ = s; + break; + } default: { if (!parseUnknownField( input, unknownFields, extensionRegistry, tag)) { @@ -5616,6 +5641,48 @@ public flyteidl.core.Literals.LiteralMapOrBuilder getOutputsOrBuilder() { return getOutputs(); } + public static final int MESSAGE_FIELD_NUMBER = 3; + private volatile java.lang.Object message_; + /** + *
+     * A descriptive message for the current state. e.g. waiting for cluster.
+     * 
+ * + * string message = 3; + */ + public java.lang.String getMessage() { + java.lang.Object ref = message_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + message_ = s; + return s; + } + } + /** + *
+     * A descriptive message for the current state. e.g. waiting for cluster.
+     * 
+ * + * string message = 3; + */ + public com.google.protobuf.ByteString + getMessageBytes() { + java.lang.Object ref = message_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + message_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -5636,6 +5703,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (outputs_ != null) { output.writeMessage(2, getOutputs()); } + if (!getMessageBytes().isEmpty()) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 3, message_); + } unknownFields.writeTo(output); } @@ -5653,6 +5723,9 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(2, getOutputs()); } + if (!getMessageBytes().isEmpty()) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, message_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -5674,6 +5747,8 @@ public boolean equals(final java.lang.Object obj) { if (!getOutputs() .equals(other.getOutputs())) return false; } + if (!getMessage() + .equals(other.getMessage())) return false; if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -5691,6 +5766,8 @@ public int hashCode() { hash = (37 * hash) + OUTPUTS_FIELD_NUMBER; hash = (53 * hash) + getOutputs().hashCode(); } + hash = (37 * hash) + MESSAGE_FIELD_NUMBER; + hash = (53 * hash) + getMessage().hashCode(); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -5832,6 +5909,8 @@ public Builder clear() { outputs_ = null; outputsBuilder_ = null; } + message_ = ""; + return this; } @@ -5864,6 +5943,7 @@ public flyteidl.admin.Agent.Resource buildPartial() { } else { result.outputs_ = outputsBuilder_.build(); } + result.message_ = message_; onBuilt(); return result; } @@ -5918,6 +5998,10 @@ public Builder mergeFrom(flyteidl.admin.Agent.Resource other) { if (other.hasOutputs()) { mergeOutputs(other.getOutputs()); } + if (!other.getMessage().isEmpty()) { + message_ = other.message_; + onChanged(); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -6182,6 +6266,95 @@ public flyteidl.core.Literals.LiteralMapOrBuilder getOutputsOrBuilder() { } return outputsBuilder_; } + + private java.lang.Object message_ = ""; + /** + *
+       * A descriptive message for the current state. e.g. waiting for cluster.
+       * 
+ * + * string message = 3; + */ + public java.lang.String getMessage() { + java.lang.Object ref = message_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + message_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * A descriptive message for the current state. e.g. waiting for cluster.
+       * 
+ * + * string message = 3; + */ + public com.google.protobuf.ByteString + getMessageBytes() { + java.lang.Object ref = message_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + message_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * A descriptive message for the current state. e.g. waiting for cluster.
+       * 
+ * + * string message = 3; + */ + public Builder setMessage( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + message_ = value; + onChanged(); + return this; + } + /** + *
+       * A descriptive message for the current state. e.g. waiting for cluster.
+       * 
+ * + * string message = 3; + */ + public Builder clearMessage() { + + message_ = getDefaultInstance().getMessage(); + onChanged(); + return this; + } + /** + *
+       * A descriptive message for the current state. e.g. waiting for cluster.
+       * 
+ * + * string message = 3; + */ + public Builder setMessageBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + message_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -7424,15 +7597,16 @@ public flyteidl.admin.Agent.DeleteTaskResponse getDefaultInstanceForType() { "\016GetTaskRequest\022\021\n\ttask_type\030\001 \001(\t\022\025\n\rre" + "source_meta\030\002 \001(\014\"=\n\017GetTaskResponse\022*\n\010" + "resource\030\001 \001(\0132\030.flyteidl.admin.Resource" + - "\"\\\n\010Resource\022$\n\005state\030\001 \001(\0162\025.flyteidl.a" + + "\"m\n\010Resource\022$\n\005state\030\001 \001(\0162\025.flyteidl.a" + "dmin.State\022*\n\007outputs\030\002 \001(\0132\031.flyteidl.c" + - "ore.LiteralMap\"=\n\021DeleteTaskRequest\022\021\n\tt" + - "ask_type\030\001 \001(\t\022\025\n\rresource_meta\030\002 \001(\014\"\024\n" + - "\022DeleteTaskResponse*^\n\005State\022\025\n\021RETRYABL" + - "E_FAILURE\020\000\022\025\n\021PERMANENT_FAILURE\020\001\022\013\n\007PE" + - "NDING\020\002\022\013\n\007RUNNING\020\003\022\r\n\tSUCCEEDED\020\004B=Z;g" + - "ithub.com/flyteorg/flyte/flyteidl/gen/pb" + - "-go/flyteidl/adminb\006proto3" + "ore.LiteralMap\022\017\n\007message\030\003 \001(\t\"=\n\021Delet" + + "eTaskRequest\022\021\n\ttask_type\030\001 \001(\t\022\025\n\rresou" + + "rce_meta\030\002 \001(\014\"\024\n\022DeleteTaskResponse*^\n\005" + + "State\022\025\n\021RETRYABLE_FAILURE\020\000\022\025\n\021PERMANEN" + + "T_FAILURE\020\001\022\013\n\007PENDING\020\002\022\013\n\007RUNNING\020\003\022\r\n" + + "\tSUCCEEDED\020\004B=Z;github.com/flyteorg/flyt" + + "e/flyteidl/gen/pb-go/flyteidl/adminb\006pro" + + "to3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -7503,7 +7677,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_admin_Resource_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_admin_Resource_descriptor, - new java.lang.String[] { "State", "Outputs", }); + new java.lang.String[] { "State", "Outputs", "Message", }); internal_static_flyteidl_admin_DeleteTaskRequest_descriptor = getDescriptor().getMessageTypes().get(6); internal_static_flyteidl_admin_DeleteTaskRequest_fieldAccessorTable = new diff --git a/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java b/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java index 14ab5db6bd..2353f8b78b 100644 --- a/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java +++ b/flyteidl/gen/pb-java/flyteidl/plugins/Tensorflow.java @@ -20,7 +20,7 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends /** *
-     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * number of worker replicas spawned in the cluster for this job
      * 
* * int32 workers = 1; @@ -30,6 +30,7 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends /** *
      * PS -> Parameter server
+     * number of ps replicas spawned in the cluster for this job
      * 
* * int32 ps_replicas = 2; @@ -37,9 +38,22 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends int getPsReplicas(); /** + *
+     * number of chief replicas spawned in the cluster for this job
+     * 
+ * * int32 chief_replicas = 3; */ int getChiefReplicas(); + + /** + *
+     * number of evaluator replicas spawned in the cluster for this job
+     * 
+ * + * int32 evaluator_replicas = 4; + */ + int getEvaluatorReplicas(); } /** *
@@ -99,6 +113,11 @@ private DistributedTensorflowTrainingTask(
               chiefReplicas_ = input.readInt32();
               break;
             }
+            case 32: {
+
+              evaluatorReplicas_ = input.readInt32();
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -135,7 +154,7 @@ private DistributedTensorflowTrainingTask(
     private int workers_;
     /**
      * 
-     * number of worker, ps, chief replicas spawned in the cluster for this job
+     * number of worker replicas spawned in the cluster for this job
      * 
* * int32 workers = 1; @@ -149,6 +168,7 @@ public int getWorkers() { /** *
      * PS -> Parameter server
+     * number of ps replicas spawned in the cluster for this job
      * 
* * int32 ps_replicas = 2; @@ -160,12 +180,29 @@ public int getPsReplicas() { public static final int CHIEF_REPLICAS_FIELD_NUMBER = 3; private int chiefReplicas_; /** + *
+     * number of chief replicas spawned in the cluster for this job
+     * 
+ * * int32 chief_replicas = 3; */ public int getChiefReplicas() { return chiefReplicas_; } + public static final int EVALUATOR_REPLICAS_FIELD_NUMBER = 4; + private int evaluatorReplicas_; + /** + *
+     * number of evaluator replicas spawned in the cluster for this job
+     * 
+ * + * int32 evaluator_replicas = 4; + */ + public int getEvaluatorReplicas() { + return evaluatorReplicas_; + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -189,6 +226,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (chiefReplicas_ != 0) { output.writeInt32(3, chiefReplicas_); } + if (evaluatorReplicas_ != 0) { + output.writeInt32(4, evaluatorReplicas_); + } unknownFields.writeTo(output); } @@ -210,6 +250,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeInt32Size(3, chiefReplicas_); } + if (evaluatorReplicas_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(4, evaluatorReplicas_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -231,6 +275,8 @@ public boolean equals(final java.lang.Object obj) { != other.getPsReplicas()) return false; if (getChiefReplicas() != other.getChiefReplicas()) return false; + if (getEvaluatorReplicas() + != other.getEvaluatorReplicas()) return false; if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -248,6 +294,8 @@ public int hashCode() { hash = (53 * hash) + getPsReplicas(); hash = (37 * hash) + CHIEF_REPLICAS_FIELD_NUMBER; hash = (53 * hash) + getChiefReplicas(); + hash = (37 * hash) + EVALUATOR_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getEvaluatorReplicas(); hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -391,6 +439,8 @@ public Builder clear() { chiefReplicas_ = 0; + evaluatorReplicas_ = 0; + return this; } @@ -420,6 +470,7 @@ public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask buildPartia result.workers_ = workers_; result.psReplicas_ = psReplicas_; result.chiefReplicas_ = chiefReplicas_; + result.evaluatorReplicas_ = evaluatorReplicas_; onBuilt(); return result; } @@ -477,6 +528,9 @@ public Builder mergeFrom(flyteidl.plugins.Tensorflow.DistributedTensorflowTraini if (other.getChiefReplicas() != 0) { setChiefReplicas(other.getChiefReplicas()); } + if (other.getEvaluatorReplicas() != 0) { + setEvaluatorReplicas(other.getEvaluatorReplicas()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -509,7 +563,7 @@ public Builder mergeFrom( private int workers_ ; /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -519,7 +573,7 @@ public int getWorkers() { } /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -532,7 +586,7 @@ public Builder setWorkers(int value) { } /** *
-       * number of worker, ps, chief replicas spawned in the cluster for this job
+       * number of worker replicas spawned in the cluster for this job
        * 
* * int32 workers = 1; @@ -548,6 +602,7 @@ public Builder clearWorkers() { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -558,6 +613,7 @@ public int getPsReplicas() { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -571,6 +627,7 @@ public Builder setPsReplicas(int value) { /** *
        * PS -> Parameter server
+       * number of ps replicas spawned in the cluster for this job
        * 
* * int32 ps_replicas = 2; @@ -584,12 +641,20 @@ public Builder clearPsReplicas() { private int chiefReplicas_ ; /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public int getChiefReplicas() { return chiefReplicas_; } /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public Builder setChiefReplicas(int value) { @@ -599,6 +664,10 @@ public Builder setChiefReplicas(int value) { return this; } /** + *
+       * number of chief replicas spawned in the cluster for this job
+       * 
+ * * int32 chief_replicas = 3; */ public Builder clearChiefReplicas() { @@ -607,6 +676,44 @@ public Builder clearChiefReplicas() { onChanged(); return this; } + + private int evaluatorReplicas_ ; + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public int getEvaluatorReplicas() { + return evaluatorReplicas_; + } + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public Builder setEvaluatorReplicas(int value) { + + evaluatorReplicas_ = value; + onChanged(); + return this; + } + /** + *
+       * number of evaluator replicas spawned in the cluster for this job
+       * 
+ * + * int32 evaluator_replicas = 4; + */ + public Builder clearEvaluatorReplicas() { + + evaluatorReplicas_ = 0; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -675,11 +782,12 @@ public flyteidl.plugins.Tensorflow.DistributedTensorflowTrainingTask getDefaultI static { java.lang.String[] descriptorData = { "\n!flyteidl/plugins/tensorflow.proto\022\020fly" + - "teidl.plugins\"a\n!DistributedTensorflowTr" + + "teidl.plugins\"}\n!DistributedTensorflowTr" + "ainingTask\022\017\n\007workers\030\001 \001(\005\022\023\n\013ps_replic" + - "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005B?Z=gith" + - "ub.com/flyteorg/flyte/flyteidl/gen/pb-go" + - "/flyteidl/pluginsb\006proto3" + "as\030\002 \001(\005\022\026\n\016chief_replicas\030\003 \001(\005\022\032\n\022eval" + + "uator_replicas\030\004 \001(\005B?Z=github.com/flyte" + + "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + + "uginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -698,7 +806,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_DistributedTensorflowTrainingTask_descriptor, - new java.lang.String[] { "Workers", "PsReplicas", "ChiefReplicas", }); + new java.lang.String[] { "Workers", "PsReplicas", "ChiefReplicas", "EvaluatorReplicas", }); } // @@protoc_insertion_point(outer_class_scope) diff --git a/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java b/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java index ec763705c8..324f3e5d01 100644 --- a/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java +++ b/flyteidl/gen/pb-java/flyteidl/plugins/kubeflow/Tensorflow.java @@ -123,6 +123,31 @@ public interface DistributedTensorflowTrainingTaskOrBuilder extends * .flyteidl.plugins.kubeflow.RunPolicy run_policy = 4; */ flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder(); + + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + boolean hasEvaluatorReplicas(); + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas(); + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder(); } /** *
@@ -219,6 +244,19 @@ private DistributedTensorflowTrainingTask(
 
               break;
             }
+            case 42: {
+              flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder subBuilder = null;
+              if (evaluatorReplicas_ != null) {
+                subBuilder = evaluatorReplicas_.toBuilder();
+              }
+              evaluatorReplicas_ = input.readMessage(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.parser(), extensionRegistry);
+              if (subBuilder != null) {
+                subBuilder.mergeFrom(evaluatorReplicas_);
+                evaluatorReplicas_ = subBuilder.buildPartial();
+              }
+
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -389,6 +427,39 @@ public flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder
       return getRunPolicy();
     }
 
+    public static final int EVALUATOR_REPLICAS_FIELD_NUMBER = 5;
+    private flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec evaluatorReplicas_;
+    /**
+     * 
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public boolean hasEvaluatorReplicas() { + return evaluatorReplicas_ != null; + } + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas() { + return evaluatorReplicas_ == null ? flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } + /** + *
+     * Evaluator replicas spec
+     * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder() { + return getEvaluatorReplicas(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -415,6 +486,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (runPolicy_ != null) { output.writeMessage(4, getRunPolicy()); } + if (evaluatorReplicas_ != null) { + output.writeMessage(5, getEvaluatorReplicas()); + } unknownFields.writeTo(output); } @@ -440,6 +514,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(4, getRunPolicy()); } + if (evaluatorReplicas_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, getEvaluatorReplicas()); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -475,6 +553,11 @@ public boolean equals(final java.lang.Object obj) { if (!getRunPolicy() .equals(other.getRunPolicy())) return false; } + if (hasEvaluatorReplicas() != other.hasEvaluatorReplicas()) return false; + if (hasEvaluatorReplicas()) { + if (!getEvaluatorReplicas() + .equals(other.getEvaluatorReplicas())) return false; + } if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -502,6 +585,10 @@ public int hashCode() { hash = (37 * hash) + RUN_POLICY_FIELD_NUMBER; hash = (53 * hash) + getRunPolicy().hashCode(); } + if (hasEvaluatorReplicas()) { + hash = (37 * hash) + EVALUATOR_REPLICAS_FIELD_NUMBER; + hash = (53 * hash) + getEvaluatorReplicas().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -663,6 +750,12 @@ public Builder clear() { runPolicy_ = null; runPolicyBuilder_ = null; } + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = null; + } else { + evaluatorReplicas_ = null; + evaluatorReplicasBuilder_ = null; + } return this; } @@ -709,6 +802,11 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingTask bu } else { result.runPolicy_ = runPolicyBuilder_.build(); } + if (evaluatorReplicasBuilder_ == null) { + result.evaluatorReplicas_ = evaluatorReplicas_; + } else { + result.evaluatorReplicas_ = evaluatorReplicasBuilder_.build(); + } onBuilt(); return result; } @@ -769,6 +867,9 @@ public Builder mergeFrom(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorf if (other.hasRunPolicy()) { mergeRunPolicy(other.getRunPolicy()); } + if (other.hasEvaluatorReplicas()) { + mergeEvaluatorReplicas(other.getEvaluatorReplicas()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -1427,6 +1528,159 @@ public flyteidl.plugins.kubeflow.Common.RunPolicyOrBuilder getRunPolicyOrBuilder } return runPolicyBuilder_; } + + private flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec evaluatorReplicas_; + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder> evaluatorReplicasBuilder_; + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public boolean hasEvaluatorReplicas() { + return evaluatorReplicasBuilder_ != null || evaluatorReplicas_ != null; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec getEvaluatorReplicas() { + if (evaluatorReplicasBuilder_ == null) { + return evaluatorReplicas_ == null ? flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } else { + return evaluatorReplicasBuilder_.getMessage(); + } + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder setEvaluatorReplicas(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec value) { + if (evaluatorReplicasBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + evaluatorReplicas_ = value; + onChanged(); + } else { + evaluatorReplicasBuilder_.setMessage(value); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder setEvaluatorReplicas( + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder builderForValue) { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = builderForValue.build(); + onChanged(); + } else { + evaluatorReplicasBuilder_.setMessage(builderForValue.build()); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder mergeEvaluatorReplicas(flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec value) { + if (evaluatorReplicasBuilder_ == null) { + if (evaluatorReplicas_ != null) { + evaluatorReplicas_ = + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.newBuilder(evaluatorReplicas_).mergeFrom(value).buildPartial(); + } else { + evaluatorReplicas_ = value; + } + onChanged(); + } else { + evaluatorReplicasBuilder_.mergeFrom(value); + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public Builder clearEvaluatorReplicas() { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicas_ = null; + onChanged(); + } else { + evaluatorReplicas_ = null; + evaluatorReplicasBuilder_ = null; + } + + return this; + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder getEvaluatorReplicasBuilder() { + + onChanged(); + return getEvaluatorReplicasFieldBuilder().getBuilder(); + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder getEvaluatorReplicasOrBuilder() { + if (evaluatorReplicasBuilder_ != null) { + return evaluatorReplicasBuilder_.getMessageOrBuilder(); + } else { + return evaluatorReplicas_ == null ? + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.getDefaultInstance() : evaluatorReplicas_; + } + } + /** + *
+       * Evaluator replicas spec
+       * 
+ * + * .flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder> + getEvaluatorReplicasFieldBuilder() { + if (evaluatorReplicasBuilder_ == null) { + evaluatorReplicasBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpec.Builder, flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplicaSpecOrBuilder>( + getEvaluatorReplicas(), + getParentForChildren(), + isClean()); + evaluatorReplicas_ = null; + } + return evaluatorReplicasBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -2553,7 +2807,7 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplica "\n*flyteidl/plugins/kubeflow/tensorflow.p" + "roto\022\031flyteidl.plugins.kubeflow\032\031flyteid" + "l/core/tasks.proto\032&flyteidl/plugins/kub" + - "eflow/common.proto\"\362\002\n!DistributedTensor" + + "eflow/common.proto\"\323\003\n!DistributedTensor" + "flowTrainingTask\022\\\n\017worker_replicas\030\001 \001(" + "\0132C.flyteidl.plugins.kubeflow.Distribute" + "dTensorflowTrainingReplicaSpec\022X\n\013ps_rep" + @@ -2562,14 +2816,16 @@ public flyteidl.plugins.kubeflow.Tensorflow.DistributedTensorflowTrainingReplica "\022[\n\016chief_replicas\030\003 \001(\0132C.flyteidl.plug" + "ins.kubeflow.DistributedTensorflowTraini" + "ngReplicaSpec\0228\n\nrun_policy\030\004 \001(\0132$.flyt" + - "eidl.plugins.kubeflow.RunPolicy\"\272\001\n(Dist" + - "ributedTensorflowTrainingReplicaSpec\022\020\n\010" + - "replicas\030\001 \001(\005\022\r\n\005image\030\002 \001(\t\022+\n\tresourc" + - "es\030\003 \001(\0132\030.flyteidl.core.Resources\022@\n\016re" + - "start_policy\030\004 \001(\0162(.flyteidl.plugins.ku" + - "beflow.RestartPolicyB?Z=github.com/flyte" + - "org/flyte/flyteidl/gen/pb-go/flyteidl/pl" + - "uginsb\006proto3" + "eidl.plugins.kubeflow.RunPolicy\022_\n\022evalu" + + "ator_replicas\030\005 \001(\0132C.flyteidl.plugins.k" + + "ubeflow.DistributedTensorflowTrainingRep" + + "licaSpec\"\272\001\n(DistributedTensorflowTraini" + + "ngReplicaSpec\022\020\n\010replicas\030\001 \001(\005\022\r\n\005image" + + "\030\002 \001(\t\022+\n\tresources\030\003 \001(\0132\030.flyteidl.cor" + + "e.Resources\022@\n\016restart_policy\030\004 \001(\0162(.fl" + + "yteidl.plugins.kubeflow.RestartPolicyB?Z" + + "=github.com/flyteorg/flyte/flyteidl/gen/" + + "pb-go/flyteidl/pluginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -2590,7 +2846,7 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingTask_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingTask_descriptor, - new java.lang.String[] { "WorkerReplicas", "PsReplicas", "ChiefReplicas", "RunPolicy", }); + new java.lang.String[] { "WorkerReplicas", "PsReplicas", "ChiefReplicas", "RunPolicy", "EvaluatorReplicas", }); internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingReplicaSpec_descriptor = getDescriptor().getMessageTypes().get(1); internal_static_flyteidl_plugins_kubeflow_DistributedTensorflowTrainingReplicaSpec_fieldAccessorTable = new diff --git a/flyteidl/gen/pb-js/flyteidl.d.ts b/flyteidl/gen/pb-js/flyteidl.d.ts index aa22fa7a18..2a49c5d9d5 100644 --- a/flyteidl/gen/pb-js/flyteidl.d.ts +++ b/flyteidl/gen/pb-js/flyteidl.d.ts @@ -8181,6 +8181,9 @@ export namespace flyteidl { /** Resource outputs */ outputs?: (flyteidl.core.ILiteralMap|null); + + /** Resource message */ + message?: (string|null); } /** Represents a Resource. */ @@ -8198,6 +8201,9 @@ export namespace flyteidl { /** Resource outputs. */ public outputs?: (flyteidl.core.ILiteralMap|null); + /** Resource message. */ + public message: string; + /** * Creates a new Resource instance using the specified properties. * @param [properties] Properties to set diff --git a/flyteidl/gen/pb-js/flyteidl.js b/flyteidl/gen/pb-js/flyteidl.js index b7cd35b681..666026623d 100644 --- a/flyteidl/gen/pb-js/flyteidl.js +++ b/flyteidl/gen/pb-js/flyteidl.js @@ -19955,6 +19955,7 @@ * @interface IResource * @property {flyteidl.admin.State|null} [state] Resource state * @property {flyteidl.core.ILiteralMap|null} [outputs] Resource outputs + * @property {string|null} [message] Resource message */ /** @@ -19988,6 +19989,14 @@ */ Resource.prototype.outputs = null; + /** + * Resource message. + * @member {string} message + * @memberof flyteidl.admin.Resource + * @instance + */ + Resource.prototype.message = ""; + /** * Creates a new Resource instance using the specified properties. * @function create @@ -20016,6 +20025,8 @@ writer.uint32(/* id 1, wireType 0 =*/8).int32(message.state); if (message.outputs != null && message.hasOwnProperty("outputs")) $root.flyteidl.core.LiteralMap.encode(message.outputs, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.message != null && message.hasOwnProperty("message")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.message); return writer; }; @@ -20043,6 +20054,9 @@ case 2: message.outputs = $root.flyteidl.core.LiteralMap.decode(reader, reader.uint32()); break; + case 3: + message.message = reader.string(); + break; default: reader.skipType(tag & 7); break; @@ -20078,6 +20092,9 @@ if (error) return "outputs." + error; } + if (message.message != null && message.hasOwnProperty("message")) + if (!$util.isString(message.message)) + return "message: string expected"; return null; }; diff --git a/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.py b/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.py index 236728acf1..bb0aa9384a 100644 --- a/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.py @@ -17,7 +17,7 @@ from flyteidl.core import identifier_pb2 as flyteidl_dot_core_dot_identifier__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/admin/agent.proto\x12\x0e\x66lyteidl.admin\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1e\x66lyteidl/core/identifier.proto\"\x98\x05\n\x15TaskExecutionMetadata\x12R\n\x11task_execution_id\x18\x01 \x01(\x0b\x32&.flyteidl.core.TaskExecutionIdentifierR\x0ftaskExecutionId\x12\x1c\n\tnamespace\x18\x02 \x01(\tR\tnamespace\x12I\n\x06labels\x18\x03 \x03(\x0b\x32\x31.flyteidl.admin.TaskExecutionMetadata.LabelsEntryR\x06labels\x12X\n\x0b\x61nnotations\x18\x04 \x03(\x0b\x32\x36.flyteidl.admin.TaskExecutionMetadata.AnnotationsEntryR\x0b\x61nnotations\x12.\n\x13k8s_service_account\x18\x05 \x01(\tR\x11k8sServiceAccount\x12t\n\x15\x65nvironment_variables\x18\x06 \x03(\x0b\x32?.flyteidl.admin.TaskExecutionMetadata.EnvironmentVariablesEntryR\x14\x65nvironmentVariables\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a>\n\x10\x41nnotationsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aG\n\x19\x45nvironmentVariablesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\x83\x02\n\x11\x43reateTaskRequest\x12\x31\n\x06inputs\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x06inputs\x12\x37\n\x08template\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateR\x08template\x12#\n\routput_prefix\x18\x03 \x01(\tR\x0coutputPrefix\x12]\n\x17task_execution_metadata\x18\x04 \x01(\x0b\x32%.flyteidl.admin.TaskExecutionMetadataR\x15taskExecutionMetadata\"9\n\x12\x43reateTaskResponse\x12#\n\rresource_meta\x18\x01 \x01(\x0cR\x0cresourceMeta\"R\n\x0eGetTaskRequest\x12\x1b\n\ttask_type\x18\x01 \x01(\tR\x08taskType\x12#\n\rresource_meta\x18\x02 \x01(\x0cR\x0cresourceMeta\"G\n\x0fGetTaskResponse\x12\x34\n\x08resource\x18\x01 \x01(\x0b\x32\x18.flyteidl.admin.ResourceR\x08resource\"l\n\x08Resource\x12+\n\x05state\x18\x01 \x01(\x0e\x32\x15.flyteidl.admin.StateR\x05state\x12\x33\n\x07outputs\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x07outputs\"U\n\x11\x44\x65leteTaskRequest\x12\x1b\n\ttask_type\x18\x01 \x01(\tR\x08taskType\x12#\n\rresource_meta\x18\x02 \x01(\x0cR\x0cresourceMeta\"\x14\n\x12\x44\x65leteTaskResponse*^\n\x05State\x12\x15\n\x11RETRYABLE_FAILURE\x10\x00\x12\x15\n\x11PERMANENT_FAILURE\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x42\xb6\x01\n\x12\x63om.flyteidl.adminB\nAgentProtoP\x01Z;github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin\xa2\x02\x03\x46\x41X\xaa\x02\x0e\x46lyteidl.Admin\xca\x02\x0e\x46lyteidl\\Admin\xe2\x02\x1a\x46lyteidl\\Admin\\GPBMetadata\xea\x02\x0f\x46lyteidl::Adminb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lyteidl/admin/agent.proto\x12\x0e\x66lyteidl.admin\x1a\x1c\x66lyteidl/core/literals.proto\x1a\x19\x66lyteidl/core/tasks.proto\x1a\x1d\x66lyteidl/core/interface.proto\x1a\x1e\x66lyteidl/core/identifier.proto\"\x98\x05\n\x15TaskExecutionMetadata\x12R\n\x11task_execution_id\x18\x01 \x01(\x0b\x32&.flyteidl.core.TaskExecutionIdentifierR\x0ftaskExecutionId\x12\x1c\n\tnamespace\x18\x02 \x01(\tR\tnamespace\x12I\n\x06labels\x18\x03 \x03(\x0b\x32\x31.flyteidl.admin.TaskExecutionMetadata.LabelsEntryR\x06labels\x12X\n\x0b\x61nnotations\x18\x04 \x03(\x0b\x32\x36.flyteidl.admin.TaskExecutionMetadata.AnnotationsEntryR\x0b\x61nnotations\x12.\n\x13k8s_service_account\x18\x05 \x01(\tR\x11k8sServiceAccount\x12t\n\x15\x65nvironment_variables\x18\x06 \x03(\x0b\x32?.flyteidl.admin.TaskExecutionMetadata.EnvironmentVariablesEntryR\x14\x65nvironmentVariables\x1a\x39\n\x0bLabelsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a>\n\x10\x41nnotationsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aG\n\x19\x45nvironmentVariablesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\x83\x02\n\x11\x43reateTaskRequest\x12\x31\n\x06inputs\x18\x01 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x06inputs\x12\x37\n\x08template\x18\x02 \x01(\x0b\x32\x1b.flyteidl.core.TaskTemplateR\x08template\x12#\n\routput_prefix\x18\x03 \x01(\tR\x0coutputPrefix\x12]\n\x17task_execution_metadata\x18\x04 \x01(\x0b\x32%.flyteidl.admin.TaskExecutionMetadataR\x15taskExecutionMetadata\"9\n\x12\x43reateTaskResponse\x12#\n\rresource_meta\x18\x01 \x01(\x0cR\x0cresourceMeta\"R\n\x0eGetTaskRequest\x12\x1b\n\ttask_type\x18\x01 \x01(\tR\x08taskType\x12#\n\rresource_meta\x18\x02 \x01(\x0cR\x0cresourceMeta\"G\n\x0fGetTaskResponse\x12\x34\n\x08resource\x18\x01 \x01(\x0b\x32\x18.flyteidl.admin.ResourceR\x08resource\"\x86\x01\n\x08Resource\x12+\n\x05state\x18\x01 \x01(\x0e\x32\x15.flyteidl.admin.StateR\x05state\x12\x33\n\x07outputs\x18\x02 \x01(\x0b\x32\x19.flyteidl.core.LiteralMapR\x07outputs\x12\x18\n\x07message\x18\x03 \x01(\tR\x07message\"U\n\x11\x44\x65leteTaskRequest\x12\x1b\n\ttask_type\x18\x01 \x01(\tR\x08taskType\x12#\n\rresource_meta\x18\x02 \x01(\x0cR\x0cresourceMeta\"\x14\n\x12\x44\x65leteTaskResponse*^\n\x05State\x12\x15\n\x11RETRYABLE_FAILURE\x10\x00\x12\x15\n\x11PERMANENT_FAILURE\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x42\xb6\x01\n\x12\x63om.flyteidl.adminB\nAgentProtoP\x01Z;github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin\xa2\x02\x03\x46\x41X\xaa\x02\x0e\x46lyteidl.Admin\xca\x02\x0e\x46lyteidl\\Admin\xe2\x02\x1a\x46lyteidl\\Admin\\GPBMetadata\xea\x02\x0f\x46lyteidl::Adminb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,8 +32,8 @@ _TASKEXECUTIONMETADATA_ANNOTATIONSENTRY._serialized_options = b'8\001' _TASKEXECUTIONMETADATA_ENVIRONMENTVARIABLESENTRY._options = None _TASKEXECUTIONMETADATA_ENVIRONMENTVARIABLESENTRY._serialized_options = b'8\001' - _globals['_STATE']._serialized_start=1530 - _globals['_STATE']._serialized_end=1624 + _globals['_STATE']._serialized_start=1557 + _globals['_STATE']._serialized_end=1651 _globals['_TASKEXECUTIONMETADATA']._serialized_start=167 _globals['_TASKEXECUTIONMETADATA']._serialized_end=831 _globals['_TASKEXECUTIONMETADATA_LABELSENTRY']._serialized_start=637 @@ -50,10 +50,10 @@ _globals['_GETTASKREQUEST']._serialized_end=1236 _globals['_GETTASKRESPONSE']._serialized_start=1238 _globals['_GETTASKRESPONSE']._serialized_end=1309 - _globals['_RESOURCE']._serialized_start=1311 - _globals['_RESOURCE']._serialized_end=1419 - _globals['_DELETETASKREQUEST']._serialized_start=1421 - _globals['_DELETETASKREQUEST']._serialized_end=1506 - _globals['_DELETETASKRESPONSE']._serialized_start=1508 - _globals['_DELETETASKRESPONSE']._serialized_end=1528 + _globals['_RESOURCE']._serialized_start=1312 + _globals['_RESOURCE']._serialized_end=1446 + _globals['_DELETETASKREQUEST']._serialized_start=1448 + _globals['_DELETETASKREQUEST']._serialized_end=1533 + _globals['_DELETETASKRESPONSE']._serialized_start=1535 + _globals['_DELETETASKRESPONSE']._serialized_end=1555 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.pyi index 830158d0b2..6ca7ba23a7 100644 --- a/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/admin/agent_pb2.pyi @@ -93,12 +93,14 @@ class GetTaskResponse(_message.Message): def __init__(self, resource: _Optional[_Union[Resource, _Mapping]] = ...) -> None: ... class Resource(_message.Message): - __slots__ = ["state", "outputs"] + __slots__ = ["state", "outputs", "message"] STATE_FIELD_NUMBER: _ClassVar[int] OUTPUTS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] state: State outputs: _literals_pb2.LiteralMap - def __init__(self, state: _Optional[_Union[State, str]] = ..., outputs: _Optional[_Union[_literals_pb2.LiteralMap, _Mapping]] = ...) -> None: ... + message: str + def __init__(self, state: _Optional[_Union[State, str]] = ..., outputs: _Optional[_Union[_literals_pb2.LiteralMap, _Mapping]] = ..., message: _Optional[str] = ...) -> None: ... class DeleteTaskRequest(_message.Message): __slots__ = ["task_type", "resource_meta"] diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py index 15f0d96558..f0c086f9e7 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.py @@ -15,7 +15,7 @@ from flyteidl.plugins.kubeflow import common_pb2 as flyteidl_dot_plugins_dot_kubeflow_dot_common__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\xa8\x03\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\"\xe5\x01\n(DistributedTensorflowTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf5\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*flyteidl/plugins/kubeflow/tensorflow.proto\x12\x19\x66lyteidl.plugins.kubeflow\x1a\x19\x66lyteidl/core/tasks.proto\x1a&flyteidl/plugins/kubeflow/common.proto\"\x9c\x04\n!DistributedTensorflowTrainingTask\x12l\n\x0fworker_replicas\x18\x01 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x0eworkerReplicas\x12\x64\n\x0bps_replicas\x18\x02 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\npsReplicas\x12j\n\x0e\x63hief_replicas\x18\x03 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\rchiefReplicas\x12\x43\n\nrun_policy\x18\x04 \x01(\x0b\x32$.flyteidl.plugins.kubeflow.RunPolicyR\trunPolicy\x12r\n\x12\x65valuator_replicas\x18\x05 \x01(\x0b\x32\x43.flyteidl.plugins.kubeflow.DistributedTensorflowTrainingReplicaSpecR\x11\x65valuatorReplicas\"\xe5\x01\n(DistributedTensorflowTrainingReplicaSpec\x12\x1a\n\x08replicas\x18\x01 \x01(\x05R\x08replicas\x12\x14\n\x05image\x18\x02 \x01(\tR\x05image\x12\x36\n\tresources\x18\x03 \x01(\x0b\x32\x18.flyteidl.core.ResourcesR\tresources\x12O\n\x0erestart_policy\x18\x04 \x01(\x0e\x32(.flyteidl.plugins.kubeflow.RestartPolicyR\rrestartPolicyB\xf5\x01\n\x1d\x63om.flyteidl.plugins.kubeflowB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PK\xaa\x02\x19\x46lyteidl.Plugins.Kubeflow\xca\x02\x19\x46lyteidl\\Plugins\\Kubeflow\xe2\x02%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\xea\x02\x1b\x46lyteidl::Plugins::Kubeflowb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,7 +25,7 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\035com.flyteidl.plugins.kubeflowB\017TensorflowProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPK\252\002\031Flyteidl.Plugins.Kubeflow\312\002\031Flyteidl\\Plugins\\Kubeflow\342\002%Flyteidl\\Plugins\\Kubeflow\\GPBMetadata\352\002\033Flyteidl::Plugins::Kubeflow' _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_start=141 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=565 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_start=568 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=797 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=681 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_start=684 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGREPLICASPEC']._serialized_end=913 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi index e08a1ff983..4a999f70e8 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/kubeflow/tensorflow_pb2.pyi @@ -7,16 +7,18 @@ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Opti DESCRIPTOR: _descriptor.FileDescriptor class DistributedTensorflowTrainingTask(_message.Message): - __slots__ = ["worker_replicas", "ps_replicas", "chief_replicas", "run_policy"] + __slots__ = ["worker_replicas", "ps_replicas", "chief_replicas", "run_policy", "evaluator_replicas"] WORKER_REPLICAS_FIELD_NUMBER: _ClassVar[int] PS_REPLICAS_FIELD_NUMBER: _ClassVar[int] CHIEF_REPLICAS_FIELD_NUMBER: _ClassVar[int] RUN_POLICY_FIELD_NUMBER: _ClassVar[int] + EVALUATOR_REPLICAS_FIELD_NUMBER: _ClassVar[int] worker_replicas: DistributedTensorflowTrainingReplicaSpec ps_replicas: DistributedTensorflowTrainingReplicaSpec chief_replicas: DistributedTensorflowTrainingReplicaSpec run_policy: _common_pb2.RunPolicy - def __init__(self, worker_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., ps_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., chief_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ...) -> None: ... + evaluator_replicas: DistributedTensorflowTrainingReplicaSpec + def __init__(self, worker_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., ps_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., chief_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ..., run_policy: _Optional[_Union[_common_pb2.RunPolicy, _Mapping]] = ..., evaluator_replicas: _Optional[_Union[DistributedTensorflowTrainingReplicaSpec, _Mapping]] = ...) -> None: ... class DistributedTensorflowTrainingReplicaSpec(_message.Message): __slots__ = ["replicas", "image", "resources", "restart_policy"] diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py index 187526d6d5..ceed4231bb 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!flyteidl/plugins/tensorflow.proto\x12\x10\x66lyteidl.plugins\"\x85\x01\n!DistributedTensorflowTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x1f\n\x0bps_replicas\x18\x02 \x01(\x05R\npsReplicas\x12%\n\x0e\x63hief_replicas\x18\x03 \x01(\x05R\rchiefReplicasB\xc7\x01\n\x14\x63om.flyteidl.pluginsB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n!flyteidl/plugins/tensorflow.proto\x12\x10\x66lyteidl.plugins\"\xb4\x01\n!DistributedTensorflowTrainingTask\x12\x18\n\x07workers\x18\x01 \x01(\x05R\x07workers\x12\x1f\n\x0bps_replicas\x18\x02 \x01(\x05R\npsReplicas\x12%\n\x0e\x63hief_replicas\x18\x03 \x01(\x05R\rchiefReplicas\x12-\n\x12\x65valuator_replicas\x18\x04 \x01(\x05R\x11\x65valuatorReplicasB\xc7\x01\n\x14\x63om.flyteidl.pluginsB\x0fTensorflowProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -23,5 +23,5 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b'\n\024com.flyteidl.pluginsB\017TensorflowProtoP\001Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\242\002\003FPX\252\002\020Flyteidl.Plugins\312\002\020Flyteidl\\Plugins\342\002\034Flyteidl\\Plugins\\GPBMetadata\352\002\021Flyteidl::Plugins' _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_start=56 - _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=189 + _globals['_DISTRIBUTEDTENSORFLOWTRAININGTASK']._serialized_end=236 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi index d3dc028af3..81e2bc30b9 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/tensorflow_pb2.pyi @@ -5,11 +5,13 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor class DistributedTensorflowTrainingTask(_message.Message): - __slots__ = ["workers", "ps_replicas", "chief_replicas"] + __slots__ = ["workers", "ps_replicas", "chief_replicas", "evaluator_replicas"] WORKERS_FIELD_NUMBER: _ClassVar[int] PS_REPLICAS_FIELD_NUMBER: _ClassVar[int] CHIEF_REPLICAS_FIELD_NUMBER: _ClassVar[int] + EVALUATOR_REPLICAS_FIELD_NUMBER: _ClassVar[int] workers: int ps_replicas: int chief_replicas: int - def __init__(self, workers: _Optional[int] = ..., ps_replicas: _Optional[int] = ..., chief_replicas: _Optional[int] = ...) -> None: ... + evaluator_replicas: int + def __init__(self, workers: _Optional[int] = ..., ps_replicas: _Optional[int] = ..., chief_replicas: _Optional[int] = ..., evaluator_replicas: _Optional[int] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.admin.rs b/flyteidl/gen/pb_rust/flyteidl.admin.rs index 5ac0ff4471..759983a507 100644 --- a/flyteidl/gen/pb_rust/flyteidl.admin.rs +++ b/flyteidl/gen/pb_rust/flyteidl.admin.rs @@ -78,6 +78,9 @@ pub struct Resource { /// +optional #[prost(message, optional, tag="2")] pub outputs: ::core::option::Option, + /// A descriptive message for the current state. e.g. waiting for cluster. + #[prost(string, tag="3")] + pub message: ::prost::alloc::string::String, } /// A message used to delete a task. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs index 59c1f681a0..96d46653da 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.kubeflow.rs @@ -182,6 +182,9 @@ pub struct DistributedTensorflowTrainingTask { /// active. #[prost(message, optional, tag="4")] pub run_policy: ::core::option::Option, + /// Evaluator replicas spec + #[prost(message, optional, tag="5")] + pub evaluator_replicas: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 5c7873b5d2..11e4ad05af 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -290,14 +290,19 @@ pub struct SparkJob { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistributedTensorflowTrainingTask { - /// number of worker, ps, chief replicas spawned in the cluster for this job + /// number of worker replicas spawned in the cluster for this job #[prost(int32, tag="1")] pub workers: i32, /// PS -> Parameter server + /// number of ps replicas spawned in the cluster for this job #[prost(int32, tag="2")] pub ps_replicas: i32, + /// number of chief replicas spawned in the cluster for this job #[prost(int32, tag="3")] pub chief_replicas: i32, + /// number of evaluator replicas spawned in the cluster for this job + #[prost(int32, tag="4")] + pub evaluator_replicas: i32, } /// Represents an Execution that was launched and could be waited on. #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/protos/flyteidl/admin/agent.proto b/flyteidl/protos/flyteidl/admin/agent.proto index d85d4d9a9e..a4dade0953 100644 --- a/flyteidl/protos/flyteidl/admin/agent.proto +++ b/flyteidl/protos/flyteidl/admin/agent.proto @@ -73,6 +73,8 @@ message Resource { // Structured dataset pointing to the query result table. // +optional core.LiteralMap outputs = 2; + // A descriptive message for the current state. e.g. waiting for cluster. + string message = 3; } // A message used to delete a task. diff --git a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto index 4cf3153548..789666b989 100644 --- a/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/kubeflow/tensorflow.proto @@ -22,6 +22,9 @@ message DistributedTensorflowTrainingTask { // job, for example how to clean up resources and how long the job can stay // active. RunPolicy run_policy = 4; + + // Evaluator replicas spec + DistributedTensorflowTrainingReplicaSpec evaluator_replicas = 5; } message DistributedTensorflowTrainingReplicaSpec { diff --git a/flyteidl/protos/flyteidl/plugins/tensorflow.proto b/flyteidl/protos/flyteidl/plugins/tensorflow.proto index e768ad65ad..e494a6cc32 100644 --- a/flyteidl/protos/flyteidl/plugins/tensorflow.proto +++ b/flyteidl/protos/flyteidl/plugins/tensorflow.proto @@ -6,9 +6,13 @@ option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugi // Custom proto for plugin that enables distributed training using https://github.com/kubeflow/tf-operator message DistributedTensorflowTrainingTask { - // number of worker, ps, chief replicas spawned in the cluster for this job + // number of worker replicas spawned in the cluster for this job int32 workers = 1; // PS -> Parameter server + // number of ps replicas spawned in the cluster for this job int32 ps_replicas = 2; + // number of chief replicas spawned in the cluster for this job int32 chief_replicas = 3; + // number of evaluator replicas spawned in the cluster for this job + int32 evaluator_replicas = 4; } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index 4f3836a8b1..3f8c2a0914 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -226,7 +226,6 @@ func PhaseInfoQueuedWithTaskInfo(version uint32, reason string, info *TaskInfo) } func PhaseInfoInitializing(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { - pi := phaseInfo(PhaseInitializing, version, nil, info, false) pi.reason = reason return pi diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go new file mode 100644 index 0000000000..d2f5042cf8 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/non_interruptible.go @@ -0,0 +1,35 @@ +package flytek8s + +import ( + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" +) + +// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false +// This is useful as the runner and the scheduler pods should never be interruptible +type NonInterruptibleTaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata +} + +func (n NonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { + return false +} + +// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is +// non-interruptible +type NonInterruptibleTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadata NonInterruptibleTaskExecutionMetadata +} + +func (n NonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return n.metadata +} + +func NewNonInterruptibleTaskExecutionContext(ctx pluginsCore.TaskExecutionContext) NonInterruptibleTaskExecutionContext { + return NonInterruptibleTaskExecutionContext{ + TaskExecutionContext: ctx, + metadata: NonInterruptibleTaskExecutionMetadata{ + ctx.TaskExecutionMetadata(), + }, + } +} diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index dbd7c47baf..caa83def5e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -432,6 +432,15 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return podSpec, objectMeta, primaryContainerName, nil } +func GetContainer(podSpec *v1.PodSpec, name string) (*v1.Container, error) { + for _, container := range podSpec.Containers { + if container.Name == name { + return &container, nil + } + } + return nil, pluginserrors.Errorf(pluginserrors.BadTaskSpecification, "invalid TaskSpecification, container [%s] not defined", name) +} + // getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can // come from one of the following: // (1) PodTemplate name in the TaskMetadata: This name is then looked up in the PodTemplateStore. diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index f8272b919a..65050f5bb2 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -6,6 +6,12 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" @@ -15,11 +21,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - v1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/intstr" - "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" ) const ( @@ -27,42 +28,12 @@ const ( KindDaskJob = "DaskJob" ) -// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false -// This is useful as the runner and the scheduler pods should never be interruptible -type nonInterruptibleTaskExecutionMetadata struct { - pluginsCore.TaskExecutionMetadata -} - -func (n nonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { - return false -} - -// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is -// non-interruptible -type nonInterruptibleTaskExecutionContext struct { - pluginsCore.TaskExecutionContext - metadata nonInterruptibleTaskExecutionMetadata -} - -func (n nonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { - return n.metadata -} - func mergeMapInto(src map[string]string, dst map[string]string) { for key, value := range src { dst[key] = value } } -func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) (*v1.Container, error) { - for _, container := range spec.Containers { - if container.Name == primaryContainerName { - return &container, nil - } - } - return nil, errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) -} - func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, container v1.Container) error { for i, c := range spec.Containers { if c.Name == primaryContainerName { @@ -104,8 +75,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if err != nil { return nil, err } - nonInterruptibleTaskMetadata := nonInterruptibleTaskExecutionMetadata{taskCtx.TaskExecutionMetadata()} - nonInterruptibleTaskCtx := nonInterruptibleTaskExecutionContext{taskCtx, nonInterruptibleTaskMetadata} + nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err @@ -144,7 +114,7 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) { workerPodSpec := podSpec.DeepCopy() - primaryContainer, err := getPrimaryContainer(workerPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(workerPodSpec, primaryContainerName) if err != nil { return nil, err } @@ -206,7 +176,7 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) { schedulerPodSpec := podSpec.DeepCopy() - primaryContainer, err := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(schedulerPodSpec, primaryContainerName) if err != nil { return nil, err } @@ -283,7 +253,7 @@ func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.Schedule jobPodSpec := podSpec.DeepCopy() jobPodSpec.RestartPolicy = v1.RestartPolicyNever - primaryContainer, err := getPrimaryContainer(jobPodSpec, primaryContainerName) + primaryContainer, err := flytek8s.GetContainer(jobPodSpec, primaryContainerName) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 4e38d791c4..f6a9787dbb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -93,7 +93,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim // GetLogs will return the logs for kubeflow job func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v1.ObjectMeta, hasMaster bool, - workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { + workersCount int32, psReplicasCount int32, chiefReplicasCount int32, evaluatorReplicasCount int32) ([]*core.TaskLog, error) { name := objectMeta.Name namespace := objectMeta.Namespace @@ -181,6 +181,18 @@ func GetLogs(pluginContext k8s.PluginContext, taskType string, objectMeta meta_v } taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) } + // get evaluator log, and the max number of evaluator is 1 + if evaluatorReplicasCount != 0 { + evaluatorReplicasCount, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-evaluatorReplica-%d", 0), + Namespace: namespace, + TaskExecutionIdentifier: &taskExecID, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, evaluatorReplicasCount.TaskLogs...) + } return taskLogs, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 4f5d70dc5c..6c083b6898 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -167,7 +167,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "mpi-namespace", } - jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0) + jobLogs, err := GetLogs(taskCtx, MPITaskType, mpiJobObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 1, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri) @@ -176,7 +176,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "pytorch-namespace", } - jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0) + jobLogs, err = GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri) @@ -186,7 +186,7 @@ func TestGetLogs(t *testing.T) { Name: "test", Namespace: "tensorflow-namespace", } - jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1) + jobLogs, err = GetLogs(taskCtx, TensorflowTaskType, tensorflowJobObjectMeta, false, workers, launcher, 1, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri) @@ -209,7 +209,7 @@ func TestGetLogsTemplateUri(t *testing.T) { Time: time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC), }, } - jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0) + jobLogs, err := GetLogs(taskCtx, PytorchTaskType, pytorchJobObjectMeta, true, 1, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("https://console.cloud.google.com/logs/query;query=resource.labels.pod_name=%s-master-0×tamp>%s", "test", "2022-01-01T12:00:00Z"), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 492dd32235..25e45ad727 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -210,7 +210,7 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false, - *numWorkers, *numLauncherReplicas, 0) + *numWorkers, *numLauncherReplicas, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 8dc176d833..bda0faff75 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -543,7 +543,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil) - jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0) + jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 07f3df0ef5..2461c7bc18 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -231,7 +231,7 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas - taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0) + taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index a9e304842a..f980c89741 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -588,7 +588,7 @@ func TestGetLogs(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) @@ -608,7 +608,7 @@ func TestGetLogsElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) taskCtx := dummyPytorchTaskContext(dummyPytorchTaskTemplate("", dummyPytorchCustomObj(workers)), resourceRequirements, nil) - jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0) + jobLogs, err := common.GetLogs(taskCtx, common.PytorchTaskType, pytorchJob.ObjectMeta, hasMaster, workers, 0, 0, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 578bd2a0d6..1c4c965819 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -80,6 +80,11 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task PodSpec: podSpec.DeepCopy(), RestartPolicy: commonOp.RestartPolicyNever, }, + kubeflowv1.TFJobReplicaTypeEval: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, } runPolicy := commonOp.RunPolicy{} @@ -94,6 +99,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = tensorflowTaskExtraArgs.GetChiefReplicas() replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = tensorflowTaskExtraArgs.GetWorkers() replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = tensorflowTaskExtraArgs.GetPsReplicas() + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = tensorflowTaskExtraArgs.GetEvaluatorReplicas() } else if taskTemplate.TaskTypeVersion == 1 { kfTensorflowTaskExtraArgs := kfplugins.DistributedTensorflowTrainingTask{} @@ -151,6 +157,22 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() } + evaluatorReplicaSpec := kfTensorflowTaskExtraArgs.GetEvaluatorReplicas() + if evaluatorReplicaSpec != nil { + err := common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + evaluatorReplicaSpec.GetImage(), + evaluatorReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].RestartPolicy = common.ParseRestartPolicy(evaluatorReplicaSpec.GetRestartPolicy()) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeEval].ReplicaNum = evaluatorReplicaSpec.GetReplicas() + } + if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { runPolicy = common.ParseRunPolicy(*kfTensorflowTaskExtraArgs.GetRunPolicy()) } @@ -207,9 +229,10 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC workersCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas + evaluatorReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false, - *workersCount, *psReplicasCount, *chiefCount) + *workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 6ee5394453..ee5e97db16 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -71,11 +71,12 @@ var ( jobNamespace = "tensorflow-namespace" ) -func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int32) *plugins.DistributedTensorflowTrainingTask { +func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int32, evaluatorReplicas int32) *plugins.DistributedTensorflowTrainingTask { return &plugins.DistributedTensorflowTrainingTask{ - Workers: workers, - PsReplicas: psReplicas, - ChiefReplicas: chiefReplicas, + Workers: workers, + PsReplicas: psReplicas, + ChiefReplicas: chiefReplicas, + EvaluatorReplicas: evaluatorReplicas, } } @@ -177,7 +178,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate, resources *core } func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorResourceHandler, - workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { + workers int32, psReplicas int32, chiefReplicas int32, evaluatorReplicas int32, conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { var jobConditions []commonOp.JobCondition now := time.Now() @@ -276,7 +277,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso } } - tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas) + tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) if err != nil { @@ -302,7 +303,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso func TestBuildResourceTensorFlow(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tfObj := dummyTensorFlowCustomObj(100, 50, 1) + tfObj := dummyTensorFlowCustomObj(100, 50, 1, 1) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) @@ -314,6 +315,7 @@ func TestBuildResourceTensorFlow(t *testing.T) { assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) // verify TaskExecutionMetadata labels and annotations are copied to the TensorFlowJob for k, v := range dummyAnnotations { @@ -444,7 +446,7 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { }, } - v0TaskTemplate := dummyTensorFlowTaskTemplate("v0", dummyTensorFlowCustomObj(100, 50, 1)) + v0TaskTemplate := dummyTensorFlowTaskTemplate("v0", dummyTensorFlowCustomObj(100, 50, 1, 1)) v1TaskTemplate := dummyTensorFlowTaskTemplate("v1", &kfplugins.DistributedTensorflowTrainingTask{ ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ Replicas: 1, @@ -455,6 +457,9 @@ func TestBuildResourceTensorFlowExtendedResources(t *testing.T) { PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ Replicas: 50, }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + }, }) v1TaskTemplate.TaskTypeVersion = 1 testConfigs := []struct { @@ -500,10 +505,10 @@ func TestGetTaskPhase(t *testing.T) { ctx := context.TODO() dummyTensorFlowJobResourceCreator := func(conditionType commonOp.JobConditionType) *kubeflowv1.TFJob { - return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, conditionType) + return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, 1, conditionType) } - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1)), resourceRequirements, nil) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(2, 1, 1, 1)), resourceRequirements, nil) taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, taskCtx, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) assert.NoError(t, err) assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) @@ -544,18 +549,20 @@ func TestGetLogs(t *testing.T) { workers := int32(2) psReplicas := int32(1) chiefReplicas := int32(1) + evaluatorReplicas := int32(1) tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) - taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas)), resourceRequirements, nil) + tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, evaluatorReplicas, commonOp.JobRunning) + taskCtx := dummyTensorFlowTaskContext(dummyTensorFlowTaskTemplate("", dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas, evaluatorReplicas)), resourceRequirements, nil) jobLogs, err := common.GetLogs(taskCtx, common.TensorflowTaskType, tensorFlowJob.ObjectMeta, false, - workers, psReplicas, chiefReplicas) + workers, psReplicas, chiefReplicas, evaluatorReplicas) assert.NoError(t, err) - assert.Equal(t, 4, len(jobLogs)) + assert.Equal(t, 5, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[0].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[1].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[2].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[3].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-evaluatorReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[4].Uri) } func TestGetProperties(t *testing.T) { @@ -566,26 +573,31 @@ func TestGetProperties(t *testing.T) { func TestReplicaCounts(t *testing.T) { for _, test := range []struct { - name string - chiefReplicaCount int32 - psReplicaCount int32 - workerReplicaCount int32 - expectError bool - contains []commonOp.ReplicaType - notContains []commonOp.ReplicaType + name string + chiefReplicaCount int32 + psReplicaCount int32 + workerReplicaCount int32 + evaluatorReplicaCount int32 + expectError bool + contains []commonOp.ReplicaType + notContains []commonOp.ReplicaType }{ - {"NoWorkers", 1, 1, 0, true, nil, nil}, - {"SingleChief", 1, 0, 1, false, + {"NoWorkers", 1, 1, 0, 1, true, nil, nil}, + {"SingleChief", 1, 0, 1, 0, false, []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeWorker}, - []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS}}, - {"SinglePS", 0, 1, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeEval}}, + {"SinglePS", 0, 1, 1, 0, false, []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeWorker}, - []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief}}, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeEval}}, + {"AllContains", 1, 1, 1, 1, false, + []commonOp.ReplicaType{kubeflowv1.TFJobReplicaTypePS, kubeflowv1.TFJobReplicaTypeWorker, kubeflowv1.TFJobReplicaTypeChief, kubeflowv1.TFJobReplicaTypeEval}, + nil, + }, } { t.Run(test.name, func(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount) + tfObj := dummyTensorFlowCustomObj(test.workerReplicaCount, test.psReplicaCount, test.chiefReplicaCount, test.evaluatorReplicaCount) taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj) resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate, resourceRequirements, nil)) @@ -653,6 +665,21 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { }, }, }, + EvaluatorReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, RunPolicy: &kfplugins.RunPolicy{ CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, ActiveDeadlineSeconds: int32(100), @@ -688,6 +715,16 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { corev1.ResourceCPU: resource.MustParse("500m"), }, }, + kubeflowv1.TFJobReplicaTypeEval: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, } tensorflowResourceHandler := tensorflowOperatorResourceHandler{} @@ -704,6 +741,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeEval].Replicas) for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { var hasContainerWithDefaultTensorFlowName = false diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 7c7ba34f9a..d0506ccfb5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -3,35 +3,29 @@ package spark import ( "context" "fmt" - - "sigs.k8s.io/controller-runtime/pkg/client" - + "regexp" "strconv" + "strings" + "time" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/template" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" + sparkOpConfig "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/config" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - "k8s.io/client-go/kubernetes/scheme" - - sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "regexp" - "strings" - "time" ) const KindSparkApplication = "SparkApplication" @@ -80,70 +74,20 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v].", taskTemplate.GetCustom()) } - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) - container := taskTemplate.GetContainer() - - envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()), - taskCtx.TaskExecutionMetadata().GetEnvironmentVariables(), taskCtx.TaskExecutionMetadata().GetTaskExecutionID()) - - sparkEnvVars := make(map[string]string) - for _, envVar := range envVars { - sparkEnvVars[envVar.Name] = envVar.Value - } - - sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) - - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) - - if len(serviceAccountName) == 0 { - serviceAccountName = sparkTaskType - } - driverSpec := sparkOp.DriverSpec{ - SparkPodSpec: sparkOp.SparkPodSpec{ - Affinity: config.GetK8sPluginConfig().DefaultAffinity, - Annotations: annotations, - Labels: labels, - EnvVars: sparkEnvVars, - Image: &container.Image, - SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), - DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), - Tolerations: config.GetK8sPluginConfig().DefaultTolerations, - SchedulerName: &config.GetK8sPluginConfig().SchedulerName, - NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, - HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, - }, - ServiceAccount: &serviceAccountName, - } - - executorSpec := sparkOp.ExecutorSpec{ - SparkPodSpec: sparkOp.SparkPodSpec{ - Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(), - Annotations: annotations, - Labels: labels, - Image: &container.Image, - EnvVars: sparkEnvVars, - SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), - DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), - Tolerations: config.GetK8sPluginConfig().DefaultTolerations, - SchedulerName: &config.GetK8sPluginConfig().SchedulerName, - NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, - HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, - }, + sparkConfig := getSparkConfig(taskCtx, &sparkJob) + driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig) + if err != nil { + return nil, err } - - modifiedArgs, err := template.Render(ctx, container.GetArgs(), template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - }) + executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig) if err != nil { return nil, err } + app := createSparkApplication(&sparkJob, sparkConfig, driverSpec, executorSpec) + return app, nil +} - // Hack: Retry submit failures in-case of resource limits hit. - submissionFailureRetries := int32(14) +func getSparkConfig(taskCtx pluginsCore.TaskExecutionContext, sparkJob *plugins.SparkJob) map[string]string { // Start with default config values. sparkConfig := make(map[string]string) for k, v := range GetSparkConfig().DefaultSparkConfig { @@ -165,57 +109,145 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } // Set pod limits. - if len(sparkConfig["spark.kubernetes.driver.limit.cores"]) == 0 { + if len(sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey]) == 0 { // spark.kubernetes.driver.request.cores takes precedence over spark.driver.cores - if len(sparkConfig["spark.kubernetes.driver.request.cores"]) != 0 { - sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.kubernetes.driver.request.cores"] + if len(sparkConfig[sparkOpConfig.SparkDriverCoreRequestKey]) != 0 { + sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey] = sparkConfig[sparkOpConfig.SparkDriverCoreRequestKey] } else if len(sparkConfig["spark.driver.cores"]) != 0 { - sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"] + sparkConfig[sparkOpConfig.SparkDriverCoreLimitKey] = sparkConfig["spark.driver.cores"] } } - if len(sparkConfig["spark.kubernetes.executor.limit.cores"]) == 0 { + if len(sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey]) == 0 { // spark.kubernetes.executor.request.cores takes precedence over spark.executor.cores - if len(sparkConfig["spark.kubernetes.executor.request.cores"]) != 0 { - sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.kubernetes.executor.request.cores"] + if len(sparkConfig[sparkOpConfig.SparkExecutorCoreRequestKey]) != 0 { + sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey] = sparkConfig[sparkOpConfig.SparkExecutorCoreRequestKey] } else if len(sparkConfig["spark.executor.cores"]) != 0 { - sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"] + sparkConfig[sparkOpConfig.SparkExecutorCoreLimitKey] = sparkConfig["spark.executor.cores"] } } sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() sparkConfig["spark.kubernetes.driverEnv.FLYTE_START_TIME"] = strconv.FormatInt(time.Now().UnixNano()/1000000, 10) - // Add driver/executor defaults to CRD Driver/Executor Spec as well. - cores, err := strconv.ParseInt(sparkConfig["spark.driver.cores"], 10, 32) - if err == nil { - driverSpec.Cores = intPtr(int32(cores)) + return sparkConfig +} + +func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { + name := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(metadata) + if len(name) == 0 { + name = sparkTaskType + } + return name +} + +func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) *sparkOp.SparkPodSpec { + annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) + labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) + + sparkEnvVars := make(map[string]string) + for _, envVar := range container.Env { + sparkEnvVars[envVar.Name] = envVar.Value } - driverSpec.Memory = strPtr(sparkConfig["spark.driver.memory"]) + sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) + + spec := sparkOp.SparkPodSpec{ + Affinity: podSpec.Affinity, + Annotations: annotations, + Labels: labels, + EnvVars: sparkEnvVars, + Image: &container.Image, + SecurityContenxt: podSpec.SecurityContext.DeepCopy(), + DNSConfig: podSpec.DNSConfig.DeepCopy(), + Tolerations: podSpec.Tolerations, + SchedulerName: &podSpec.SchedulerName, + NodeSelector: podSpec.NodeSelector, + HostNetwork: &podSpec.HostNetwork, + } + return &spec +} - execCores, err := strconv.ParseInt(sparkConfig["spark.executor.cores"], 10, 32) - if err == nil { - executorSpec.Cores = intPtr(int32(execCores)) +type driverSpec struct { + sparkSpec *sparkOp.DriverSpec +} + +func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) { + // Spark driver pods should always run as non-interruptible + nonInterruptibleTaskCtx := flytek8s.NewNonInterruptibleTaskExecutionContext(taskCtx) + podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) + if err != nil { + return nil, err + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) + if err != nil { + return nil, err + } + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer) + serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) + spec := driverSpec{ + &sparkOp.DriverSpec{ + SparkPodSpec: *sparkPodSpec, + ServiceAccount: &serviceAccountName, + }, + } + if cores, err := strconv.ParseInt(sparkConfig["spark.driver.cores"], 10, 32); err == nil { + spec.sparkSpec.Cores = intPtr(int32(cores)) } + spec.sparkSpec.Memory = strPtr(sparkConfig["spark.driver.memory"]) + return &spec, nil +} - execCount, err := strconv.ParseInt(sparkConfig["spark.executor.instances"], 10, 32) - if err == nil { - executorSpec.Instances = intPtr(int32(execCount)) +type executorSpec struct { + container *v1.Container + sparkSpec *sparkOp.ExecutorSpec + serviceAccountName string +} + +func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*executorSpec, error) { + podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, err + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) + if err != nil { + return nil, err + } + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer) + serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) + spec := executorSpec{ + primaryContainer, + &sparkOp.ExecutorSpec{ + SparkPodSpec: *sparkPodSpec, + }, + serviceAccountName, } - executorSpec.Memory = strPtr(sparkConfig["spark.executor.memory"]) + if execCores, err := strconv.ParseInt(sparkConfig["spark.executor.cores"], 10, 32); err == nil { + spec.sparkSpec.Cores = intPtr(int32(execCores)) + } + if execCount, err := strconv.ParseInt(sparkConfig["spark.executor.instances"], 10, 32); err == nil { + spec.sparkSpec.Instances = intPtr(int32(execCount)) + } + spec.sparkSpec.Memory = strPtr(sparkConfig["spark.executor.memory"]) + return &spec, nil +} + +func createSparkApplication(sparkJob *plugins.SparkJob, sparkConfig map[string]string, driverSpec *driverSpec, + executorSpec *executorSpec) *sparkOp.SparkApplication { + // Hack: Retry submit failures in-case of resource limits hit. + submissionFailureRetries := int32(14) - j := &sparkOp.SparkApplication{ + app := &sparkOp.SparkApplication{ TypeMeta: metav1.TypeMeta{ Kind: KindSparkApplication, APIVersion: sparkOp.SchemeGroupVersion.String(), }, Spec: sparkOp.SparkApplicationSpec{ - ServiceAccount: &serviceAccountName, + ServiceAccount: &executorSpec.serviceAccountName, Type: getApplicationType(sparkJob.GetApplicationType()), - Image: &container.Image, - Arguments: modifiedArgs, - Driver: driverSpec, - Executor: executorSpec, + Image: &executorSpec.container.Image, + Arguments: executorSpec.container.Args, + Driver: *driverSpec.sparkSpec, + Executor: *executorSpec.sparkSpec, SparkConf: sparkConfig, HadoopConf: sparkJob.GetHadoopConf(), // SubmissionFailures handled here. Task Failures handled at Propeller/Job level. @@ -227,32 +259,16 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } if val, ok := sparkConfig["spark.batchScheduler"]; ok { - j.Spec.BatchScheduler = &val + app.Spec.BatchScheduler = &val } if sparkJob.MainApplicationFile != "" { - j.Spec.MainApplicationFile = &sparkJob.MainApplicationFile + app.Spec.MainApplicationFile = &sparkJob.MainApplicationFile } if sparkJob.MainClass != "" { - j.Spec.MainClass = &sparkJob.MainClass - } - - // Spark driver pods should always run as non-interruptible. As such, we hardcode - // `interruptible=false` to explicitly add non-interruptible node selector - // requirements to the driver pods - flytek8s.ApplyInterruptibleNodeSelectorRequirement(false, j.Spec.Driver.Affinity) - - // Add Interruptible Tolerations/NodeSelector to only Executor pods. - // The Interruptible NodeSelector takes precedence over the DefaultNodeSelector - if taskCtx.TaskExecutionMetadata().IsInterruptible() { - j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) - j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector + app.Spec.MainClass = &sparkJob.MainClass } - - // Add interruptible/non-interruptible node selector requirements to executor pod - flytek8s.ApplyInterruptibleNodeSelectorRequirement(taskCtx.TaskExecutionMetadata().IsInterruptible(), j.Spec.Executor.Affinity) - - return j, nil + return app } func addConfig(sparkConfig map[string]string, key string, value string) { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index e981c0dce1..be2f9477b6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -6,28 +6,25 @@ import ( "strconv" "testing" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" - - "github.com/stretchr/testify/mock" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" - - pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" - - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - - pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" - sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" + pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) const sparkMainClass = "MainClass" @@ -87,7 +84,7 @@ func TestGetEventInfo(t *testing.T) { }, }, })) - taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("blah-1", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false) info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState)) assert.NoError(t, err) assert.Len(t, info.Logs, 6) @@ -157,7 +154,7 @@ func TestGetTaskPhase(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} ctx := context.TODO() - taskCtx := dummySparkTaskContext(dummySparkTaskTemplate("", dummySparkConf), false) + taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("", dummySparkConf), false) taskPhase, err := sparkResourceHandler.GetTaskPhase(ctx, taskCtx, dummySparkApplication(sj.NewState)) assert.NoError(t, err) assert.Equal(t, taskPhase.Phase(), pluginsCore.PhaseQueued) @@ -250,8 +247,33 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { return &sparkJob } -func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate { +func dummyPodSpec() *corev1.PodSpec { + return &corev1.PodSpec{ + InitContainers: []corev1.Container{ + { + Name: "init", + Image: testImage, + Args: testArgs, + }, + }, + Containers: []corev1.Container{ + { + Name: "primary", + Image: testImage, + Args: testArgs, + Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + }, + { + Name: "secondary", + Image: testImage, + Args: testArgs, + Env: flytek8s.ToK8sEnvVar(dummyEnvVars), + }, + }, + } +} +func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *core.TaskTemplate { sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) if err != nil { @@ -279,6 +301,40 @@ func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTe } } +func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate { + sparkJob := dummySparkCustomObj(sparkConf) + sparkJobJSON, err := utils.MarshalToString(sparkJob) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(sparkJobJSON, &structObj) + if err != nil { + panic(err) + } + + podSpecPb, err := utils.MarshalObjToStruct(podSpec) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "k8s_pod", + Target: &core.TaskTemplate_K8SPod{ + K8SPod: &core.K8SPod{ + PodSpec: podSpecPb, + }, + }, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "primary", + }, + Custom: &structObj, + } +} + func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) pluginsCore.TaskExecutionContext { taskCtx := &mocks.TaskExecutionContext{} inputReader := &pluginIOMocks.InputReader{} @@ -312,6 +368,11 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) }) tID.On("GetGeneratedName").Return("some-acceptable-name") + overrides := &mocks.TaskOverrides{} + overrides.On("GetResources").Return(&corev1.ResourceRequirements{}) + // No support for GPUs, and consequently, ExtendedResources on Spark plugin. + overrides.On("GetExtendedResources").Return(nil) + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.On("GetTaskExecutionID").Return(tID) taskExecutionMetadata.On("GetNamespace").Return("test-namespace") @@ -327,30 +388,14 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) taskExecutionMetadata.On("IsInterruptible").Return(interruptible) taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1)) taskExecutionMetadata.On("GetEnvironmentVariables").Return(nil) + taskExecutionMetadata.On("GetPlatformResources").Return(nil) + taskExecutionMetadata.On("GetOverrides").Return(overrides) + taskExecutionMetadata.On("GetK8sServiceAccount").Return("new-val") taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) return taskCtx } -func TestBuildResourceSpark(t *testing.T) { - sparkResourceHandler := sparkResourceHandler{} - - // Case1: Valid Spark Task-Template - taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf) - - // Set spark custom feature config. - assert.NoError(t, setSparkConfig(&Config{ - Features: []Feature{ - { - Name: "feature1", - SparkConfig: map[string]string{"spark.hadoop.feature1": "true"}, - }, - { - Name: "feature2", - SparkConfig: map[string]string{"spark.hadoop.feature2": "true"}, - }, - }, - })) - +func defaultPluginConfig() *config.K8sPluginConfig { // Set Interruptible Config runAsUser := int64(1000) dnsOptVal1 := "1" @@ -400,7 +445,7 @@ func TestBuildResourceSpark(t *testing.T) { }, } - // interruptible/non-interruptible nodeselector requirement + // Interruptible/non-interruptible nodeselector requirement interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{ Key: "x/interruptible", Operator: corev1.NodeSelectorOpIn, @@ -413,9 +458,7 @@ func TestBuildResourceSpark(t *testing.T) { Values: []string{"true"}, } - // NonInterruptibleNodeSelectorRequirement - - assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + config := &config.K8sPluginConfig{ DefaultAffinity: defaultAffinity, DefaultPodSecurityContext: &corev1.PodSecurityContext{ RunAsUser: &runAsUser, @@ -465,8 +508,32 @@ func TestBuildResourceSpark(t *testing.T) { EnableHostNetworkingPod: &defaultPodHostNetwork, DefaultEnvVars: defaultEnvVars, DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv, - }), - ) + } + return config +} + +func TestBuildResourceContainer(t *testing.T) { + sparkResourceHandler := sparkResourceHandler{} + + // Case1: Valid Spark Task-Template + taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) + + // Set spark custom feature config. + assert.NoError(t, setSparkConfig(&Config{ + Features: []Feature{ + { + Name: "feature1", + SparkConfig: map[string]string{"spark.hadoop.feature1": "true"}, + }, + { + Name: "feature2", + SparkConfig: map[string]string{"spark.hadoop.feature2": "true"}, + }, + }, + })) + + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true)) assert.Nil(t, err) @@ -479,28 +546,16 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, testArgs, sparkApp.Spec.Arguments) assert.Equal(t, testImage, *sparkApp.Spec.Image) assert.NotNil(t, sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt) - assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser) + assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser) assert.NotNil(t, sparkApp.Spec.Driver.DNSConfig) assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, sparkApp.Spec.Driver.DNSConfig.Nameservers) - assert.Equal(t, "ndots", sparkApp.Spec.Driver.DNSConfig.Options[0].Name) - assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Driver.DNSConfig.Options[0].Value) - assert.Equal(t, "single-request-reopen", sparkApp.Spec.Driver.DNSConfig.Options[1].Name) - assert.Equal(t, "timeout", sparkApp.Spec.Driver.DNSConfig.Options[2].Name) - assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Driver.DNSConfig.Options[2].Value) - assert.Equal(t, "attempts", sparkApp.Spec.Driver.DNSConfig.Options[3].Name) - assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Driver.DNSConfig.Options[3].Value) + assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Driver.DNSConfig.Options) assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Driver.DNSConfig.Searches) assert.NotNil(t, sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt) - assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser) + assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser) assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig) assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig) - assert.Equal(t, "ndots", sparkApp.Spec.Executor.DNSConfig.Options[0].Name) - assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Executor.DNSConfig.Options[0].Value) - assert.Equal(t, "single-request-reopen", sparkApp.Spec.Executor.DNSConfig.Options[1].Name) - assert.Equal(t, "timeout", sparkApp.Spec.Executor.DNSConfig.Options[2].Name) - assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Executor.DNSConfig.Options[2].Value) - assert.Equal(t, "attempts", sparkApp.Spec.Executor.DNSConfig.Options[3].Name) - assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Executor.DNSConfig.Options[3].Value) + assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Executor.DNSConfig.Options) assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Executor.DNSConfig.Searches) //Validate Driver/Executor Spec. @@ -515,19 +570,19 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler) - assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName) - assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName) - assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork) - assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Driver.HostNetwork) // Validate - // * Interruptible Toleration and NodeSelector set for Executor but not Driver. - // * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor. - // * Default Tolerations set for both Driver and Executor. - // * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity. + // * Default tolerations set for both Driver and Executor. + // * Interruptible tolerations and node selector set for Executor but not Driver. + // * Default node selector set for both Driver and Executor. + // * Interruptible node selector requirements set for Executor Affinity, non-interruptiblefir Driver Affinity. assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0] assert.Equal(t, tolDriverDefault.Key, "x/flyte") assert.Equal(t, tolDriverDefault.Value, "default") @@ -535,21 +590,23 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule")) assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) - assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) - assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector) - - tolExecDefault := sparkApp.Spec.Executor.Tolerations[0] - assert.Equal(t, tolExecDefault.Key, "x/flyte") - assert.Equal(t, tolExecDefault.Value, "default") - assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal")) - assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule")) + assert.Equal(t, 2, len(sparkApp.Spec.Executor.NodeSelector)) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) - tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1] + tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[0] assert.Equal(t, tolExecInterrupt.Key, "x/flyte") assert.Equal(t, tolExecInterrupt.Value, "interruptible") assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal")) assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule")) - assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"]) + + tolExecDefault := sparkApp.Spec.Executor.Tolerations[1] + assert.Equal(t, tolExecDefault.Key, "x/flyte") + assert.Equal(t, tolExecDefault.Value, "default") + assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal")) + assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule")) for confKey, confVal := range dummySparkConf { exists := false @@ -583,31 +640,36 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"]) assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"]) - assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"]) - assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv) - assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv) - - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *interruptibleNodeSelectorRequirement, - ) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) // Case 2: Driver/Executor request cores set. dummyConfWithRequest := make(map[string]string) @@ -619,7 +681,7 @@ func TestBuildResourceSpark(t *testing.T) { dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3" dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4" - taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest) + taskTemplate = dummySparkTaskTemplateContainer("blah-1", dummyConfWithRequest) resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) assert.Nil(t, err) assert.NotNil(t, resource) @@ -640,36 +702,41 @@ func TestBuildResourceSpark(t *testing.T) { // Validate that the default Toleration and NodeSelector are set for both Driver and Executors. assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) - assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Executor.NodeSelector) assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte") assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default") assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte") assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default") // Validate correct affinity and nodeselector requirements are set for both Driver and Executors. - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], - ) - assert.Equal( - t, - sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], - *nonInterruptibleNodeSelectorRequirement, - ) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil @@ -678,6 +745,122 @@ func TestBuildResourceSpark(t *testing.T) { assert.Nil(t, resource) } +func TestBuildResourcePodTemplate(t *testing.T) { + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) + extraToleration := corev1.Toleration{ + Key: "x/flyte", + Value: "extra", + Operator: "Equal", + } + podSpec := dummyPodSpec() + podSpec.Tolerations = append(podSpec.Tolerations, extraToleration) + podSpec.NodeSelector = map[string]string{"x/custom": "foo"} + taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec) + taskTemplate.GetK8SPod() + sparkResourceHandler := sparkResourceHandler{} + + taskCtx := dummySparkTaskContext(taskTemplate, true) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) + + assert.Nil(t, err) + assert.NotNil(t, resource) + sparkApp, ok := resource.(*sj.SparkApplication) + assert.True(t, ok) + + // Application + assert.Equal(t, v1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, sparkApp.TypeMeta) + + // Application spec + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) + assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: intPtr(int32(14)), + }, sparkApp.Spec.RestartPolicy) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + + // Driver + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Driver.Labels) + assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + extraToleration, + }, sparkApp.Spec.Driver.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + }, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + + // Executor + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{"annotation-1": "val1"}), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{"label-1": "val1"}), sparkApp.Spec.Executor.Labels) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"]) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"]) + assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.ElementsMatch(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + extraToleration, + defaultConfig.InterruptibleTolerations[0], + }, sparkApp.Spec.Executor.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) + cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) + instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) +} + func TestGetPropertiesSpark(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} expected := k8s.PluginProperties{} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index a9b233a3bb..c50d361726 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -5,6 +5,7 @@ import ( "crypto/x509" "encoding/gob" "fmt" + "time" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flytestdlib/config" @@ -39,6 +40,7 @@ type Plugin struct { type ResourceWrapper struct { State admin.State Outputs *flyteIdl.LiteralMap + Message string } type ResourceMetaWrapper struct { @@ -141,6 +143,7 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba return ResourceWrapper{ State: res.Resource.State, Outputs: res.Resource.Outputs, + Message: res.Resource.Message, }, nil } @@ -171,6 +174,8 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase taskInfo := &core.TaskInfo{} switch resource.State { + case admin.State_PENDING: + return core.PhaseInfoInitializing(time.Now(), core.DefaultPhaseVersion, resource.Message, taskInfo), nil case admin.State_RUNNING: return core.PhaseInfoRunning(core.DefaultPhaseVersion, taskInfo), nil case admin.State_PERMANENT_FAILURE: diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 7a4ea350b6..3a8e759908 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -5,12 +5,14 @@ import ( "testing" "time" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flytestdlib/config" "google.golang.org/grpc" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/stretchr/testify/assert" ) @@ -99,4 +101,70 @@ func TestPlugin(t *testing.T) { ctx, _ = getFinalContext(context.TODO(), "CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.NotEqual(t, context.TODO(), ctx) }) + + t.Run("test PENDING Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(ResourceWrapper{ + State: admin.State_PENDING, + Outputs: nil, + Message: "Waiting for cluster", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseInitializing, phase.Phase()) + assert.Equal(t, "Waiting for cluster", phase.Reason()) + }) + + t.Run("test RUNNING Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(ResourceWrapper{ + State: admin.State_RUNNING, + Outputs: nil, + Message: "Job is running", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase()) + }) + + t.Run("test PERMANENT_FAILURE Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(ResourceWrapper{ + State: admin.State_PERMANENT_FAILURE, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhasePermanentFailure, phase.Phase()) + }) + + t.Run("test RETRYABLE_FAILURE Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(ResourceWrapper{ + State: admin.State_RETRYABLE_FAILURE, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase()) + }) + + t.Run("test UNDEFINED Status", func(t *testing.T) { + taskContext := new(webapiPlugin.StatusContext) + taskContext.On("Resource").Return(ResourceWrapper{ + State: 5, + Outputs: nil, + Message: "", + }) + + phase, err := plugin.Status(context.Background(), taskContext) + assert.Error(t, err) + assert.Equal(t, pluginsCore.PhaseUndefined, phase.Phase()) + }) } diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index 1ff81ef28c..85d6850325 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -315,7 +315,7 @@ func getAdminClient(ctx context.Context) (client service.AdminServiceClient, sig credentialsFuture := admin.NewPerRPCCredentialsFuture() opts := []grpc.DialOption{ - grpc.WithChainUnaryInterceptor(admin.NewAuthInterceptor(cfg, nil, credentialsFuture)), + grpc.WithChainUnaryInterceptor(admin.NewAuthInterceptor(cfg, nil, credentialsFuture, nil)), grpc.WithPerRPCCredentials(credentialsFuture), }