Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def pandas(self):
return PandasChunkCountWithInitialStateProcessor()


class CompositeOutputProcessorFactory(StatefulProcessorFactory):
def pandas(self):
return PandasCompositeOutputProcessor()

def row(self):
return RowCompositeOutputProcessor()


# StatefulProcessor implementations


Expand Down Expand Up @@ -1873,3 +1881,161 @@ def handleInitialState(self, key, initialState, timerValues) -> None:

def close(self) -> None:
pass


# A Pandas stateful processor with a simple ValueState computation and composite output schema:
#
# primitiveValue: StringType
# listOfPrimitive: ArrayType(StringType)
# mapOfPrimitive: MapType(StringType, StringType)
# listOfComposite: ArrayType(InnerNestedClass)
# mapOfComposite: MapType(StringType, InnerNestedClass)
#
# where InnerNestedClass is a StructType with:
# intValue: IntegerType
# doubleValue: DoubleType
# arrayValue: ArrayType(StringType)
# mapValue: MapType(StringType, StringType)
class PandasCompositeOutputProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Simple value state to track counts
state_schema = StructType([StructField("value", IntegerType(), True)])
self.count_state = handle.getValueState("count_state", state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
# Calculate count from input rows
count = 0
if self.count_state.exists():
count = self.count_state.get()[0]

for pdf in rows:
count += len(pdf)

self.count_state.update((count,))

# Build output matching the composite schema
key_str = key[0]

# primitiveValue: StringType
primitive_value = f"key_{key_str}_count_{count}"

# listOfPrimitive: ArrayType(StringType)
list_of_primitive = [f"item_{i}" for i in range(count)]

# mapOfPrimitive: MapType(StringType, StringType)
map_of_primitive = {f"key{i}": f"value{i}" for i in range(count)}

# listOfComposite: ArrayType(StructType)
# Each Row schema matches the InnerNestedClass mentioned above.
list_of_composite = [
{
"intValue": i,
"doubleValue": float(i * 1.5),
"arrayValue": [f"elem_{i}_{j}" for j in range(i + 1)],
"mapValue": {f"map_{i}_{j}": f"val_{i}_{j}" for j in range(i + 1)},
}
for i in range(count)
]

# mapOfComposite: MapType(StringType, StructType)
# Each (value) Row schema matches the InnerNestedClass mentioned above.
map_of_composite = {
f"nested_key{i}": {
"intValue": i * 10,
"doubleValue": float(i * 2.5),
"arrayValue": [f"elem_{i}_{j}" for j in range(i + 1)],
"mapValue": {f"map_{i}_{j}": f"val_{i}_{j}" for j in range(i + 1)},
}
for i in range(count)
}

yield pd.DataFrame(
{
"primitiveValue": [primitive_value],
"listOfPrimitive": [list_of_primitive],
"mapOfPrimitive": [map_of_primitive],
"listOfComposite": [list_of_composite],
"mapOfComposite": [map_of_composite],
}
)

def close(self) -> None:
pass


# A Row stateful processor with a simple ValueState computation and composite output schema:
#
# primitiveValue: StringType
# listOfPrimitive: ArrayType(StringType)
# mapOfPrimitive: MapType(StringType, StringType)
# listOfComposite: ArrayType(InnerNestedClass)
# mapOfComposite: MapType(StringType, InnerNestedClass)
#
# where InnerNestedClass is a StructType with:
# intValue: IntegerType
# doubleValue: DoubleType
# arrayValue: ArrayType(StringType)
# mapValue: MapType(StringType, StringType)
class RowCompositeOutputProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
# Simple value state to track counts
state_schema = StructType([StructField("value", IntegerType(), True)])
self.count_state = handle.getValueState("count_state", state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
# Calculate count from input rows
count = 0
if self.count_state.exists():
count = self.count_state.get()[0]

for row in rows:
count += 1

self.count_state.update((count,))

# Build output matching the composite schema
key_str = key[0]

# primitiveValue: StringType
primitive_value = f"key_{key_str}_count_{count}"

# listOfPrimitive: ArrayType(StringType)
list_of_primitive = [f"item_{i}" for i in range(count)]

# mapOfPrimitive: MapType(StringType, StringType)
map_of_primitive = {f"key{i}": f"value{i}" for i in range(count)}

# listOfComposite: ArrayType(StructType)
# Each Row schema matches the InnerNestedClass mentioned above.
list_of_composite = [
Row(
intValue=i,
doubleValue=float(i * 1.5),
arrayValue=[f"elem_{i}_{j}" for j in range(i + 1)],
mapValue={f"map_{i}_{j}": f"val_{i}_{j}" for j in range(i + 1)},
)
for i in range(count)
]

# mapOfComposite: MapType(StringType, StructType)
# Each (value) Row schema matches the InnerNestedClass mentioned above.
map_of_composite = {
f"nested_key{i}": Row(
intValue=i * 10,
doubleValue=float(i * 2.5),
arrayValue=[f"elem_{i}_{j}" for j in range(i + 1)],
mapValue={f"map_{i}_{j}": f"val_{i}_{j}" for j in range(i + 1)},
)
for i in range(count)
}

yield Row(
primitiveValue=primitive_value,
listOfPrimitive=list_of_primitive,
mapOfPrimitive=map_of_primitive,
listOfComposite=list_of_composite,
mapOfComposite=map_of_composite,
)

def close(self) -> None:
pass
Loading