From 685403d8c52910cfc98cce5cdbd8748312fbcae0 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 5 Jan 2024 13:43:29 +0800 Subject: [PATCH] Literal metadata model update (#2089) Signed-off-by: Yee Hing Tong --- flytekit/models/literals.py | 10 +++++++++- tests/flytekit/unit/models/test_literals.py | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index fa1f3d8ade..f164ab7b25 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -1,6 +1,6 @@ from datetime import datetime as _datetime from datetime import timezone as _timezone -from typing import Optional +from typing import Optional, Dict from flyteidl.core import literals_pb2 as _literals_pb2 from google.protobuf.struct_pb2 import Struct @@ -859,6 +859,7 @@ def __init__( collection: Optional[LiteralCollection] = None, map: Optional[LiteralMap] = None, hash: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, ): """ This IDL message represents a literal value in the Flyte ecosystem. @@ -871,6 +872,7 @@ def __init__( self._collection = collection self._map = map self._hash = hash + self._metadata = metadata @property def scalar(self): @@ -916,6 +918,10 @@ def hash(self): def hash(self, value): self._hash = value + @property + def metadata(self) -> Optional[Dict[str, str]]: + return self._metadata + def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Literal @@ -925,6 +931,7 @@ def to_flyte_idl(self): collection=self.collection.to_flyte_idl() if self.collection is not None else None, map=self.map.to_flyte_idl() if self.map is not None else None, hash=self.hash, + metadata=self.metadata, ) @classmethod @@ -942,4 +949,5 @@ def from_flyte_idl(cls, pb2_object): collection=collection, map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, hash=pb2_object.hash if pb2_object.hash else None, + metadata={k: v for k, v in pb2_object.metadata.items()}, ) diff --git a/tests/flytekit/unit/models/test_literals.py b/tests/flytekit/unit/models/test_literals.py index 73d8508d75..df1172949c 100644 --- a/tests/flytekit/unit/models/test_literals.py +++ b/tests/flytekit/unit/models/test_literals.py @@ -500,7 +500,7 @@ def test_binding_data_collection_nested(): @pytest.mark.parametrize("scalar_value_pair", parameterizers.LIST_OF_SCALARS_AND_PYTHON_VALUES) def test_scalar_literals(scalar_value_pair): scalar, _ = scalar_value_pair - obj = literals.Literal(scalar=scalar) + obj = literals.Literal(scalar=scalar, metadata={"a": "b"}) assert obj.value == scalar assert obj.scalar == scalar assert obj.collection is None @@ -512,6 +512,7 @@ def test_scalar_literals(scalar_value_pair): assert obj2.scalar == scalar assert obj2.collection is None assert obj2.map is None + assert obj2.metadata == {"a": "b"} @pytest.mark.parametrize("literal_value_pair", parameterizers.LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE)