Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch_daily example #113

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improvements to the scheduling and workflow
hellais committed Apr 25, 2024

Verified

This commit was signed with the committer’s verified signature.
khaneliman Austin Horstman
commit 1a70e3a17ac713567414d11049bcdb20cbe6687a
29 changes: 20 additions & 9 deletions batch_daily/activities.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from typing import List
import asyncio
import time
import random
from typing import Any, Dict, List
from temporalio import activity

from dataclasses import dataclass

log = logging.getLogger(__name__)


@dataclass
class ListRecordActivityInput:
@@ -18,15 +18,26 @@ class ProcessRecordActivityInput:
uri: str


async def random_sleep():
"""
simulate a long running operation with a random sleep.
"""
sleep_s = 1 / random.randint(1, 100)
await asyncio.sleep(sleep_s)


@activity.defn
def list_records(activity_input: ListRecordActivityInput) -> List[str]:
log.info(
async def list_records(activity_input: ListRecordActivityInput) -> List[str]:
print(
f"filtering records on {activity_input.day} based on filter: {activity_input.record_filter}"
)
await random_sleep()
return [f"uri://record-id{idx}" for idx in range(10)]


@activity.defn
def process_record(activity_input: ProcessRecordActivityInput) -> str:
log.info(f"this record is yummy: {activity_input.uri}")
return activity_input.uri
async def process_record(activity_input: ProcessRecordActivityInput) -> Dict[str, Any]:
t0 = time.monotonic()
print(f"this record is yummy: {activity_input.uri}")
await random_sleep()
return {"runtime": time.monotonic() - t0}
21 changes: 8 additions & 13 deletions batch_daily/create_schedule.py
Original file line number Diff line number Diff line change
@@ -11,32 +11,27 @@
WorkflowFailureError,
)

