Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Refactor to use binding flyteidl-rust's RawSynchronousFlyteClient #2560

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing

import flyteidl_rust as flyteidl
from flyteidl.admin import common_pb2 as _common_pb2
from flyteidl.admin import execution_pb2 as _execution_pb2
from flyteidl.admin import launch_plan_pb2 as _launch_plan_pb2
Expand All @@ -9,13 +10,11 @@
from flyteidl.admin import project_domain_attributes_pb2 as _project_domain_attributes_pb2
from flyteidl.admin import project_pb2 as _project_pb2
from flyteidl.admin import task_execution_pb2 as _task_execution_pb2
from flyteidl.admin import task_pb2 as _task_pb2
from flyteidl.admin import workflow_attributes_pb2 as _workflow_attributes_pb2
from flyteidl.admin import workflow_pb2 as _workflow_pb2
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from google.protobuf.duration_pb2 import Duration

from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient
from flytekit.models import common as _common
from flytekit.models import execution as _execution
from flytekit.models import filters as _filters
Expand All @@ -29,7 +28,7 @@
from flytekit.models.core import identifier as _identifier


class SynchronousFlyteClient(_RawSynchronousFlyteClient):
class SynchronousFlyteClient(flyteidl.RawSynchronousFlyteClient):
"""
This is a low-level client that users can use to make direct gRPC service calls to the control plane. See the
:std:doc:`service spec <idl:protos/docs/service/index>`. This is more user-friendly interface than the
Expand Down Expand Up @@ -75,7 +74,7 @@ def create_task(self, task_identifer, task_spec):
:raises grpc.RpcError:
"""
super(SynchronousFlyteClient, self).create_task(
_task_pb2.TaskCreateRequest(id=task_identifer.to_flyte_idl(), spec=task_spec.to_flyte_idl())
flyteidl.admin.TaskCreateRequest(id=task_identifer.to_flyte_idl(), spec=task_spec.to_flyte_idl())
)

def list_task_ids_paginated(self, project, domain, limit=100, token=None, sort_by=None):
Expand Down Expand Up @@ -173,7 +172,7 @@ def get_task(self, id):
:rtype: flytekit.models.task.Task
"""
return _task.Task.from_flyte_idl(
super(SynchronousFlyteClient, self).get_task(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl()))
super(SynchronousFlyteClient, self).get_task(flyteidl.admin.ObjectGetRequest(id=id.to_flyte_idl()))
)

####################################################################################################################
Expand Down Expand Up @@ -551,12 +550,13 @@ def create_execution(self, project, domain, name, execution_spec, inputs):
return _identifier.WorkflowExecutionIdentifier.from_flyte_idl(
super(SynchronousFlyteClient, self)
.create_execution(
_execution_pb2.ExecutionCreateRequest(
flyteidl.admin.ExecutionCreateRequest(
project=project,
domain=domain,
name=name,
spec=execution_spec.to_flyte_idl(),
inputs=inputs.to_flyte_idl(),
org="",
)
)
.id
Expand All @@ -582,7 +582,7 @@ def get_execution(self, id):
"""
return _execution.Execution.from_flyte_idl(
super(SynchronousFlyteClient, self).get_execution(
_execution_pb2.WorkflowExecutionGetRequest(id=id.to_flyte_idl())
flyteidl.admin.WorkflowExecutionGetRequest(id=id.to_flyte_idl())
)
)

Expand All @@ -595,7 +595,7 @@ def get_execution_data(self, id):
"""
return _execution.WorkflowExecutionGetDataResponse.from_flyte_idl(
super(SynchronousFlyteClient, self).get_execution_data(
_execution_pb2.WorkflowExecutionGetDataRequest(id=id.to_flyte_idl())
flyteidl.admin.WorkflowExecutionGetDataRequest(id=id.to_flyte_idl())
)
)

Expand Down Expand Up @@ -677,7 +677,7 @@ def get_node_execution(self, node_execution_identifier):
"""
return _node_execution.NodeExecution.from_flyte_idl(
super(SynchronousFlyteClient, self).get_node_execution(
_node_execution_pb2.NodeExecutionGetRequest(id=node_execution_identifier.to_flyte_idl())
flyteidl.admin.NodeExecutionGetRequest(id=node_execution_identifier.to_flyte_idl())
)
)

