From 5ec11e017aad013e11f28d07559c33bb8be287ba Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 4 Jul 2024 20:29:26 +0800 Subject: [PATCH] refactor to use binding `flyteidl-rust`'s `RawSynchronousFlyteClient` Signed-off-by: Austin Liu --- flytekit/clients/friendly.py | 31 +++---- flytekit/interaction/click_types.py | 6 +- flytekit/models/annotation.py | 10 +-- flytekit/models/common.py | 35 +++++--- flytekit/models/core/identifier.py | 33 +++++-- flytekit/models/documentation.py | 19 ++-- flytekit/models/execution.py | 119 ++++++++++++++++---------- flytekit/models/interface.py | 19 ++-- flytekit/models/literals.py | 68 +++++++++------ flytekit/models/matchable_resource.py | 3 +- flytekit/models/security.py | 18 ++-- flytekit/models/task.py | 102 ++++++++++++---------- flytekit/models/types.py | 101 +++++++++++----------- flytekit/remote/entities.py | 2 +- flytekit/remote/remote.py | 4 +- 15 files changed, 332 insertions(+), 238 deletions(-) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 58038d12ec..d297d7569a 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1,6 +1,7 @@ import datetime import typing +import flyteidl_rust as flyteidl from flyteidl.admin import common_pb2 as _common_pb2 from flyteidl.admin import execution_pb2 as _execution_pb2 from flyteidl.admin import launch_plan_pb2 as _launch_plan_pb2 @@ -9,13 +10,11 @@ from flyteidl.admin import project_domain_attributes_pb2 as _project_domain_attributes_pb2 from flyteidl.admin import project_pb2 as _project_pb2 from flyteidl.admin import task_execution_pb2 as _task_execution_pb2 -from flyteidl.admin import task_pb2 as _task_pb2 from flyteidl.admin import workflow_attributes_pb2 as _workflow_attributes_pb2 from flyteidl.admin import workflow_pb2 as _workflow_pb2 from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 from google.protobuf.duration_pb2 import Duration -from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient from flytekit.models import common as _common from flytekit.models import execution as _execution from flytekit.models import filters as _filters @@ -29,7 +28,7 @@ from flytekit.models.core import identifier as _identifier -class SynchronousFlyteClient(_RawSynchronousFlyteClient): +class SynchronousFlyteClient(flyteidl.RawSynchronousFlyteClient): """ This is a low-level client that users can use to make direct gRPC service calls to the control plane. See the :std:doc:`service spec `. This is more user-friendly interface than the @@ -75,7 +74,7 @@ def create_task(self, task_identifer, task_spec): :raises grpc.RpcError: """ super(SynchronousFlyteClient, self).create_task( - _task_pb2.TaskCreateRequest(id=task_identifer.to_flyte_idl(), spec=task_spec.to_flyte_idl()) + flyteidl.admin.TaskCreateRequest(id=task_identifer.to_flyte_idl(), spec=task_spec.to_flyte_idl()) ) def list_task_ids_paginated(self, project, domain, limit=100, token=None, sort_by=None): @@ -173,7 +172,7 @@ def get_task(self, id): :rtype: flytekit.models.task.Task """ return _task.Task.from_flyte_idl( - super(SynchronousFlyteClient, self).get_task(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl())) + super(SynchronousFlyteClient, self).get_task(flyteidl.admin.ObjectGetRequest(id=id.to_flyte_idl())) ) #################################################################################################################### @@ -551,12 +550,13 @@ def create_execution(self, project, domain, name, execution_spec, inputs): return _identifier.WorkflowExecutionIdentifier.from_flyte_idl( super(SynchronousFlyteClient, self) .create_execution( - _execution_pb2.ExecutionCreateRequest( + flyteidl.admin.ExecutionCreateRequest( project=project, domain=domain, name=name, spec=execution_spec.to_flyte_idl(), inputs=inputs.to_flyte_idl(), + org="", ) ) .id @@ -582,7 +582,7 @@ def get_execution(self, id): """ return _execution.Execution.from_flyte_idl( super(SynchronousFlyteClient, self).get_execution( - _execution_pb2.WorkflowExecutionGetRequest(id=id.to_flyte_idl()) + flyteidl.admin.WorkflowExecutionGetRequest(id=id.to_flyte_idl()) ) ) @@ -595,7 +595,7 @@ def get_execution_data(self, id): """ return _execution.WorkflowExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_execution_data( - _execution_pb2.WorkflowExecutionGetDataRequest(id=id.to_flyte_idl()) + flyteidl.admin.WorkflowExecutionGetDataRequest(id=id.to_flyte_idl()) ) ) @@ -677,7 +677,7 @@ def get_node_execution(self, node_execution_identifier): """ return _node_execution.NodeExecution.from_flyte_idl( super(SynchronousFlyteClient, self).get_node_execution( - _node_execution_pb2.NodeExecutionGetRequest(id=node_execution_identifier.to_flyte_idl()) + flyteidl.admin.NodeExecutionGetRequest(id=node_execution_identifier.to_flyte_idl()) ) ) @@ -689,7 +689,7 @@ def get_node_execution_data(self, node_execution_identifier) -> _execution.NodeE """ return _execution.NodeExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_node_execution_data( - _node_execution_pb2.NodeExecutionGetDataRequest(id=node_execution_identifier.to_flyte_idl()) + flyteidl.admin.NodeExecutionGetDataRequest(id=node_execution_identifier.to_flyte_idl()) ) ) @@ -715,7 +715,7 @@ def list_node_executions( :rtype: list[flytekit.models.node_execution.NodeExecution], Text """ exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated( - _node_execution_pb2.NodeExecutionListRequest( + flyteidl.admin.NodeExecutionListRequest( workflow_execution_id=workflow_execution_identifier.to_flyte_idl(), limit=limit, token=token, @@ -775,7 +775,7 @@ def get_task_execution(self, id): """ return _task_execution.TaskExecution.from_flyte_idl( super(SynchronousFlyteClient, self).get_task_execution( - _task_execution_pb2.TaskExecutionGetRequest(id=id.to_flyte_idl()) + flyteidl.admin.TaskExecutionGetRequest(id=id.to_flyte_idl()) ) ) @@ -788,7 +788,7 @@ def get_task_execution_data(self, task_execution_identifier): """ return _execution.TaskExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_task_execution_data( - _task_execution_pb2.TaskExecutionGetDataRequest(id=task_execution_identifier.to_flyte_idl()) + flyteidl.admin.TaskExecutionGetDataRequest(id=task_execution_identifier.to_flyte_idl()) ) ) @@ -1010,14 +1010,15 @@ def get_upload_signed_url( expires_in_pb = Duration() expires_in_pb.FromTimedelta(expires_in) return super(SynchronousFlyteClient, self).create_upload_location( - _data_proxy_pb2.CreateUploadLocationRequest( + flyteidl.service.CreateUploadLocationRequest( project=project, domain=domain, content_md5=content_md5, filename=filename, expires_in=expires_in_pb, - filename_root=filename_root, + filename_root=filename_root or "", add_content_md5_metadata=add_content_md5_metadata, + org="", ) ) except Exception as e: diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 5728dce3d0..652218a174 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -326,12 +326,12 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli Converts a Flyte LiteralType given a python_type to a click.ParamType """ if lt.simple: - if lt.simple == SimpleType.STRUCT: + if int(str(lt.simple)) == SimpleType.STRUCT: ct = JsonParamType(python_type) ct.name = f"JSON object {python_type.__name__}" return ct - if lt.simple in SIMPLE_TYPE_CONVERTER: - return SIMPLE_TYPE_CONVERTER[lt.simple] + if int(str(lt.simple)) in SIMPLE_TYPE_CONVERTER: + return SIMPLE_TYPE_CONVERTER[int(str(lt.simple))] raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run") if lt.enum_type: diff --git a/flytekit/models/annotation.py b/flytekit/models/annotation.py index 1c17aabc5e..530c66bb4f 100644 --- a/flytekit/models/annotation.py +++ b/flytekit/models/annotation.py @@ -1,9 +1,8 @@ import json from typing import Any, Dict -from flyteidl.core import types_pb2 as _types_pb2 +import flyteidl_rust as flyteidl from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct class TypeAnnotation: @@ -19,17 +18,16 @@ def annotations(self) -> Dict[str, Any]: """ return self._annotations - def to_flyte_idl(self) -> _types_pb2.TypeAnnotation: + def to_flyte_idl(self) -> flyteidl.core.TypeAnnotation: """ :rtype: flyteidl.core.types_pb2.TypeAnnotation """ - if self._annotations is not None: - annotations = _json_format.Parse(json.dumps(self.annotations), _struct.Struct()) + annotations = _json_format.Parse(json.dumps(self.annotations), flyteidl.wkt.Struct()) else: annotations = None - return _types_pb2.TypeAnnotation( + return flyteidl.core.TypeAnnotation( annotations=annotations, ) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 77ae72e703..2678e8f847 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -3,8 +3,8 @@ import re from typing import Dict +import flyteidl_rust as flytedidl from flyteidl.admin import common_pb2 as _common_pb2 -from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -302,11 +302,20 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.Notification """ - return _common_pb2.Notification( + _type = None + if self.email: + _type = flytedidl.notification.Type.Email(self.email) + elif self.pager_duty: + _type = flytedidl.notification.Type.PagerDuty(self.pager_duty) + elif self.slack: + _type = flytedidl.notification.Type.Slack(self.slack) + + return flytedidl.notification.Type( phases=self.phases, - email=self.email.to_flyte_idl() if self.email else None, - pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None, - slack=self.slack.to_flyte_idl() if self.slack else None, + type=_type, + # email=self.email.to_flyte_idl() if self.email else None, + # pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None, + # slack=self.slack.to_flyte_idl() if self.slack else None, ) @classmethod @@ -340,7 +349,7 @@ def to_flyte_idl(self): """ :rtype: dict[Text, Text] """ - return _common_pb2.Labels(values={k: v for k, v in self.values.items()}) + return flytedidl.admin.Labels(values={k: v for k, v in self.values.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -368,7 +377,7 @@ def to_flyte_idl(self): """ :rtype: _common_pb2.Annotations """ - return _common_pb2.Annotations(values={k: v for k, v in self.values.items()}) + return flytedidl.admin.Annotations(values={k: v for k, v in self.values.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -451,9 +460,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.launch_plan_pb2.Auth """ - return _common_pb2.AuthRole( - assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None, - kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None, + return flytedidl.admin.AuthRole( + assumable_iam_role=self.assumable_iam_role or "", + kubernetes_service_account=self.kubernetes_service_account or "", ) @classmethod @@ -483,7 +492,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.Auth """ - return _common_pb2.RawOutputDataConfig(output_location_prefix=self.output_location_prefix) + return flytedidl.admin.RawOutputDataConfig(output_location_prefix=self.output_location_prefix) @classmethod def from_flyte_idl(cls, pb2): @@ -498,8 +507,8 @@ def __init__(self, envs: Dict[str, str]): def envs(self) -> Dict[str, str]: return self._envs - def to_flyte_idl(self) -> _common_pb2.Envs: - return _common_pb2.Envs(values=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.envs.items()]) + def to_flyte_idl(self) -> flytedidl.admin.Envs: + return flytedidl.admin.Envs(values=[flytedidl.core.KeyValuePair(key=k, value=v) for k, v in self.envs.items()]) @classmethod def from_flyte_idl(cls, pb2: _common_pb2.Envs) -> _common_pb2.Envs: diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index 8a45232e38..f6e0906d0a 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -1,13 +1,14 @@ +import flyteidl_rust as flyteidl from flyteidl.core import identifier_pb2 as identifier_pb2 from flytekit.models import common as _common_models class ResourceType(object): - UNSPECIFIED = identifier_pb2.UNSPECIFIED - TASK = identifier_pb2.TASK - WORKFLOW = identifier_pb2.WORKFLOW - LAUNCH_PLAN = identifier_pb2.LAUNCH_PLAN + UNSPECIFIED = int(flyteidl.core.ResourceType.Unspecified) + TASK = int(flyteidl.core.ResourceType.Task) + WORKFLOW = int(flyteidl.core.ResourceType.Workflow) + LAUNCH_PLAN = int(flyteidl.core.ResourceType.LaunchPlan) class Identifier(_common_models.FlyteIdlEntity): @@ -24,6 +25,7 @@ def __init__(self, resource_type, project, domain, name, version): self._domain = domain self._name = name self._version = version + self._org = "" @property def resource_type(self): @@ -64,16 +66,24 @@ def version(self): """ return self._version + @property + def org(self): + """ + :rtype: Text + """ + return self._org + def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.Identifier """ - return identifier_pb2.Identifier( + return flyteidl.core.Identifier( resource_type=self.resource_type, project=self.project, domain=self.domain, name=self.name, version=self.version, + org=self.org, ) @classmethod @@ -107,6 +117,7 @@ def __init__(self, project, domain, name): self._project = project self._domain = domain self._name = name + self._org = "" @property def project(self): @@ -129,14 +140,22 @@ def name(self): """ return self._name + @property + def org(self): + """ + :rtype: Text + """ + return self._org + def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier """ - return identifier_pb2.WorkflowExecutionIdentifier( + return flyteidl.core.WorkflowExecutionIdentifier( project=self.project, domain=self.domain, name=self.name, + org=self.org, ) @classmethod @@ -179,7 +198,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier """ - return identifier_pb2.NodeExecutionIdentifier( + return flyteidl.core.NodeExecutionIdentifier( node_id=self.node_id, execution_id=self.execution_id.to_flyte_idl(), ) diff --git a/flytekit/models/documentation.py b/flytekit/models/documentation.py index e1bae8122e..7d6001415d 100644 --- a/flytekit/models/documentation.py +++ b/flytekit/models/documentation.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Optional +import flyteidl_rust as flyteidl from flyteidl.admin import description_entity_pb2 from flytekit.models import common as _common_models @@ -27,11 +28,16 @@ class DescriptionFormat(Enum): format: DescriptionFormat = DescriptionFormat.RST def to_flyte_idl(self): - return description_entity_pb2.Description( - value=self.value if self.value else None, - uri=self.uri if self.uri else None, + return flyteidl.admin.Description( + content=flyteidl.description.Content.Value(self.value) + if self.value + else flyteidl.description.Content.Uri(self.uri) + if self.uri + else None, + # value=self.value if self.value else None, + # uri=self.uri if self.uri else None, format=self.format.value, - icon_link=self.icon_link, + icon_link=self.icon_link or "", ) @classmethod @@ -76,10 +82,11 @@ class Documentation(_common_models.FlyteIdlEntity): source_code: Optional[SourceCode] = None def to_flyte_idl(self): - return description_entity_pb2.DescriptionEntity( - short_description=self.short_description, + return flyteidl.admin.DescriptionEntity( + short_description=self.short_description or "", long_description=self.long_description.to_flyte_idl() if self.long_description else None, source_code=self.source_code.to_flyte_idl() if self.source_code else None, + tags=[], ) @classmethod diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 7e4ff02645..959839d8be 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -2,14 +2,11 @@ import datetime import typing +from datetime import timedelta from datetime import timezone as _timezone from typing import Optional -import flyteidl -import flyteidl.admin.cluster_assignment_pb2 as _cluster_assignment_pb2 -import flyteidl.admin.execution_pb2 as _execution_pb2 -import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 -import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +import flyteidl_rust as flyteidl import flytekit from flytekit.models import common as _common_models @@ -21,6 +18,13 @@ from flytekit.models.node_execution import DynamicWorkflowNodeMetadata +# A helping function for `from_flyte_idl()` +def convert_to_datetime(seconds: int, nanos: int) -> datetime.datetime: + total_microseconds = (seconds * 1_000_000) + (nanos // 1_000) + dt = (datetime.datetime(1970, 1, 1) + timedelta(microseconds=total_microseconds)).replace(tzinfo=_timezone.utc) + return dt + + class SystemMetadata(_common_models.FlyteIdlEntity): def __init__(self, execution_cluster: str): self._execution_cluster = execution_cluster @@ -30,7 +34,7 @@ def execution_cluster(self) -> str: return self._execution_cluster def to_flyte_idl(self) -> flyteidl.admin.execution_pb2.SystemMetadata: - return _execution_pb2.SystemMetadata(execution_cluster=self.execution_cluster) + return flyteidl.admin.SystemMetadata(execution_cluster=self.execution_cluster) @classmethod def from_flyte_idl(cls, pb2_object: flyteidl.admin.execution_pb2.SystemMetadata) -> SystemMetadata: @@ -126,7 +130,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionMetadata """ - p = _execution_pb2.ExecutionMetadata( + p = flyteidl.admin.ExecutionMetadata( mode=self.mode, principal=self.principal, nesting=self.nesting, @@ -137,9 +141,11 @@ def to_flyte_idl(self): if self.reference_execution is not None else None, system_metadata=self.system_metadata.to_flyte_idl() if self.system_metadata is not None else None, + artifact_ids=[], ) + if self.scheduled_at is not None: - p.scheduled_at.FromDatetime(self.scheduled_at) + p.scheduled_at = flyteidl.wkt.Timestamp(seconds=self.scheduled_at.seconds, nanos=self.scheduled_at.nanos) return p @classmethod @@ -152,15 +158,19 @@ def from_flyte_idl(cls, pb2_object): mode=pb2_object.mode, principal=pb2_object.principal, nesting=pb2_object.nesting, - scheduled_at=pb2_object.scheduled_at.ToDatetime() if pb2_object.HasField("scheduled_at") else None, + scheduled_at=convert_to_datetime(pb2_object.scheduled_at.seconds, pb2_object.scheduled_at.nanos) + if pb2_object.scheduled_at + else convert_to_datetime( + 0, 0 + ), # pb2_object.scheduled_at.ToDatetime() if pb2_object.HasField("scheduled_at") else None, parent_node_execution=_identifier.NodeExecutionIdentifier.from_flyte_idl(pb2_object.parent_node_execution) - if pb2_object.HasField("parent_node_execution") + if pb2_object.parent_node_execution else None, reference_execution=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(pb2_object.reference_execution) - if pb2_object.HasField("reference_execution") + if pb2_object.reference_execution else None, system_metadata=SystemMetadata.from_flyte_idl(pb2_object.system_metadata) - if pb2_object.HasField("system_metadata") + if pb2_object.system_metadata else None, ) @@ -307,26 +317,31 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionSpec """ - return _execution_pb2.ExecutionSpec( + + return flyteidl.admin.ExecutionSpec( launch_plan=self.launch_plan.to_flyte_idl(), metadata=self.metadata.to_flyte_idl(), - notifications=self.notifications.to_flyte_idl() if self.notifications else None, - disable_all=self.disable_all, # type: ignore + # notifications=self.notifications.to_flyte_idl() if self.notifications else None, + # disable_all=self.disable_all, # type: ignore + notification_overrides=self.notifications.to_flyte_idl() + if isinstance(self.notifications, flyteidl.execution_spec.NotificationOverrides.Notifications) + else flyteidl.execution_spec.NotificationOverrides.DisableAll(self.disable_all), labels=self.labels.to_flyte_idl(), annotations=self.annotations.to_flyte_idl(), auth_role=self._auth_role.to_flyte_idl() if self.auth_role else None, raw_output_data_config=self._raw_output_data_config.to_flyte_idl() if self._raw_output_data_config else None, - max_parallelism=self.max_parallelism, + max_parallelism=self.max_parallelism or 0, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, envs=self.envs.to_flyte_idl() if self.envs else None, - tags=self.tags, + tags=self.tags or [], cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, execution_cluster_label=self._execution_cluster_label.to_flyte_idl() if self._execution_cluster_label else None, + execution_env_assignments=[], ) @classmethod @@ -338,26 +353,28 @@ def from_flyte_idl(cls, p): return cls( launch_plan=_identifier.Identifier.from_flyte_idl(p.launch_plan), metadata=ExecutionMetadata.from_flyte_idl(p.metadata), - notifications=NotificationList.from_flyte_idl(p.notifications) if p.HasField("notifications") else None, - disable_all=p.disable_all if p.HasField("disable_all") else None, + notifications=NotificationList.from_flyte_idl(p.notification_overrides[0]) + if isinstance(p.notification_overrides, flyteidl.execution_spec.NotificationOverrides.Notifications) + else None, + disable_all=p.notification_overrides[0] + if isinstance(p.notification_overrides, flyteidl.execution_spec.NotificationOverrides.DisableAll) + else None, labels=_common_models.Labels.from_flyte_idl(p.labels), annotations=_common_models.Annotations.from_flyte_idl(p.annotations), auth_role=_common_models.AuthRole.from_flyte_idl(p.auth_role), raw_output_data_config=_common_models.RawOutputDataConfig.from_flyte_idl(p.raw_output_data_config) - if p.HasField("raw_output_data_config") + if p.raw_output_data_config else None, max_parallelism=p.max_parallelism, security_context=security.SecurityContext.from_flyte_idl(p.security_context) if p.security_context else None, overwrite_cache=p.overwrite_cache, - envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, + envs=_common_models.Envs.from_flyte_idl(p.envs) if p.envs else None, tags=p.tags, - cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) - if p.HasField("cluster_assignment") - else None, + cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) if p.cluster_assignment else None, execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(p.execution_cluster_label) - if p.HasField("execution_cluster_label") + if p.execution_cluster_label else None, ) @@ -380,7 +397,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin._cluster_assignment_pb2.ClusterAssignment """ - return _cluster_assignment_pb2.ClusterAssignment( + return flyteidl.admin.ClusterAssignment( cluster_pool_name=self._cluster_pool, ) @@ -420,7 +437,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.LiteralMapBlob """ - return _execution_pb2.LiteralMapBlob( + return flyteidl.admin.LiteralMapBlob( values=self.values.to_flyte_idl() if self.values is not None else None, uri=self.uri, ) @@ -473,7 +490,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.Execution """ - return _execution_pb2.Execution( + return flyteidl.admin.Execution( id=self.id.to_flyte_idl(), closure=self.closure.to_flyte_idl(), spec=self.spec.to_flyte_idl(), @@ -506,7 +523,7 @@ def principal(self) -> str: return self._principal def to_flyte_idl(self) -> flyteidl.admin.execution_pb2.AbortMetadata: - return _execution_pb2.AbortMetadata(cause=self.cause, principal=self.principal) + return flyteidl.admin.AbortMetadata(cause=self.cause, principal=self.principal) @classmethod def from_flyte_idl(cls, pb2_object: flyteidl.admin.execution_pb2.AbortMetadata) -> AbortMetadata: @@ -584,7 +601,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionClosure """ - obj = _execution_pb2.ExecutionClosure( + obj = flyteidl.admin.ExecutionClosure( phase=self.phase, error=self.error.to_flyte_idl() if self.error is not None else None, outputs=self.outputs.to_flyte_idl() if self.outputs is not None else None, @@ -605,26 +622,36 @@ def from_flyte_idl(cls, pb2_object): :rtype: ExecutionClosure """ error = None - if pb2_object.HasField("error"): - error = _core_execution.ExecutionError.from_flyte_idl(pb2_object.error) + if isinstance(pb2_object, flyteidl.execution_closure.OutputResult.Error): + error = _core_execution.ExecutionError.from_flyte_idl(pb2_object.output_result) outputs = None - if pb2_object.HasField("outputs"): - outputs = LiteralMapBlob.from_flyte_idl(pb2_object.outputs) + if isinstance(pb2_object, flyteidl.execution_closure.OutputResult.Outputs): + outputs = LiteralMapBlob.from_flyte_idl(pb2_object.output_result) abort_metadata = None - if pb2_object.HasField("abort_metadata"): - abort_metadata = AbortMetadata.from_flyte_idl(pb2_object.abort_metadata) + if isinstance(pb2_object, flyteidl.execution_closure.OutputResult.AbortMetadata): + abort_metadata = AbortMetadata.from_flyte_idl(pb2_object.output_result) return cls( error=error, outputs=outputs, phase=pb2_object.phase, - started_at=pb2_object.started_at.ToDatetime().replace(tzinfo=_timezone.utc), - duration=pb2_object.duration.ToTimedelta(), + # started_at=pb2_object.started_at.ToDatetime().replace(tzinfo=_timezone.utc), + started_at=convert_to_datetime(pb2_object.started_at.seconds, pb2_object.started_at.nanos) + if pb2_object.started_at + else None, + # duration=pb2_object.duration.ToTimedelta(), + duration=timedelta(seconds=pb2_object.duration.seconds, milliseconds=pb2_object.duration.nanos // 1_000_000) + if pb2_object.duration + else timedelta() + if pb2_object.duration + else None, abort_metadata=abort_metadata, - created_at=pb2_object.created_at.ToDatetime().replace(tzinfo=_timezone.utc) - if pb2_object.HasField("created_at") + # created_at=pb2_object.created_at.ToDatetime().replace(tzinfo=_timezone.utc) + created_at=convert_to_datetime(pb2_object.created_at.seconds, pb2_object.created_at.nanos) + if pb2_object.created_at else None, - updated_at=pb2_object.updated_at.ToDatetime().replace(tzinfo=_timezone.utc) - if pb2_object.HasField("updated_at") + # updated_at=pb2_object.updated_at.ToDatetime().replace(tzinfo=_timezone.utc) + updated_at=convert_to_datetime(pb2_object.updated_at.seconds, pb2_object.updated_at.nanos) + if pb2_object.updated_at else None, ) @@ -647,7 +674,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.NotificationList """ - return _execution_pb2.NotificationList(notifications=[n.to_flyte_idl() for n in self.notifications]) + return flyteidl.admin.NotificationList(notifications=[n.to_flyte_idl() for n in self.notifications]) @classmethod def from_flyte_idl(cls, pb2_object): @@ -723,7 +750,7 @@ def to_flyte_idl(self): """ :rtype: _execution_pb2.WorkflowExecutionGetDataResponse """ - return _execution_pb2.WorkflowExecutionGetDataResponse( + return flyteidl.admin.WorkflowExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), full_inputs=self.full_inputs.to_flyte_idl(), @@ -749,7 +776,7 @@ def to_flyte_idl(self): """ :rtype: _task_execution_pb2.TaskExecutionGetDataResponse """ - return _task_execution_pb2.TaskExecutionGetDataResponse( + return flyteidl.admin.TaskExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), full_inputs=self.full_inputs.to_flyte_idl(), @@ -786,7 +813,7 @@ def to_flyte_idl(self): """ :rtype: _node_execution_pb2.NodeExecutionGetDataResponse """ - return _node_execution_pb2.NodeExecutionGetDataResponse( + return flyteidl.admin.NodeExecutionGetDataResponse( inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), full_inputs=self.full_inputs.to_flyte_idl(), diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index f80bfb9e52..9c8d9458ef 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -1,5 +1,6 @@ import typing +import flyteidl_rust as flyteidl from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import interface_pb2 as _interface_pb2 @@ -57,7 +58,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.interface_pb2.Variable """ - return _interface_pb2.Variable( + return flyteidl.core.Variable( type=self.type.to_flyte_idl(), description=self.description, artifact_partial_id=self.artifact_partial_id, @@ -65,17 +66,15 @@ def to_flyte_idl(self): ) @classmethod - def from_flyte_idl(cls, variable_proto) -> _interface_pb2.Variable: + def from_flyte_idl(cls, variable_proto) -> flyteidl.core.Variable: """ :param flyteidl.core.interface_pb2.Variable variable_proto: """ return cls( type=_types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description, - artifact_partial_id=variable_proto.artifact_partial_id - if variable_proto.HasField("artifact_partial_id") - else None, - artifact_tag=variable_proto.artifact_tag if variable_proto.HasField("artifact_tag") else None, + artifact_partial_id=variable_proto.artifact_partial_id or None, + artifact_tag=variable_proto.artifact_tag or None, ) @@ -130,10 +129,10 @@ def inputs(self) -> typing.Dict[str, Variable]: def outputs(self) -> typing.Dict[str, Variable]: return self._outputs - def to_flyte_idl(self) -> _interface_pb2.TypedInterface: - return _interface_pb2.TypedInterface( - inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.inputs.items()}), - outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.outputs.items()}), + def to_flyte_idl(self) -> flyteidl.core.TypedInterface: + return flyteidl.core.TypedInterface( + inputs=flyteidl.core.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.inputs.items()}), + outputs=flyteidl.core.VariableMap(variables={k: v.to_flyte_idl() for k, v in self.outputs.items()}), ) @classmethod diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index e08c495b67..294201b95a 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -2,6 +2,7 @@ from datetime import timezone as _timezone from typing import Dict, Optional +import flyteidl_rust as flyteidl from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf.struct_pb2 import Struct @@ -34,7 +35,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.RetryStrategy """ - return _literals_pb2.RetryStrategy(retries=self.retries) + return flyteidl.core.RetryStrategy(retries=self.retries) @classmethod def from_flyte_idl(cls, pb2_object): @@ -141,11 +142,21 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Primitive """ - primitive = _literals_pb2.Primitive( - integer=self.integer, - float_value=self.float_value, - string_value=self.string_value, - boolean=self.boolean, + value = None + if self.string_value: + value = flyteidl.primitive.Value.StringValue(self.string_value) + elif self.integer: + value = flyteidl.primitive.Value.Integer(self.integer) + elif self.float_value: + value = flyteidl.primitive.Value.FloatValue(self.float_value) + elif self.boolean: + value = flyteidl.primitive.Value.Boolean(self.string_value) + primitive = flyteidl.core.Primitive( + value + # integer=self.integer, + # float_value=self.float_value, + # string_value=self.string_value, + # boolean=self.boolean, ) if self.datetime is not None: # Convert to UTC and remove timezone so protobuf behaves. @@ -689,7 +700,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.LiteralMap """ - return _literals_pb2.LiteralMap(literals={k: v.to_flyte_idl() for k, v in self.literals.items()}) + return flyteidl.core.LiteralMap(literals={k: v.to_flyte_idl() for k, v in self.literals.items()}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -818,16 +829,22 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Scalar """ - return _literals_pb2.Scalar( - primitive=self.primitive.to_flyte_idl() if self.primitive is not None else None, - blob=self.blob.to_flyte_idl() if self.blob is not None else None, - binary=self.binary.to_flyte_idl() if self.binary is not None else None, - schema=self.schema.to_flyte_idl() if self.schema is not None else None, - union=self.union.to_flyte_idl() if self.union is not None else None, - none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None, - error=self.error.to_flyte_idl() if self.error is not None else None, - generic=self.generic, - structured_dataset=self.structured_dataset.to_flyte_idl() if self.structured_dataset is not None else None, + return flyteidl.literal.Value.Scalar( + flyteidl.core.Scalar( + flyteidl.scalar.Value.Primitive( + self.primitive.to_flyte_idl() if self.primitive is not None else None, + # TODO: + # primitive=self.primitive.to_flyte_idl() if self.primitive is not None else None, + # blob=self.blob.to_flyte_idl() if self.blob is not None else None, + # binary=self.binary.to_flyte_idl() if self.binary is not None else None, + # schema=self.schema.to_flyte_idl() if self.schema is not None else None, + # union=self.union.to_flyte_idl() if self.union is not None else None, + # none_type=self.none_type.to_flyte_idl() if self.none_type is not None else None, + # error=self.error.to_flyte_idl() if self.error is not None else None, + # generic=self.generic, + # structured_dataset=self.structured_dataset.to_flyte_idl() if self.structured_dataset is not None else None, + ) + ) ) @classmethod @@ -836,9 +853,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.literals_pb2.Scalar pb2_object: :rtype: flytekit.models.literals.Scalar """ - # todo finish return cls( - primitive=Primitive.from_flyte_idl(pb2_object.primitive) if pb2_object.HasField("primitive") else None, + # TODO: + primitive=Primitive.from_flyte_idl(pb2_object.primitive) if pb2_object.primitive else None, blob=Blob.from_flyte_idl(pb2_object.blob) if pb2_object.HasField("blob") else None, binary=Binary.from_flyte_idl(pb2_object.binary) if pb2_object.HasField("binary") else None, schema=Schema.from_flyte_idl(pb2_object.schema) if pb2_object.HasField("schema") else None, @@ -929,12 +946,13 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Literal """ - return _literals_pb2.Literal( - scalar=self.scalar.to_flyte_idl() if self.scalar is not None else None, - collection=self.collection.to_flyte_idl() if self.collection is not None else None, - map=self.map.to_flyte_idl() if self.map is not None else None, - hash=self.hash, - metadata=self.metadata, + return flyteidl.core.Literal( + value=self.scalar.to_flyte_idl() if self.scalar is not None else None, + # scalar=self.scalar.to_flyte_idl() if self.scalar is not None else None, + # collection=self.collection.to_flyte_idl() if self.collection is not None else None, + # map=self.map.to_flyte_idl() if self.map is not None else None, + hash=self.hash or "", + metadata=self.metadata or {}, ) @classmethod diff --git a/flytekit/models/matchable_resource.py b/flytekit/models/matchable_resource.py index 64247f5bf5..cf5d05e4a8 100644 --- a/flytekit/models/matchable_resource.py +++ b/flytekit/models/matchable_resource.py @@ -1,3 +1,4 @@ +import flyteidl_rust as flyteidl from flyteidl.admin import matchable_resource_pb2 as _matchable_resource from flytekit.models import common as _common @@ -143,7 +144,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel """ - return _matchable_resource.ExecutionClusterLabel( + return flyteidl.admin.ExecutionClusterLabel( value=self.value, ) diff --git a/flytekit/models/security.py b/flytekit/models/security.py index e210c910b7..ad230e368e 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -2,6 +2,7 @@ from enum import Enum from typing import List, Optional +import flyteidl_rust as flyteidl from flyteidl.core import security_pb2 as _sec from flytekit.models import common as _common @@ -162,21 +163,22 @@ def __post_init__(self): if self.tokens and not isinstance(self.tokens, list): self.tokens = [self.tokens] - def to_flyte_idl(self) -> _sec.SecurityContext: + def to_flyte_idl(self) -> flyteidl.core.SecurityContext: if self.run_as is None and self.secrets is None and self.tokens is None: return None - return _sec.SecurityContext( + return flyteidl.core.SecurityContext( run_as=self.run_as.to_flyte_idl() if self.run_as else None, secrets=[s.to_flyte_idl() for s in self.secrets] if self.secrets else None, tokens=[t.to_flyte_idl() for t in self.tokens] if self.tokens else None, ) @classmethod - def from_flyte_idl(cls, pb2_object: _sec.SecurityContext) -> "SecurityContext": + def from_flyte_idl(cls, pb2_object: flyteidl.core.SecurityContext) -> "flyteidl.core.SecurityContext": + # TODO: return cls( - run_as=Identity.from_flyte_idl(pb2_object.run_as) - if pb2_object.run_as and pb2_object.run_as.ByteSize() > 0 - else None, - secrets=[Secret.from_flyte_idl(s) for s in pb2_object.secrets] if pb2_object.secrets else None, - tokens=[OAuth2TokenRequest.from_flyte_idl(t) for t in pb2_object.tokens] if pb2_object.tokens else None, + run_as=None, # Identity.from_flyte_idl(pb2_object.run_as) + # if pb2_object.run_as and pb2_object.run_as.ByteSize() > 0 + # else None, + secrets=[], # [Secret.from_flyte_idl(s) for s in pb2_object.secrets] if pb2_object.secrets else None, + tokens=[], # [OAuth2TokenRequest.from_flyte_idl(t) for t in pb2_object.tokens] if pb2_object.tokens else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 0532b276e2..7739468c5a 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -1,13 +1,10 @@ import json as _json import typing +import flyteidl_rust as flyteidl from flyteidl.admin import agent_pb2 as _admin_agent -from flyteidl.admin import task_pb2 as _admin_task -from flyteidl.core import compiler_pb2 as _compiler -from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core import tasks_pb2 as _core_task from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct from flytekit.models import common as _common from flytekit.models import interface as _interface @@ -54,7 +51,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.ResourceEntry """ - return _core_task.Resources.ResourceEntry(name=self.name, value=self.value) + return flyteidl.core.ResourceEntry(name=self.name, value=self.value) @classmethod def from_flyte_idl(cls, pb2_object): @@ -94,7 +91,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.Resources """ - return _core_task.Resources( + return flyteidl.core.Resources( requests=[r.to_flyte_idl() for r in self.requests], limits=[r.to_flyte_idl() for r in self.limits], ) @@ -155,7 +152,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata """ - return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) + return flyteidl.core.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) @classmethod def from_flyte_idl(cls, pb2_object): @@ -300,15 +297,17 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskMetadata """ - tm = _core_task.TaskMetadata( + tm = flyteidl.core.TaskMetadata( discoverable=self.discoverable, runtime=self.runtime.to_flyte_idl(), retries=self.retries.to_flyte_idl(), - interruptible=self.interruptible, + interruptible_value=self.interruptible, + generates_deck=False, + tags={}, discovery_version=self.discovery_version, deprecated_error_message=self.deprecated_error_message, cache_serializable=self.cache_serializable, - pod_template_name=self.pod_template_name, + pod_template_name=self.pod_template_name if self.pod_template_name else "", cache_ignore_input_vars=self.cache_ignore_input_vars, ) if self.timeout: @@ -321,11 +320,15 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.task_pb2.TaskMetadata pb2_object: :rtype: TaskMetadata """ + from datetime import timedelta + return cls( discoverable=pb2_object.discoverable, runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), - timeout=pb2_object.timeout.ToTimedelta(), - interruptible=pb2_object.interruptible if pb2_object.HasField("interruptible") else None, + timeout=timedelta(seconds=pb2_object.timeout.seconds, milliseconds=pb2_object.timeout.nanos // 1_000_000) + if pb2_object.timeout + else timedelta(), + interruptible=pb2_object.interruptible_value or None, retries=_literals.RetryStrategy.from_flyte_idl(pb2_object.retries), discovery_version=pb2_object.discovery_version, deprecated_error_message=pb2_object.deprecated_error_message, @@ -479,19 +482,20 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.TaskTemplate """ - task_template = _core_task.TaskTemplate( + task_template = flyteidl.core.TaskTemplate( id=self.id.to_flyte_idl(), type=self.type, metadata=self.metadata.to_flyte_idl(), interface=self.interface.to_flyte_idl(), - custom=_json_format.Parse(_json.dumps(self.custom), _struct.Struct()) if self.custom else None, - container=self.container.to_flyte_idl() if self.container else None, + custom=_json_format.Parse(_json.dumps(self.custom), flyteidl.wkt.Struct()) if self.custom else None, task_type_version=self.task_type_version, security_context=self.security_context.to_flyte_idl() if self.security_context else None, extended_resources=self.extended_resources, config={k: v for k, v in self.config.items()} if self.config is not None else None, - k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, - sql=self.sql.to_flyte_idl() if self.sql else None, + target=flyteidl.task_template.Target.Container(self.container.to_flyte_idl() or None), + # container=self.container.to_flyte_idl() if self.container else None, + # k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None, + # sql=self.sql.to_flyte_idl() if self.sql else None, ) return task_template @@ -506,16 +510,23 @@ def from_flyte_idl(cls, pb2_object): type=pb2_object.type, metadata=TaskMetadata.from_flyte_idl(pb2_object.metadata), interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface), - custom=_json_format.MessageToDict(pb2_object.custom) if pb2_object else None, - container=Container.from_flyte_idl(pb2_object.container) if pb2_object.HasField("container") else None, + custom=pb2_object.custom if pb2_object else None, + container=Container.from_flyte_idl(pb2_object.target[0]) + if isinstance(pb2_object.target, flyteidl.task_template.Target.Container) + else None, task_type_version=pb2_object.task_type_version, security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context) - if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0 + # TODO: `security_context.ByteSize()` + if pb2_object.security_context + else None, + extended_resources=pb2_object.extended_resources or None, + config={k: v for k, v in pb2_object.config.items()} if pb2_object.config else {}, + k8s_pod=K8sPod.from_flyte_idl(pb2_object.target[0]) + if isinstance(pb2_object.target, flyteidl.task_template.Target.K8sPod) + else None, + sql=Sql.from_flyte_idl(pb2_object.target[0]) + if isinstance(pb2_object.target, flyteidl.task_template.Target.Sql) else None, - extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None, - config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None, - k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None, - sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None, ) @@ -643,7 +654,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec( + return flyteidl.admin.TaskSpec( template=self.template.to_flyte_idl(), description=self.docs.to_flyte_idl() if self.docs else None ) @@ -688,9 +699,10 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.Task """ - return _admin_task.Task( + return flyteidl.admin.Task( closure=self.closure.to_flyte_idl(), id=self.id.to_flyte_idl(), + short_description="", ) @classmethod @@ -723,7 +735,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskClosure """ - return _admin_task.TaskClosure(compiled_task=self.compiled_task.to_flyte_idl()) + return flyteidl.admin.TaskClosure(compiled_task=self.compiled_task.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -752,7 +764,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.compiler_pb2.CompiledTask """ - return _compiler.CompiledTask(template=self.template.to_flyte_idl()) + return flyteidl.core.CompiledTask(template=self.template.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -798,17 +810,17 @@ def from_flyte_idl(cls, pb2_object: _core_task.IOStrategy): class DataLoadingConfig(_common.FlyteIdlEntity): - LITERALMAP_FORMAT_PROTO = _core_task.DataLoadingConfig.PROTO - LITERALMAP_FORMAT_JSON = _core_task.DataLoadingConfig.JSON - LITERALMAP_FORMAT_YAML = _core_task.DataLoadingConfig.YAML - _LITERALMAP_FORMATS = frozenset([LITERALMAP_FORMAT_JSON, LITERALMAP_FORMAT_PROTO, LITERALMAP_FORMAT_YAML]) + LITERALMAP_FORMAT_PROTO = flyteidl.data_loading_config.LiteralMapFormat.Proto + LITERALMAP_FORMAT_JSON = flyteidl.data_loading_config.LiteralMapFormat.Json + LITERALMAP_FORMAT_YAML = flyteidl.data_loading_config.LiteralMapFormat.Yaml + _LITERALMAP_FORMATS = [LITERALMAP_FORMAT_JSON, LITERALMAP_FORMAT_PROTO, LITERALMAP_FORMAT_YAML] def __init__( self, input_path: str, output_path: str, enabled: bool = True, - format: _core_task.DataLoadingConfig.LiteralMapFormat = LITERALMAP_FORMAT_PROTO, + format: flyteidl.core.DataLoadingConfig.format = LITERALMAP_FORMAT_PROTO, io_strategy: IOStrategy = None, ): if format not in self._LITERALMAP_FORMATS: @@ -821,8 +833,8 @@ def __init__( self._format = format self._io_strategy = io_strategy - def to_flyte_idl(self) -> _core_task.DataLoadingConfig: - return _core_task.DataLoadingConfig( + def to_flyte_idl(self) -> flyteidl.core.DataLoadingConfig: + return flyteidl.core.DataLoadingConfig( input_path=self._input_path, output_path=self._output_path, format=self._format, @@ -839,7 +851,7 @@ def from_flyte_idl(cls, pb2: _core_task.DataLoadingConfig) -> "DataLoadingConfig output_path=pb2.output_path, enabled=pb2.enabled, format=pb2.format, - io_strategy=IOStrategy.from_flyte_idl(pb2.io_strategy) if pb2.HasField("io_strategy") else None, + io_strategy=IOStrategy.from_flyte_idl(pb2.io_strategy) if pb2.io_strategy else None, ) @@ -929,14 +941,16 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.Container """ - return _core_task.Container( + return flyteidl.core.Container( image=self.image, command=self.command, args=self.args, resources=self.resources.to_flyte_idl(), - env=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.env.items()], - config=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.config.items()], + env=[flyteidl.core.KeyValuePair(key=k, value=v) for k, v in self.env.items()], + config=[flyteidl.core.KeyValuePair(key=k, value=v) for k, v in self.config.items()], data_config=self._data_loading_config.to_flyte_idl() if self._data_loading_config else None, + ports=[], + architecture=0, ) @classmethod @@ -953,7 +967,7 @@ def from_flyte_idl(cls, pb2_object): env={kv.key: kv.value for kv in pb2_object.env}, config={kv.key: kv.value for kv in pb2_object.config}, data_loading_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) - if pb2_object.HasField("data_config") + if pb2_object.data_config else None, ) @@ -1019,7 +1033,7 @@ def data_config(self) -> typing.Optional[DataLoadingConfig]: def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl(), - pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, + pod_spec=self.pod_spec or None, data_config=self.data_config.to_flyte_idl() if self.data_config else None, ) @@ -1027,10 +1041,8 @@ def to_flyte_idl(self) -> _core_task.K8sPod: def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): return cls( metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata), - pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None, - data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) - if pb2_object.HasField("data_config") - else None, + pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.pod_spec else None, + data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) if pb2_object.data_config else None, ) diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 23f818e7a6..bbdfc3f422 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -1,10 +1,8 @@ -import json import typing from typing import Dict +import flyteidl_rust as flyteidl from flyteidl.core import types_pb2 as _types_pb2 -from google.protobuf import json_format as _json_format -from google.protobuf import struct_pb2 as _struct from flytekit.models import common as _common from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel @@ -12,16 +10,16 @@ class SimpleType(object): - NONE = _types_pb2.NONE - INTEGER = _types_pb2.INTEGER - FLOAT = _types_pb2.FLOAT - STRING = _types_pb2.STRING - BOOLEAN = _types_pb2.BOOLEAN - DATETIME = _types_pb2.DATETIME - DURATION = _types_pb2.DURATION - BINARY = _types_pb2.BINARY - ERROR = _types_pb2.ERROR - STRUCT = _types_pb2.STRUCT + NONE = int(0) # flyteidl.core.SimpleType.None + INTEGER = int(flyteidl.core.SimpleType.Integer) + FLOAT = int(flyteidl.core.SimpleType.Float) + STRING = int(flyteidl.core.SimpleType.String) + BOOLEAN = int(flyteidl.core.SimpleType.Boolean) + DATETIME = int(flyteidl.core.SimpleType.Datetime) + DURATION = int(flyteidl.core.SimpleType.Duration) + BINARY = int(flyteidl.core.SimpleType.Binary) + ERROR = int(flyteidl.core.SimpleType.Error) + STRUCT = int(flyteidl.core.SimpleType.Struct) class SchemaType(_common.FlyteIdlEntity): @@ -128,7 +126,7 @@ class TypeStructure(_common.FlyteIdlEntity): Models _types_pb2.TypeStructure """ - def __init__(self, tag: str, dataclass_type: Dict[str, "LiteralType"] = None): + def __init__(self, tag: str, dataclass_type: Dict[str, "flyteidl.core.LiteralType"] = None): self._tag = tag self._dataclass_type = dataclass_type @@ -137,11 +135,11 @@ def tag(self) -> str: return self._tag @property - def dataclass_type(self) -> Dict[str, "LiteralType"]: + def dataclass_type(self) -> Dict[str, "flyteidl.core.LiteralType"]: return self._dataclass_type - def to_flyte_idl(self) -> _types_pb2.TypeStructure: - return _types_pb2.TypeStructure( + def to_flyte_idl(self) -> flyteidl.core.TypeStructure: + return flyteidl.core.TypeStructure( tag=self._tag, dataclass_type={k: v.to_flyte_idl() for k, v in self._dataclass_type.items()} if self._dataclass_type is not None @@ -358,23 +356,24 @@ def to_flyte_idl(self): :rtype: flyteidl.core.types_pb2.LiteralType """ - if self.metadata is not None: - metadata = _json_format.Parse(json.dumps(self.metadata), _struct.Struct()) - else: - metadata = None - - t = _types_pb2.LiteralType( - simple=self.simple if self.simple is not None else None, - schema=self.schema.to_flyte_idl() if self.schema is not None else None, - collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None, - map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None, - blob=self.blob.to_flyte_idl() if self.blob is not None else None, - enum_type=self.enum_type.to_flyte_idl() if self.enum_type else None, - union_type=self.union_type.to_flyte_idl() if self.union_type else None, - structured_dataset_type=self.structured_dataset_type.to_flyte_idl() - if self.structured_dataset_type - else None, - metadata=metadata, + # if self.metadata is not None: + # metadata = _json_format.Parse(json.dumps(self.metadata), flyteidl.wkt.Struct()) + # else: + # metadata = None + + t = flyteidl.core.LiteralType( + type=flyteidl.literal_type.Type.Simple(int(str(self.simple))) if self.simple else None, + # simple=self.simple if self.simple is not None else None, + # schema=self.schema.to_flyte_idl() if self.schema is not None else None, + # collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None, + # map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None, + # blob=self.blob.to_flyte_idl() if self.blob is not None else None, + # enum_type=self.enum_type.to_flyte_idl() if self.enum_type else None, + # union_type=self.union_type.to_flyte_idl() if self.union_type else None, + # structured_dataset_type=self.structured_dataset_type.to_flyte_idl() + # if self.structured_dataset_type + # else None, + metadata=self.metadata, annotation=self.annotation.to_flyte_idl() if self.annotation else None, structure=self.structure.to_flyte_idl() if self.structure else None, ) @@ -388,24 +387,24 @@ def from_flyte_idl(cls, proto): """ collection_type = None map_value_type = None - if proto.HasField("collection_type"): - collection_type = LiteralType.from_flyte_idl(proto.collection_type) - if proto.HasField("map_value_type"): - map_value_type = LiteralType.from_flyte_idl(proto.map_value_type) + if isinstance(proto, flyteidl.literal_type.Type.CollectionType): + collection_type = cls.from_flyte_idl(proto.collection_type) + if isinstance(proto, flyteidl.literal_type.Type.MapValueType): + map_value_type = cls.from_flyte_idl(proto.map_value_type) return cls( - simple=proto.simple if proto.HasField("simple") else None, - schema=SchemaType.from_flyte_idl(proto.schema) if proto.HasField("schema") else None, - collection_type=collection_type, - map_value_type=map_value_type, - blob=_core_types.BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None, - enum_type=_core_types.EnumType.from_flyte_idl(proto.enum_type) if proto.HasField("enum_type") else None, - union_type=UnionType.from_flyte_idl(proto.union_type) if proto.HasField("union_type") else None, - structured_dataset_type=StructuredDatasetType.from_flyte_idl(proto.structured_dataset_type) - if proto.HasField("structured_dataset_type") - else None, - metadata=_json_format.MessageToDict(proto.metadata) or None, - structure=TypeStructure.from_flyte_idl(proto.structure) if proto.HasField("structure") else None, - annotation=TypeAnnotationModel.from_flyte_idl(proto.annotation) if proto.HasField("annotation") else None, + simple=proto.type if isinstance(proto.type, flyteidl.literal_type.Type.Simple) else None, + # schema=SchemaType.from_flyte_idl(proto.schema) if proto.HasField("schema") else None, + collection_type=collection_type or None, + map_value_type=map_value_type or None, + # blob=_core_types.BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None, + # enum_type=_core_types.EnumType.from_flyte_idl(proto.enum_type) if proto.HasField("enum_type") else None, + # union_type=UnionType.from_flyte_idl(proto.union_type) if proto.HasField("union_type") else None, + # structured_dataset_type=StructuredDatasetType.from_flyte_idl(proto.structured_dataset_type) + # if proto.HasField("structured_dataset_type") + # else None, + metadata=proto.metadata if proto.metadata else None, # _json_format.MessageToDict(proto.metadata) or None, + structure=TypeStructure.from_flyte_idl(proto.structure) if proto.structure else None, + annotation=TypeAnnotationModel.from_flyte_idl(proto.annotation) if proto.annotation else None, ) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index fd78d4c3c4..657727fd67 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -174,7 +174,7 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: task_type_version=base_model.task_type_version, ) # Override the newly generated name if one exists in the base model - if not base_model.id.is_empty: + if len(str(base_model.id)) > 0: # not is_empty t._id = base_model.id return t diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 301ce4b5fb..a7187e6e26 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -241,7 +241,7 @@ def context(self) -> FlyteContext: def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" if not self._client_initialized: - self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs) + self._client = SynchronousFlyteClient(self.config.platform.endpoint, **self._kwargs) self._client_initialized = True return self._client @@ -351,6 +351,8 @@ def fetch_task(self, project: str = None, domain: str = None, name: str = None, version, ) admin_task = self.client.get_task(task_id) + admin_task = self.client.get_task(task_id) + flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) flyte_task.template._id = task_id return flyte_task