from batch.workflows import (
DailyBatch,
DailyBatchWorkflowInput,
from batch_daily.workflows import (
RecordBatchProcessor,
RecordBatchProcessorWorkflowInput,
TASK_QUEUE_NAME,
)


async def main() -> None:
"""Main function to run temporal workflow."""
client = await Client.connect("localhost:7233")

wf_input = DailyBatchWorkflowInput(
record_filter="taste=yummy",
# XXX: how do we get the current day in a way that works with the schedule?
start_day=datetime.now().date().strftime("%Y-%m-%d"),
end_day=((datetime.now().date()) + timedelta(days=1)).strftime("%Y-%m-%d"),
)

try:
wf_input = RecordBatchProcessorWorkflowInput(record_filter="taste=yummy")
await client.create_schedule(
"daily-batch-wf-schedule",
Schedule(
action=ScheduleActionStartWorkflow(
DailyBatch.run,
RecordBatchProcessor.run,
wf_input,
id=f"daily-batch-{wf_input.record_filter}",
task_queue="TASK_QUEUE",
id=f"record-filter-{wf_input.record_filter}",
task_queue=TASK_QUEUE_NAME,
),
spec=ScheduleSpec(
intervals=[ScheduleIntervalSpec(every=timedelta(hours=1))]
23 changes: 7 additions & 16 deletions batch_daily/run_worker.py
Original file line number Diff line number Diff line change
@@ -3,32 +3,23 @@

from temporalio.client import Client
from temporalio.worker import Worker
from temporalio.worker.workflow_sandbox import (
SandboxedWorkflowRunner,
SandboxRestrictions,
)

from cloud_export_to_parquet.data_trans_activities import (
data_trans_and_land,
get_object_keys,
from batch_daily.activities import (
list_records,
process_record,
)
from cloud_export_to_parquet.workflows import ProtoToParquet
from batch_daily.workflows import DailyBatch, RecordBatchProcessor, TASK_QUEUE_NAME


async def main() -> None:
"""Main worker function."""
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")

# Run the worker
worker: Worker = Worker(
client,
task_queue="DATA_TRANSFORMATION_TASK_QUEUE",
workflows=[ProtoToParquet],
activities=[get_object_keys, data_trans_and_land],
workflow_runner=SandboxedWorkflowRunner(
restrictions=SandboxRestrictions.default.with_passthrough_modules("boto3")
),
task_queue=TASK_QUEUE_NAME,
workflows=[DailyBatch, RecordBatchProcessor],
activities=[list_records, process_record],
activity_executor=ThreadPoolExecutor(100),
)
await worker.run()
28 changes: 28 additions & 0 deletions batch_daily/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio

from temporalio.client import Client

# from batch_daily.activity import
from batch_daily.workflows import DailyBatchWorkflowInput, TASK_QUEUE_NAME, DailyBatch


async def main():
client = await Client.connect(
"localhost:7233",
)

result = await client.execute_workflow(
DailyBatch.run,
DailyBatchWorkflowInput(
start_day="2024-01-01",
end_day="2024-03-01",
record_filter="taste=yummy",
),
id=f"daily_batch-workflow-id",
task_queue=TASK_QUEUE_NAME,
)
print(f"Workflow result: {result}")


if __name__ == "__main__":
asyncio.run(main())
98 changes: 60 additions & 38 deletions batch_daily/workflows.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,72 @@
import asyncio
from datetime import datetime, timedelta

from dataclasses import dataclass
import time
from typing import Any, Dict, Optional

from temporalio import workflow
from temporalio.common import RetryPolicy
from temporalio.exceptions import ActivityError
from temporalio.common import SearchAttributeKey

with workflow.unsafe.imports_passed_through():
from batch.activities import (
from batch_daily.activities import (
ListRecordActivityInput,
list_records,
ProcessRecordActivityInput,
process_record,
)

TASK_QUEUE_NAME = "MY_TASK_QUEUE"


@dataclass
class RecordProcessorWorkflowInput:
day: str
record_uri: str
class RecordBatchProcessorWorkflowInput:
record_filter: str
day: Optional[str] = None


@workflow.defn
class RecordProcessor:
class RecordBatchProcessor:
@workflow.run
async def run(self, workflow_input: RecordProcessorWorkflowInput) -> str:
async def run(
self, workflow_input: RecordBatchProcessorWorkflowInput
) -> Dict[str, Any]:
if workflow_input.day is None:
schedule_time = workflow.info().typed_search_attributes.get(
SearchAttributeKey.for_datetime("TemporalScheduledStartTime")
)
assert schedule_time is not None, "when not scheduled, day must be provided"
workflow_input.day = schedule_time.strftime("%Y-%m-%d")

print(f"starting RecordProcessor with {workflow_input}")

list_records_input = ListRecordActivityInput(
record_filter="taste=yummy", day=workflow_input.day
record_filter=workflow_input.record_filter, day=workflow_input.day
)

record_uri_list = await workflow.execute_activity(
list_records,
list_records_input,
start_to_close_timeout=timedelta(minutes=5),
)
try:

task_list = []
async with asyncio.TaskGroup() as tg:
for key in record_uri_list:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no good as it runs the risk of spamming the event history log. It should probably work similarly to the java example where it uses continue_as_new to reset the event history.

process_record_input = ProcessRecordActivityInput(uri=key)
await workflow.execute_activity(
process_record,
process_record_input,
start_to_close_timeout=timedelta(minutes=1),
task_list.append(
tg.create_task(
workflow.execute_activity(
process_record,
process_record_input,
start_to_close_timeout=timedelta(minutes=1),
)
)
)

except ActivityError as output_err:
workflow.logger.error(f"failed: {output_err}")
raise output_err
total_runtime = sum(map(lambda task: task.result()["runtime"], task_list))
return {"runtime": total_runtime}


@dataclass
@@ -60,25 +81,26 @@ class DailyBatch:
"""DailyBatch workflow"""

@workflow.run
async def run(self, workflow_input: DailyBatchWorkflowInput) -> str:
if workflow_input.start_day == workflow_input.end_day:
return ""

await workflow.execute_child_workflow(
RecordProcessor.run,
RecordProcessorWorkflowInput(
day=workflow_input.start_day, record_uri=workflow_input.record_filter
),
)

next_start_day = (
datetime.strptime(workflow_input.start_day, "%Y-%m-%d") + timedelta(days=1)
).strftime("%Y-%m-%d")

return workflow.continue_as_new(
DailyBatchWorkflowInput(
start_day=next_start_day,
end_day=workflow_input.end_day,
record_filter=workflow_input.record_filter,
)
)
async def run(self, workflow_input: DailyBatchWorkflowInput) -> Dict[str, Any]:
print(f"starting DailyBatch with {workflow_input}")

start = datetime.strptime(workflow_input.start_day, "%Y-%m-%d")
end = datetime.strptime(workflow_input.end_day, "%Y-%m-%d")
task_list = []
async with asyncio.TaskGroup() as tg:
for day in [
start + timedelta(days=x) for x in range(0, (end - start).days)
]:
task_list.append(
tg.create_task(
workflow.execute_child_workflow(
RecordBatchProcessor.run,
RecordBatchProcessorWorkflowInput(
day=day.strftime("%Y-%m-%d"),
record_filter=workflow_input.record_filter,
),
)
)
)
total_runtime = sum(map(lambda task: task.result()["runtime"], task_list))
return {"runtime": total_runtime}