Expand All @@ -689,7 +689,7 @@ def get_node_execution_data(self, node_execution_identifier) -> _execution.NodeE
"""
return _execution.NodeExecutionGetDataResponse.from_flyte_idl(
super(SynchronousFlyteClient, self).get_node_execution_data(
_node_execution_pb2.NodeExecutionGetDataRequest(id=node_execution_identifier.to_flyte_idl())
flyteidl.admin.NodeExecutionGetDataRequest(id=node_execution_identifier.to_flyte_idl())
)
)

Expand All @@ -715,7 +715,7 @@ def list_node_executions(
:rtype: list[flytekit.models.node_execution.NodeExecution], Text
"""
exec_list = super(SynchronousFlyteClient, self).list_node_executions_paginated(
_node_execution_pb2.NodeExecutionListRequest(
flyteidl.admin.NodeExecutionListRequest(
workflow_execution_id=workflow_execution_identifier.to_flyte_idl(),
limit=limit,
token=token,
Expand Down Expand Up @@ -775,7 +775,7 @@ def get_task_execution(self, id):
"""
return _task_execution.TaskExecution.from_flyte_idl(
super(SynchronousFlyteClient, self).get_task_execution(
_task_execution_pb2.TaskExecutionGetRequest(id=id.to_flyte_idl())
flyteidl.admin.TaskExecutionGetRequest(id=id.to_flyte_idl())
)
)

Expand All @@ -788,7 +788,7 @@ def get_task_execution_data(self, task_execution_identifier):
"""
return _execution.TaskExecutionGetDataResponse.from_flyte_idl(
super(SynchronousFlyteClient, self).get_task_execution_data(
_task_execution_pb2.TaskExecutionGetDataRequest(id=task_execution_identifier.to_flyte_idl())
flyteidl.admin.TaskExecutionGetDataRequest(id=task_execution_identifier.to_flyte_idl())
)
)

Expand Down Expand Up @@ -1010,14 +1010,15 @@ def get_upload_signed_url(
expires_in_pb = Duration()
expires_in_pb.FromTimedelta(expires_in)
return super(SynchronousFlyteClient, self).create_upload_location(
_data_proxy_pb2.CreateUploadLocationRequest(
flyteidl.service.CreateUploadLocationRequest(
project=project,
domain=domain,
content_md5=content_md5,
filename=filename,
expires_in=expires_in_pb,
filename_root=filename_root,
filename_root=filename_root or "",
add_content_md5_metadata=add_content_md5_metadata,
org="",
)
)
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,12 @@ def literal_type_to_click_type(lt: LiteralType, python_type: typing.Type) -> cli
Converts a Flyte LiteralType given a python_type to a click.ParamType
"""
if lt.simple:
if lt.simple == SimpleType.STRUCT:
if int(str(lt.simple)) == SimpleType.STRUCT:
ct = JsonParamType(python_type)
ct.name = f"JSON object {python_type.__name__}"
return ct
if lt.simple in SIMPLE_TYPE_CONVERTER:
return SIMPLE_TYPE_CONVERTER[lt.simple]
if int(str(lt.simple)) in SIMPLE_TYPE_CONVERTER:
return SIMPLE_TYPE_CONVERTER[int(str(lt.simple))]
raise NotImplementedError(f"Type {lt.simple} is not supported in pyflyte run")

if lt.enum_type:
Expand Down
10 changes: 4 additions & 6 deletions flytekit/models/annotation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
from typing import Any, Dict

from flyteidl.core import types_pb2 as _types_pb2
import flyteidl_rust as flyteidl
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct


class TypeAnnotation:
Expand All @@ -19,17 +18,16 @@ def annotations(self) -> Dict[str, Any]:
"""
return self._annotations

def to_flyte_idl(self) -> _types_pb2.TypeAnnotation:
def to_flyte_idl(self) -> flyteidl.core.TypeAnnotation:
"""
:rtype: flyteidl.core.types_pb2.TypeAnnotation
"""

if self._annotations is not None:
annotations = _json_format.Parse(json.dumps(self.annotations), _struct.Struct())
annotations = _json_format.Parse(json.dumps(self.annotations), flyteidl.wkt.Struct())
else:
annotations = None

return _types_pb2.TypeAnnotation(
return flyteidl.core.TypeAnnotation(
annotations=annotations,
)

Expand Down
35 changes: 22 additions & 13 deletions flytekit/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import re
from typing import Dict

import flyteidl_rust as flytedidl
from flyteidl.admin import common_pb2 as _common_pb2
from flyteidl.core import literals_pb2 as _literals_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct

Expand Down Expand Up @@ -302,11 +302,20 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.common_pb2.Notification
"""
return _common_pb2.Notification(
_type = None
if self.email:
_type = flytedidl.notification.Type.Email(self.email)
elif self.pager_duty:
_type = flytedidl.notification.Type.PagerDuty(self.pager_duty)
elif self.slack:
_type = flytedidl.notification.Type.Slack(self.slack)

return flytedidl.notification.Type(
phases=self.phases,
email=self.email.to_flyte_idl() if self.email else None,
pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None,
slack=self.slack.to_flyte_idl() if self.slack else None,
type=_type,
# email=self.email.to_flyte_idl() if self.email else None,
# pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None,
# slack=self.slack.to_flyte_idl() if self.slack else None,
)

@classmethod
Expand Down Expand Up @@ -340,7 +349,7 @@ def to_flyte_idl(self):
"""
:rtype: dict[Text, Text]
"""
return _common_pb2.Labels(values={k: v for k, v in self.values.items()})
return flytedidl.admin.Labels(values={k: v for k, v in self.values.items()})

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand Down Expand Up @@ -368,7 +377,7 @@ def to_flyte_idl(self):
"""
:rtype: _common_pb2.Annotations
"""
return _common_pb2.Annotations(values={k: v for k, v in self.values.items()})
return flytedidl.admin.Annotations(values={k: v for k, v in self.values.items()})

@classmethod
def from_flyte_idl(cls, pb2_object):
Expand Down Expand Up @@ -451,9 +460,9 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.launch_plan_pb2.Auth
"""
return _common_pb2.AuthRole(
assumable_iam_role=self.assumable_iam_role if self.assumable_iam_role else None,
kubernetes_service_account=self.kubernetes_service_account if self.kubernetes_service_account else None,
return flytedidl.admin.AuthRole(
assumable_iam_role=self.assumable_iam_role or "",
kubernetes_service_account=self.kubernetes_service_account or "",
)

@classmethod
Expand Down Expand Up @@ -483,7 +492,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.common_pb2.Auth
"""
return _common_pb2.RawOutputDataConfig(output_location_prefix=self.output_location_prefix)
return flytedidl.admin.RawOutputDataConfig(output_location_prefix=self.output_location_prefix)

@classmethod
def from_flyte_idl(cls, pb2):
Expand All @@ -498,8 +507,8 @@ def __init__(self, envs: Dict[str, str]):
def envs(self) -> Dict[str, str]:
return self._envs

def to_flyte_idl(self) -> _common_pb2.Envs:
return _common_pb2.Envs(values=[_literals_pb2.KeyValuePair(key=k, value=v) for k, v in self.envs.items()])
def to_flyte_idl(self) -> flytedidl.admin.Envs:
return flytedidl.admin.Envs(values=[flytedidl.core.KeyValuePair(key=k, value=v) for k, v in self.envs.items()])

@classmethod
def from_flyte_idl(cls, pb2: _common_pb2.Envs) -> _common_pb2.Envs:
Expand Down
33 changes: 26 additions & 7 deletions flytekit/models/core/identifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import flyteidl_rust as flyteidl
from flyteidl.core import identifier_pb2 as identifier_pb2

from flytekit.models import common as _common_models


class ResourceType(object):
UNSPECIFIED = identifier_pb2.UNSPECIFIED
TASK = identifier_pb2.TASK
WORKFLOW = identifier_pb2.WORKFLOW
LAUNCH_PLAN = identifier_pb2.LAUNCH_PLAN
UNSPECIFIED = int(flyteidl.core.ResourceType.Unspecified)
TASK = int(flyteidl.core.ResourceType.Task)
WORKFLOW = int(flyteidl.core.ResourceType.Workflow)
LAUNCH_PLAN = int(flyteidl.core.ResourceType.LaunchPlan)


class Identifier(_common_models.FlyteIdlEntity):
Expand All @@ -24,6 +25,7 @@ def __init__(self, resource_type, project, domain, name, version):
self._domain = domain
self._name = name
self._version = version
self._org = ""

@property
def resource_type(self):
Expand Down Expand Up @@ -64,16 +66,24 @@ def version(self):
"""
return self._version

@property
def org(self):
"""
:rtype: Text
"""
return self._org

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.Identifier
"""
return identifier_pb2.Identifier(
return flyteidl.core.Identifier(
resource_type=self.resource_type,
project=self.project,
domain=self.domain,
name=self.name,
version=self.version,
org=self.org,
)

@classmethod
Expand Down Expand Up @@ -107,6 +117,7 @@ def __init__(self, project, domain, name):
self._project = project
self._domain = domain
self._name = name
self._org = ""

@property
def project(self):
Expand All @@ -129,14 +140,22 @@ def name(self):
"""
return self._name

@property
def org(self):
"""
:rtype: Text
"""
return self._org

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier
"""
return identifier_pb2.WorkflowExecutionIdentifier(
return flyteidl.core.WorkflowExecutionIdentifier(
project=self.project,
domain=self.domain,
name=self.name,
org=self.org,
)

@classmethod
Expand Down Expand Up @@ -179,7 +198,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier
"""
return identifier_pb2.NodeExecutionIdentifier(
return flyteidl.core.NodeExecutionIdentifier(
node_id=self.node_id,
execution_id=self.execution_id.to_flyte_idl(),
)
Expand Down
Loading