diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index c5c10d20d6a..3a516fc1d2c 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -989,14 +989,14 @@ def get_upload_signed_url( filename_root: typing.Optional[str] = None, ) -> _data_proxy_pb2.CreateUploadLocationResponse: """ - Get a signed url to be used during fast registration. + Get a signed url to be used during fast registration - :param str project: Project to create the upload location for - :param str domain: Domain to create the upload location for - :param bytes content_md5: ContentMD5 restricts the upload location to the specific MD5 provided. The content_md5 + :param project: Project to create the upload location for + :param domain: Domain to create the upload location for + :param content_md5: ContentMD5 restricts the upload location to the specific MD5 provided. The content_md5 will also appear in the generated path. - :param str filename: [Optional] If provided this specifies a desired suffix for the generated location - :param datetime.timedelta expires_in: [Optional] If provided this defines a requested expiration duration for + :param filename: If provided this specifies a desired suffix for the generated location + :param expires_in: If provided this defines a requested expiration duration for the generated url :param filename_root: If provided will be used as the root of the filename. If not, Admin will use a hash This option is useful when uploading a series of files that you want to be grouped together. diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py new file mode 100644 index 00000000000..6c709f59a18 --- /dev/null +++ b/flytekit/core/artifact.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +import datetime +import typing +from datetime import timedelta +from typing import Optional, Union + +from flyteidl.core import artifact_id_pb2 as art_id +from google.protobuf.timestamp_pb2 import Timestamp + +TIME_PARTITION_KWARG = "time_partition" + + +class InputsBase(object): + """ + A class to provide better partition semantics + Used for invoking an Artifact to bind partition keys to input values. + If there's a good reason to use a metaclass in the future we can, but a simple instance suffices for now + """ + + def __getattr__(self, name: str) -> art_id.InputBindingData: + return art_id.InputBindingData(var=name) + + +Inputs = InputsBase() + + +class ArtifactIDSpecification(object): + """ + This is a special object that helps specify how Artifacts are to be created. See the comment in the + call function of the main Artifact class. Also see the handling code in transform_variable_map for more + information. There's a limited set of information that we ultimately need in a TypedInterface, so it + doesn't make sense to carry the full Artifact object around. This object should be sufficient, despite + having a pointer to the main artifact. + """ + + def __init__(self, a: Artifact): + self.artifact = a + self.partitions: Optional[Partitions] = None + self.time_partition: Optional[TimePartition] = None + + # todo: add time partition arg hint + def __call__(self, *args, **kwargs): + return self.bind_partitions(*args, **kwargs) + + def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification: + # See the parallel function in the main Artifact class for more information. + if len(args) > 0: + raise ValueError("Cannot set partition values by position") + + if TIME_PARTITION_KWARG in kwargs: + if not self.artifact.time_partitioned: + raise ValueError("Cannot bind time partition to non-time partitioned artifact") + p = kwargs[TIME_PARTITION_KWARG] + if isinstance(p, datetime.datetime): + t = Timestamp() + t.FromDatetime(p) + self.time_partition = TimePartition(value=art_id.LabelValue(time_value=t)) + elif isinstance(p, art_id.InputBindingData): + self.time_partition = TimePartition(value=art_id.LabelValue(input_binding=p)) + else: + raise ValueError(f"Time partition needs to be input binding data or static string, not {p}") + # Given the context, shouldn't need to set further reference_artifacts. + + del kwargs[TIME_PARTITION_KWARG] + + if len(kwargs) > 0: + p = Partitions(None) + # k is the partition key, v should be static, or an input to the task or workflow + for k, v in kwargs.items(): + if not self.artifact.partition_keys or k not in self.artifact.partition_keys: + raise ValueError(f"Partition key {k} not found in {self.artifact.partition_keys}") + if isinstance(v, art_id.InputBindingData): + p.partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k) + elif isinstance(v, str): + p.partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k) + else: + raise ValueError(f"Partition key {k} needs to be input binding data or static string, not {v}") + # Given the context, shouldn't need to set further reference_artifacts. + self.partitions = p + + return self + + def to_partial_artifact_id(self) -> art_id.ArtifactID: + # This function should only be called by transform_variable_map + artifact_id = self.artifact.to_id_idl() + # Use the partitions from this object, but replacement is not allowed by protobuf, so generate new object + p = Serializer.partitions_to_idl(self.partitions) + tp = None + if self.artifact.time_partitioned: + if not self.time_partition: + raise ValueError( + f"Artifact {artifact_id.artifact_key} requires a time partition, but it hasn't been bound." + ) + tp = self.time_partition.to_flyte_idl() + + if self.artifact.partition_keys: + required = len(self.artifact.partition_keys) + fulfilled = len(p.value) if p else 0 + if required != fulfilled: + raise ValueError( + f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are " + f"bound." + ) + artifact_id = art_id.ArtifactID( + artifact_key=artifact_id.artifact_key, + partitions=p, + time_partition=tp, + version=artifact_id.version, # this should almost never be set since setting it + # hardcodes the query to one version + ) + return artifact_id + + +class ArtifactQuery(object): + def __init__( + self, + artifact: Artifact, + name: str, + project: Optional[str] = None, + domain: Optional[str] = None, + time_partition: Optional[TimePartition] = None, + partitions: Optional[Partitions] = None, + tag: Optional[str] = None, + ): + if not name: + raise ValueError("Cannot create query without name") + + # So normally, if you just do MyData.query(partitions="region": Inputs.region), it will just + # use the input value to fill in the partition. But if you do + # MyData.query(region=OtherArtifact.partitions.region) + # then you now have a dependency on the other artifact. This list keeps track of all the other Artifacts you've + # referenced. + self.artifact = artifact + bindings: typing.List[Artifact] = [] + if time_partition: + if time_partition.reference_artifact and time_partition.reference_artifact is not artifact: + bindings.append(time_partition.reference_artifact) + if partitions and partitions.partitions: + for k, v in partitions.partitions.items(): + if v.reference_artifact and v.reference_artifact is not artifact: + bindings.append(v.reference_artifact) + + self.name = name + self.project = project + self.domain = domain + self.time_partition = time_partition + self.partitions = partitions + self.tag = tag + self.bindings = bindings + + def to_flyte_idl( + self, + **kwargs, + ) -> art_id.ArtifactQuery: + return Serializer.artifact_query_to_idl(self, **kwargs) + + +class TimePartition(object): + def __init__( + self, + value: Union[art_id.LabelValue, art_id.InputBindingData, str, datetime.datetime, None], + op: Optional[str] = None, + other: Optional[timedelta] = None, + ): + if isinstance(value, str): + raise ValueError(f"value to a time partition shouldn't be a str {value}") + elif isinstance(value, datetime.datetime): + t = Timestamp() + t.FromDatetime(value) + value = art_id.LabelValue(time_value=t) + elif isinstance(value, art_id.InputBindingData): + value = art_id.LabelValue(input_binding=value) + # else should already be a LabelValue or None + self.value: art_id.LabelValue = value + self.op = op + self.other = other + self.reference_artifact: Optional[Artifact] = None + + def __add__(self, other: timedelta) -> TimePartition: + tp = TimePartition(self.value, op="+", other=other) + tp.reference_artifact = self.reference_artifact + return tp + + def __sub__(self, other: timedelta) -> TimePartition: + tp = TimePartition(self.value, op="-", other=other) + tp.reference_artifact = self.reference_artifact + return tp + + def to_flyte_idl(self, **kwargs) -> Optional[art_id.TimePartition]: + return Serializer.time_partition_to_idl(self, **kwargs) + + +class Partition(object): + def __init__(self, value: Optional[art_id.LabelValue], name: str): + self.name = name + self.value = value + self.reference_artifact: Optional[Artifact] = None + + +class Partitions(object): + def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.InputBindingData, Partition]]]): + self._partitions = {} + if partitions: + for k, v in partitions.items(): + if isinstance(v, Partition): + self._partitions[k] = v + elif isinstance(v, art_id.InputBindingData): + self._partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k) + else: + self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k) + self.reference_artifact: Optional[Artifact] = None + + @property + def partitions(self) -> Optional[typing.Dict[str, Partition]]: + return self._partitions + + def set_reference_artifact(self, artifact: Artifact): + self.reference_artifact = artifact + if self.partitions: + for p in self.partitions.values(): + p.reference_artifact = artifact + + def __getattr__(self, item): + if self.partitions and item in self.partitions: + return self.partitions[item] + raise AttributeError(f"Partition {item} not found in {self}") + + def to_flyte_idl(self, **kwargs) -> Optional[art_id.Partitions]: + return Serializer.partitions_to_idl(self, **kwargs) + + +class Artifact(object): + """ + An Artifact is effectively just a metadata layer on top of data that exists in Flyte. Most data of interest + will be the output of tasks and workflows. The other category is user uploads. + + This Python class has limited purpose, as a way for users to specify that tasks/workflows create Artifacts + and the manner (i.e. name, partitions) in which they are created. + + Control creation parameters at task/workflow execution time :: + + @task + def t1() -> Annotated[nn.Module, Artifact(name="my.artifact.name")]: + ... + """ + + def __init__( + self, + project: Optional[str] = None, + domain: Optional[str] = None, + name: Optional[str] = None, + version: Optional[str] = None, + time_partitioned: bool = False, + time_partition: Optional[TimePartition] = None, + partition_keys: Optional[typing.List[str]] = None, + partitions: Optional[Union[Partitions, typing.Dict[str, str]]] = None, + ): + """ + :param project: Should not be directly user provided, the project/domain will come from the project/domain of + the execution that produced the output. These values will be filled in automatically when retrieving however. + :param domain: See above. + :param name: The name of the Artifact. This should be user provided. + :param version: Version of the Artifact, typically the execution ID, plus some additional entropy. + Not user provided. + :param time_partitioned: Whether or not this Artifact will have a time partition. + :param partition_keys: This is a list of keys that will be used to partition the Artifact. These are not the + values. Values are set via a () on the artifact and will end up in the partition_values field. + :param partitions: This is a dictionary of partition keys to values. + """ + if not name: + raise ValueError("Can't instantiate an Artifact without a name.") + self.project = project + self.domain = domain + self.name = name + self.version = version + self.time_partitioned = time_partitioned + self._time_partition = None + if time_partition: + self._time_partition = time_partition + self._time_partition.reference_artifact = self + self.partition_keys = partition_keys + self._partitions: Optional[Partitions] = None + if partitions: + if isinstance(partitions, dict): + self._partitions = Partitions(partitions) + self.partition_keys = list(partitions.keys()) + elif isinstance(partitions, Partitions): + self._partitions = partitions + if not partitions.partitions: + raise ValueError("Partitions must be non-empty") + self.partition_keys = list(partitions.partitions.keys()) + else: + raise ValueError(f"Partitions must be a dict or Partitions object, not {type(partitions)}") + self._partitions.set_reference_artifact(self) + if not partitions and partition_keys: + # this should be the only time where we create Partition objects with None + p = {k: Partition(None, name=k) for k in partition_keys} + self._partitions = Partitions(p) + self._partitions.set_reference_artifact(self) + + def __call__(self, *args, **kwargs) -> ArtifactIDSpecification: + """ + This __call__ should only ever happen in the context of a task or workflow's output, to be + used in an Annotated[] call. The other styles will go through different call functions. + """ + # Can't guarantee the order in which time/non-time partitions are bound so create the helper + # object and invoke the function there. + partial_id = ArtifactIDSpecification(self) + return partial_id.bind_partitions(*args, **kwargs) + + @property + def partitions(self) -> Optional[Partitions]: + return self._partitions + + @property + def time_partition(self) -> TimePartition: + if not self.time_partitioned: + raise ValueError(f"Artifact {self.name} is not time partitioned") + if not self._time_partition and self.time_partitioned: + self._time_partition = TimePartition(None) + self._time_partition.reference_artifact = self + return self._time_partition + + def __str__(self): + tp_str = f" time partition={self.time_partition}\n" if self.time_partitioned else "" + return ( + f"Artifact: project={self.project}, domain={self.domain}, name={self.name}, version={self.version}\n" + f" name={self.name}\n" + f" partitions={self.partitions}\n" + f"{tp_str}" + ) + + def __repr__(self): + return self.__str__() + + def query( + self, + project: Optional[str] = None, + domain: Optional[str] = None, + time_partition: Optional[Union[datetime.datetime, TimePartition, art_id.InputBindingData]] = None, + partitions: Optional[Union[typing.Dict[str, str], Partitions]] = None, + **kwargs, + ) -> ArtifactQuery: + if self.partition_keys: + fn_args = {"project", "domain", "time_partition", "partitions", "tag"} + k = set(self.partition_keys) + if len(fn_args & k) > 0: + raise ValueError( + f"There are conflicting partition key names {fn_args ^ k}, please rename" + f" use a partitions object" + ) + if partitions and kwargs: + raise ValueError("Please either specify kwargs or a partitions object not both") + + p_obj: Optional[Partitions] = None + if kwargs: + p_obj = Partitions(kwargs) + p_obj.reference_artifact = self # only set top level + + if partitions and isinstance(partitions, dict): + p_obj = Partitions(partitions) + p_obj.reference_artifact = self # only set top level + + tp = None + if time_partition: + if isinstance(time_partition, TimePartition): + tp = time_partition + else: + tp = TimePartition(time_partition) + tp.reference_artifact = self + + tp = tp or (self.time_partition if self.time_partitioned else None) + + aq = ArtifactQuery( + artifact=self, + name=self.name, + project=project or self.project or None, + domain=domain or self.domain or None, + time_partition=tp, + partitions=p_obj or self.partitions, + ) + return aq + + @property + def concrete_artifact_id(self) -> art_id.ArtifactID: + # This property is used when you want to ensure that this is a materialized artifact, all fields are known. + if self.name is None or self.project is None or self.domain is None or self.version is None: + raise ValueError("Cannot create artifact id without name, project, domain, version") + return self.to_id_idl() + + def embed_as_query( + self, + bindings: typing.List[Artifact], + partition: Optional[str] = None, + bind_to_time_partition: Optional[bool] = None, + expr: Optional[str] = None, + ) -> art_id.ArtifactQuery: + """ + This should only be called in the context of a Trigger + :param bindings: The list of artifacts in trigger_on + :param partition: Can embed a time partition + :param bind_to_time_partition: Set to true if you want to bind to a time partition + :param expr: Only valid if there's a time partition. + """ + # Find self in the list, raises ValueError if not there. + idx = bindings.index(self) + aq = art_id.ArtifactQuery( + binding=art_id.ArtifactBindingData( + index=idx, + partition_key=partition, + bind_to_time_partition=bind_to_time_partition, + transform=str(expr) if expr and (partition or bind_to_time_partition) else None, + ) + ) + + return aq + + def to_id_idl(self) -> art_id.ArtifactID: + """ + Converts this object to the IDL representation. + This is here instead of translator because it's in the interface, a relatively simple proto object + that's exposed to the user. + """ + p = Serializer.partitions_to_idl(self.partitions) + tp = Serializer.time_partition_to_idl(self.time_partition) if self.time_partitioned else None + + i = art_id.ArtifactID( + artifact_key=art_id.ArtifactKey( + project=self.project, + domain=self.domain, + name=self.name, + ), + version=self.version, + partitions=p, + time_partition=tp, + ) + + return i + + +class ArtifactSerializationHandler(typing.Protocol): + """ + This protocol defines the interface for serializing artifact-related entities down to Flyte IDL. + """ + + def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: + ... + + def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: + ... + + def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: + ... + + +class DefaultArtifactSerializationHandler(ArtifactSerializationHandler): + def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: + if p and p.partitions: + pp = {} + for k, v in p.partitions.items(): + if v.value is None: + # For specifying partitions in the Variable partial id + pp[k] = art_id.LabelValue(static_value="") + else: + pp[k] = v.value + return art_id.Partitions(value=pp) + return None + + def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: + if tp: + return art_id.TimePartition(value=tp.value) + return None + + def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: + ak = art_id.ArtifactKey( + name=aq.name, + project=aq.project, + domain=aq.domain, + ) + + p = self.partitions_to_idl(aq.partitions) + tp = self.time_partition_to_idl(aq.time_partition) + + i = art_id.ArtifactID( + artifact_key=ak, + partitions=p, + time_partition=tp, + ) + + aq = art_id.ArtifactQuery( + artifact_id=i, + ) + + return aq + + +class Serializer(object): + serializer: ArtifactSerializationHandler = DefaultArtifactSerializationHandler() + + @classmethod + def register_serializer(cls, serializer: ArtifactSerializationHandler): + cls.serializer = serializer + + @classmethod + def partitions_to_idl(cls, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: + return cls.serializer.partitions_to_idl(p, **kwargs) + + @classmethod + def time_partition_to_idl(cls, tp: TimePartition, **kwargs) -> Optional[art_id.TimePartition]: + return cls.serializer.time_partition_to_idl(tp, **kwargs) + + @classmethod + def artifact_query_to_idl(cls, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: + return cls.serializer.artifact_query_to_idl(aq, **kwargs) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index a548f6a49b5..24cfd24581c 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,15 +7,17 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast +from flyteidl.core import artifact_id_pb2 as art_id from typing_extensions import get_args, get_origin, get_type_hints from flytekit.core import context_manager +from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models -from flytekit.models.literals import Void +from flytekit.models.literals import Literal, Scalar, Void T = typing.TypeVar("T") @@ -202,6 +204,7 @@ def transform_inputs_to_parameters( ) -> _interface_models.ParameterMap: """ Transforms the given interface (with inputs) to a Parameter Map with defaults set + :param ctx: context :param interface: the interface object """ if interface is None or interface.inputs_with_defaults is None: @@ -215,16 +218,20 @@ def transform_inputs_to_parameters( for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): - from flytekit import Literal, Scalar - literal = Literal(scalar=Scalar(none_type=Void())) params[k] = _interface_models.Parameter(var=v, default=literal, required=False) else: - required = _default is None - default_lv = None - if _default is not None: - default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) - params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) + if isinstance(_default, ArtifactQuery): + params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=_default.to_flyte_idl()) + elif isinstance(_default, Artifact): + artifact_id = _default.concrete_artifact_id # may raise + params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id) + else: + required = _default is None + default_lv = None + if _default is not None: + default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) + params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) return _interface_models.ParameterMap(params) @@ -246,9 +253,36 @@ def transform_interface_to_typed_interface( inputs_map = transform_variable_map(interface.inputs, input_descriptions) outputs_map = transform_variable_map(interface.outputs, output_descriptions) + verify_outputs_artifact_bindings(interface.inputs, outputs_map) return _interface_models.TypedInterface(inputs_map, outputs_map) +def verify_outputs_artifact_bindings(inputs: Dict[str, type], outputs: Dict[str, _interface_models.Variable]): + # collect Artifacts + for k, v in outputs.items(): + # Iterate through output partition values if any and verify that if they're bound to an input, that that input + # actually exists in the interface. + if ( + v.artifact_partial_id + and v.artifact_partial_id.HasField("partitions") + and v.artifact_partial_id.partitions.value + ): + for pk, pv in v.artifact_partial_id.partitions.value.items(): + if pv.HasField("input_binding"): + input_name = pv.input_binding.var + if input_name not in inputs: + raise FlyteValidationException( + f"Output partition {k} is bound to input {input_name} which does not exist in the interface" + ) + if v.artifact_partial_id.HasField("time_partition"): + if v.artifact_partial_id.time_partition.value.HasField("input_binding"): + input_name = v.artifact_partial_id.time_partition.value.input_binding.var + if input_name not in inputs: + raise FlyteValidationException( + f"Output time partition is bound to input {input_name} which does not exist in the interface" + ) + + def transform_types_to_list_of_type( m: Dict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False ) -> Dict[str, type]: @@ -333,21 +367,45 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc def transform_variable_map( variable_map: Dict[str, type], - descriptions: Dict[str, str] = {}, + descriptions: Optional[Dict[str, str]] = None, ) -> Dict[str, _interface_models.Variable]: """ Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a Flyte Variable object with that type. """ res = OrderedDict() + descriptions = descriptions or {} if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) return res +def detect_artifact( + ts: typing.Tuple[typing.Any, ...], +) -> Optional[art_id.ArtifactID]: + """ + If the user wishes to control how Artifacts are created (i.e. naming them, etc.) this is where we pick it up and + store it in the interface. + """ + for t in ts: + if isinstance(t, Artifact): + id_spec = t() + return id_spec.to_partial_artifact_id() + elif isinstance(t, ArtifactIDSpecification): + artifact_id = t.to_partial_artifact_id() + return artifact_id + + return None + + def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable: - return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description) + artifact_id = detect_artifact(get_args(x)) + if artifact_id: + logger.debug(f"Found artifact id spec: {artifact_id}") + return _interface_models.Variable( + type=TypeEngine.to_literal_type(x), description=description, artifact_partial_id=artifact_id + ) def default_output_name(index: int = 0) -> str: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 500a865ade8..a96f83b8ce6 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -311,6 +311,7 @@ def __init__( raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, + additional_metadata: Optional[Any] = None, ): self._name = name self._workflow = workflow @@ -328,6 +329,7 @@ def __init__( self._raw_output_data_config = raw_output_data_config self._max_parallelism = max_parallelism self._security_context = security_context + self._additional_metadata = additional_metadata FlyteEntities.entities.append(self) @@ -418,6 +420,10 @@ def max_parallelism(self) -> Optional[int]: def security_context(self) -> Optional[security.SecurityContext]: return self._security_context + @property + def additional_metadata(self) -> Optional[Any]: + return self._additional_metadata + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: return self.workflow.construct_node_metadata() diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 392f2c45241..108b323a48d 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -24,7 +24,6 @@ from flytekit.core.interface import ( Interface, transform_function_to_interface, - transform_inputs_to_parameters, transform_interface_to_typed_interface, ) from flytekit.core.node import Node @@ -702,7 +701,6 @@ def compile(self, **kwargs): self.compiled = True ctx = FlyteContextManager.current_context() - self._input_parameters = transform_inputs_to_parameters(ctx, self.python_interface) all_nodes = [] prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else "" diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index b0c9fc882a0..f80bfb9e52b 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -1,5 +1,6 @@ import typing +from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import interface_pb2 as _interface_pb2 from flytekit.models import common as _common @@ -8,15 +9,25 @@ class Variable(_common.FlyteIdlEntity): - def __init__(self, type, description): + def __init__( + self, + type, + description, + artifact_partial_id: typing.Optional[art_id.ArtifactID] = None, + artifact_tag: typing.Optional[art_id.ArtifactTag] = None, + ): """ :param flytekit.models.types.LiteralType type: This describes the type of value that must be provided to satisfy this variable. :param Text description: This is a help string that can provide context for what this variable means in relation to a task or workflow. + :param artifact_partial_id: Optional Artifact object to control how the artifact is created when the task runs. + :param artifact_tag: Optional ArtifactTag object to automatically tag things. """ self._type = type self._description = description + self._artifact_partial_id = artifact_partial_id + self._artifact_tag = artifact_tag @property def type(self): @@ -34,21 +45,37 @@ def description(self): """ return self._description + @property + def artifact_partial_id(self) -> typing.Optional[art_id.ArtifactID]: + return self._artifact_partial_id + + @property + def artifact_tag(self) -> typing.Optional[art_id.ArtifactTag]: + return self._artifact_tag + def to_flyte_idl(self): """ :rtype: flyteidl.core.interface_pb2.Variable """ - return _interface_pb2.Variable(type=self.type.to_flyte_idl(), description=self.description) + return _interface_pb2.Variable( + type=self.type.to_flyte_idl(), + description=self.description, + artifact_partial_id=self.artifact_partial_id, + artifact_tag=self.artifact_tag, + ) @classmethod - def from_flyte_idl(cls, variable_proto): + def from_flyte_idl(cls, variable_proto) -> _interface_pb2.Variable: """ :param flyteidl.core.interface_pb2.Variable variable_proto: - :rtype: Variable """ return cls( type=_types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description, + artifact_partial_id=variable_proto.artifact_partial_id + if variable_proto.HasField("artifact_partial_id") + else None, + artifact_tag=variable_proto.artifact_tag if variable_proto.HasField("artifact_tag") else None, ) @@ -121,7 +148,14 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface class Parameter(_common.FlyteIdlEntity): - def __init__(self, var, default=None, required=None): + def __init__( + self, + var, + default=None, + required=None, + artifact_query: typing.Optional[art_id.ArtifactQuery] = None, + artifact_id: typing.Optional[art_id.ArtifactID] = None, + ): """ Declares an input parameter. A parameter is used as input to a launch plan and has the special ability to have a default value or mark itself as required. @@ -129,10 +163,14 @@ def __init__(self, var, default=None, required=None): :param flytekit.models.literals.Literal default: [Optional] Defines a default value that has to match the variable type defined. :param bool required: [Optional] is this value required to be filled in? + :param artifact_query: Specify this to bind to a query instead of a constant. + :param artifact_id: When you want to bind to a known artifact pointer. """ self._var = var self._default = default self._required = required + self._artifact_query = artifact_query + self._artifact_id = artifact_id @property def var(self): @@ -163,7 +201,15 @@ def behavior(self): """ :rtype: T """ - return self._default or self._required + return self._default or self._required or self._artifact_query + + @property + def artifact_query(self) -> typing.Optional[art_id.ArtifactQuery]: + return self._artifact_query + + @property + def artifact_id(self) -> typing.Optional[art_id.ArtifactID]: + return self._artifact_id def to_flyte_idl(self): """ @@ -172,7 +218,9 @@ def to_flyte_idl(self): return _interface_pb2.Parameter( var=self.var.to_flyte_idl(), default=self.default.to_flyte_idl() if self.default is not None else None, - required=self.required if self.default is None else None, + required=self.required if self.default is None and self.artifact_query is None else None, + artifact_query=self.artifact_query if self.artifact_query else None, + artifact_id=self.artifact_id if self.artifact_id else None, ) @classmethod @@ -185,6 +233,8 @@ def from_flyte_idl(cls, pb2_object): Variable.from_flyte_idl(pb2_object.var), _literals.Literal.from_flyte_idl(pb2_object.default) if pb2_object.HasField("default") else None, pb2_object.required if pb2_object.HasField("required") else None, + artifact_query=pb2_object.artifact_query if pb2_object.HasField("artifact_query") else None, + artifact_id=pb2_object.artifact_id if pb2_object.HasField("artifact_id") else None, ) diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index b63d94bf4ac..9f2af1b92e8 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -1,6 +1,7 @@ import typing from flyteidl.admin import launch_plan_pb2 as _launch_plan +from google.protobuf.any_pb2 import Any from flytekit.models import common as _common from flytekit.models import interface as _interface @@ -11,15 +12,17 @@ class LaunchPlanMetadata(_common.FlyteIdlEntity): - def __init__(self, schedule, notifications): + def __init__(self, schedule, notifications, launch_conditions=None): """ :param flytekit.models.schedule.Schedule schedule: Schedule to execute the Launch Plan :param list[flytekit.models.common.Notification] notifications: List of notifications based on execution status transitions + :param launch_conditions: Additional metadata for launching """ self._schedule = schedule self._notifications = notifications + self._launch_conditions = launch_conditions @property def schedule(self): @@ -37,14 +40,24 @@ def notifications(self): """ return self._notifications + @property + def launch_conditions(self): + return self._launch_conditions + def to_flyte_idl(self): """ List of notifications based on Execution status transitions :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanMetadata """ + if self.launch_conditions: + a = Any() + a.Pack(self.launch_conditions) + else: + a = None return _launch_plan.LaunchPlanMetadata( schedule=self.schedule.to_flyte_idl() if self.schedule is not None else None, notifications=[n.to_flyte_idl() for n in self.notifications], + launch_conditions=a, ) @classmethod @@ -58,6 +71,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("schedule") else None, notifications=[_common.Notification.from_flyte_idl(n) for n in pb2_object.notifications], + launch_conditions=pb2_object.launch_conditions if pb2_object.HasField("launch_conditions") else None, ) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 82a915cda1c..c59de4afde6 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -919,10 +919,9 @@ def hash(self, value): self._hash = value @property - def metadata(self): + def metadata(self) -> Optional[Dict[str, str]]: """ This value holds metadata about the literal. - :rtype: typing.Dict[str, str] """ return self._metadata diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index beef4108b0b..c8acaea259b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -24,12 +24,13 @@ import fsspec import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest -from flyteidl.core import literals_pb2 as literals_pb2 +from flyteidl.core import literals_pb2 from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.core import constants, utils +from flytekit.core.artifact import Artifact from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider @@ -800,7 +801,10 @@ def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: s return self.upload_file(pathlib.Path(zip_file)) def upload_file( - self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None + self, + to_upload: pathlib.Path, + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, ) -> typing.Tuple[bytes, str]: """ Function will use remote's client to hash and then upload the file using Admin's data proxy service. @@ -1042,6 +1046,8 @@ def _execute( ) if isinstance(v, Literal): lit = v + elif isinstance(v, Artifact): + raise user_exceptions.FlyteValueException(v, "Running with an artifact object is not yet possible.") else: if k not in type_hints: try: diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5d182c89b8b..1c2016b681d 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -372,6 +372,7 @@ def get_serializable_launch_plan( entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=entity.schedule, notifications=options.notifications or entity.notifications, + launch_conditions=entity.additional_metadata, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, diff --git a/pyproject.toml b/pyproject.toml index a3fbde5be5a..01bbf89daef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "grpcio", "grpcio-status", "importlib-metadata", + "isodate", "joblib", "jsonpickle", "keyring>=18.0.1", diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py new file mode 100644 index 00000000000..abd9e456c2e --- /dev/null +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -0,0 +1,335 @@ +import datetime +import sys +from collections import OrderedDict + +import pytest +from flyteidl.core import artifact_id_pb2 as art_id +from typing_extensions import Annotated, get_args + +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.artifact import Artifact, Inputs +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.interface import detect_artifact +from flytekit.core.launch_plan import LaunchPlan +from flytekit.core.task import task +from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteValidationException +from flytekit.tools.translator import get_serializable + +if "pandas" not in sys.modules: + pytest.skip(reason="Requires pandas", allow_module_level=True) + + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +class CustomReturn(object): + def __init__(self, data): + self.data = data + + +def test_basic_option_a_rev(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + + @task + def t1( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dt, b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition.value.input_binding.var == "dt" + assert p["b"].HasField("input_binding") + assert p["b"].input_binding.var == "b_value" + assert p["a"].HasField("static_value") + assert p["a"].static_value == "manual" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == "" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + + +def test_args_getting(): + a1 = Artifact(name="argstst") + a1_called = a1() + x = Annotated[int, a1_called] + gotten = get_args(x) + assert len(gotten) == 2 + assert gotten[1] is a1_called + detected = detect_artifact(get_args(int)) + assert detected is None + detected = detect_artifact(get_args(x)) + assert detected == a1_called.to_partial_artifact_id() + + +def test_basic_option_no_tp(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"]) + assert not a1_t_ab.time_partitioned + + # trying to bind to a time partition when not so raises an error. + with pytest.raises(ValueError): + + @task + def t1x( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dt, b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + @task + def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + p = t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.HasField("time_partition") is False + assert p["b"].HasField("input_binding") + + +def test_basic_option_hardcoded_tp(): + a1_t_ab = Artifact(name="my_data", time_partitioned=True) + + dt = datetime.datetime.strptime("04/05/2063", "%m/%d/%Y") + + id_spec = a1_t_ab(time_partition=dt) + assert id_spec.partitions is None + assert id_spec.time_partition.value.HasField("time_value") + + +def test_basic_option_a(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + + @task + def t1( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(b=Inputs.b_value, a="manual", time_partition=Inputs.dt)]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + entities = OrderedDict() + t1_s = get_serializable(entities, serialization_settings, t1) + assert len(t1_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.version == "" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + assert t1_s.template.interface.outputs["o0"].artifact_partial_id.time_partition is not None + + +def test_basic_no_call(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + + # raise an error because the user hasn't () the artifact + with pytest.raises(ValueError): + + @task + def t1(b_value: str, dt: datetime.datetime) -> Annotated[pd.DataFrame, a1_t_ab]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + +def test_basic_option_a2(): + import pandas as pd + + a2_ab = Artifact(name="my_data2", partition_keys=["a", "b"]) + + with pytest.raises(ValueError): + + @task + def t2x(b_value: str) -> Annotated[pd.DataFrame, a2_ab(a=Inputs.b_value)]: + ... + + @task + def t2(b_value: str) -> Annotated[pd.DataFrame, a2_ab(a=Inputs.b_value, b="manualval")]: + ... + + entities = OrderedDict() + t2_s = get_serializable(entities, serialization_settings, t2) + assert len(t2_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 2 + assert t2_s.template.interface.outputs["o0"].artifact_partial_id.version == "" + assert t2_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data2" + assert t2_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.project == "" + + +def test_basic_option_a3(): + import pandas as pd + + a3 = Artifact(name="my_data3") + + @task + def t3(b_value: str) -> Annotated[pd.DataFrame, a3]: + ... + + entities = OrderedDict() + t3_s = get_serializable(entities, serialization_settings, t3) + assert len(t3_s.template.interface.outputs["o0"].artifact_partial_id.partitions.value) == 0 + assert t3_s.template.interface.outputs["o0"].artifact_partial_id.artifact_key.name == "my_data3" + + +def test_query_basic(): + aa = Artifact( + name="ride_count_data", + time_partitioned=True, + partition_keys=["region"], + ) + data_query = aa.query(time_partition=Inputs.dt, region=Inputs.blah) + assert data_query.bindings == [] + assert data_query.artifact is aa + dq_idl = data_query.to_flyte_idl() + assert dq_idl.HasField("artifact_id") + assert dq_idl.artifact_id.artifact_key.name == "ride_count_data" + assert len(dq_idl.artifact_id.partitions.value) == 1 + assert dq_idl.artifact_id.partitions.value["region"].HasField("input_binding") + assert dq_idl.artifact_id.partitions.value["region"].input_binding.var == "blah" + assert dq_idl.artifact_id.time_partition.value.input_binding.var == "dt" + + +def test_not_specified_behavior(): + wf_artifact_no_tag = Artifact(project="project1", domain="dev", name="wf_artifact", version="1", partitions=None) + aq = wf_artifact_no_tag.query("pr", "dom").to_flyte_idl() + assert aq.artifact_id.HasField("partitions") is False + assert aq.artifact_id.artifact_key.project == "pr" + assert aq.artifact_id.artifact_key.domain == "dom" + + assert wf_artifact_no_tag.concrete_artifact_id.HasField("partitions") is False + + wf_artifact_no_tag = Artifact(project="project1", domain="dev", name="wf_artifact", partitions={}) + assert wf_artifact_no_tag.partitions is None + aq = wf_artifact_no_tag.query().to_flyte_idl() + assert aq.artifact_id.HasField("partitions") is False + assert aq.artifact_id.HasField("time_partition") is False + + +def test_artifact_as_promise_query(): + # when artifact is partially specified, can be used as a query input + wf_artifact = Artifact(project="project1", domain="dev", name="wf_artifact") + + @task + def t1(a: CustomReturn) -> CustomReturn: + print(a) + return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) + + @workflow + def wf(a: CustomReturn = wf_artifact.query()): + u = t1(a=a) + return u + + ctx = FlyteContextManager.current_context() + lp = LaunchPlan.get_default_launch_plan(ctx, wf) + entities = OrderedDict() + spec = get_serializable(entities, serialization_settings, lp) + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.project == "project1" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.domain == "dev" + assert spec.spec.default_inputs.parameters["a"].artifact_query.artifact_id.artifact_key.name == "wf_artifact" + + +def test_artifact_as_promise(): + # when the full artifact is specified, the artifact should be bindable as a literal + wf_artifact = Artifact(project="pro", domain="dom", name="key", version="v0.1.0", partitions={"region": "LAX"}) + + @task + def t1(a: CustomReturn) -> CustomReturn: + print(a) + return CustomReturn({"name": ["Tom", "Joseph"], "age": [20, 22]}) + + @workflow + def wf2(a: CustomReturn = wf_artifact): + u = t1(a=a) + return u + + ctx = FlyteContextManager.current_context() + lp = LaunchPlan.get_default_launch_plan(ctx, wf2) + entities = OrderedDict() + spec = get_serializable(entities, serialization_settings, lp) + assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.project == "pro" + assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.domain == "dom" + assert spec.spec.default_inputs.parameters["a"].artifact_id.artifact_key.name == "key" + + aq = wf_artifact.query().to_flyte_idl() + assert aq.artifact_id.HasField("partitions") is True + assert aq.artifact_id.partitions.value["region"].static_value == "LAX" + + +def test_partition_none(): + # confirm that we can distinguish between partitions being set to empty, and not being set + # though this is not currently used. + ak = art_id.ArtifactKey(project="p", domain="d", name="name") + no_partition = art_id.ArtifactID(artifact_key=ak, version="without_p") + assert not no_partition.HasField("partitions") + + p = art_id.Partitions() + with_partition = art_id.ArtifactID(artifact_key=ak, version="without_p", partitions=p) + assert with_partition.HasField("partitions") + + +def test_as_artf_no_partitions(): + int_artf = Artifact(name="important_int") + + @task + def greet(day_of_week: str, number: int, am: bool) -> str: + greeting = "Have a great " + day_of_week + " " + greeting += "morning" if am else "evening" + return greeting + "!" * number + + @workflow + def go_greet(day_of_week: str, number: int = int_artf.query(), am: bool = False) -> str: + return greet(day_of_week=day_of_week, number=number, am=am) + + tst_lp = LaunchPlan.create( + "morning_lp", + go_greet, + fixed_inputs={"am": True}, + default_inputs={"day_of_week": "monday"}, + ) + + entities = OrderedDict() + spec = get_serializable(entities, serialization_settings, tst_lp) + aq = spec.spec.default_inputs.parameters["number"].artifact_query + assert aq.artifact_id.artifact_key.name == "important_int" + assert not aq.artifact_id.HasField("partitions") + assert not aq.artifact_id.HasField("time_partition") + + +def test_check_input_binding(): + import pandas as pd + + a1_t_ab = Artifact(name="my_data", partition_keys=["a", "b"], time_partitioned=True) + + with pytest.raises(FlyteValidationException): + + @task + def t1( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dt, b=Inputs.xyz, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df + + with pytest.raises(FlyteValidationException): + + @task + def t2( + b_value: str, dt: datetime.datetime + ) -> Annotated[pd.DataFrame, a1_t_ab(time_partition=Inputs.dtt, b=Inputs.b_value, a="manual")]: + df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) + return df