Skip to content

Commit

Permalink
[close flyteorg/flyte#4241]: add nested types support for dict, datac…
Browse files Browse the repository at this point in the history
…lass, kwargs or mix

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Mar 10, 2024
1 parent 64b8468 commit ba969bd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
24 changes: 22 additions & 2 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,37 @@
UNSET_CARD = "_uc"


def kwtypes(**kwargs) -> OrderedDict[str, Type]:
def flatten_dict(nested_dict):
def _flatten(sub_dict, parent_key=""):
result = {}
for key, value in sub_dict.items():
current_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
result.update(_flatten(value, current_key))
else:
result[current_key] = value
return result

return _flatten(nested_dict)


def kwtypes(*args, **kwargs) -> OrderedDict[str, Type]:
"""
This is a small helper function to convert the keyword arguments to an OrderedDict of types.
.. code-block:: python
kwtypes(a=int, b=str)
kwtypes(a = int, b = str)
kwtypes({"a": int, "b": str})
"""
d = collections.OrderedDict()

for arg in args:
flat_arg = flatten_dict(arg)
d.update(flat_arg)
for k, v in kwargs.items():
d[k] = v

return d


Expand Down
18 changes: 16 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,11 @@ def iter_as(
raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}")
return result

def _column_to_type_list(self, dataclass):
if not isinstance(dataclass, type) or not hasattr(dataclass, "__dataclass_fields__"):
return [{"": dataclass}]
return [{field.name: field.type} for field in dataclass.__dataclass_fields__.values()]

def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType:
if t in get_supported_types():
return get_supported_types()[t]
Expand All @@ -827,8 +832,17 @@ def _convert_ordered_dict_of_columns_to_list(
if column_map is None or len(column_map) == 0:
return converted_cols
for k, v in column_map.items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
vt_lst = self._column_to_type_list(v)
for vt_i in vt_lst:
lt = self._get_dataset_column_literal_type(list(vt_i.values())[0])
if list(vt_i.keys())[0] == "":
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
else:
converted_cols.append(
StructuredDatasetType.DatasetColumn(
name=f"{k.capitalize()}.{list(vt_i.keys())[0].capitalize()}", literal_type=lt
)
)
return converted_cols

def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType:
Expand Down

0 comments on commit ba969bd

Please sign in to comment.