diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index 20d489b65..e1a0ba357 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -370,8 +370,7 @@ class _ContextPropagationWorkflowOutboundInterceptor( async def signal_child_workflow( self, input: temporalio.worker.SignalChildWorkflowInput ) -> None: - trace = get_trace_provider().get_current_trace() - if trace: + if get_trace_provider().get_current_trace(): with custom_span( name="temporal:signalChildWorkflow", data={"workflowId": input.child_workflow_id}, @@ -385,8 +384,7 @@ async def signal_child_workflow( async def signal_external_workflow( self, input: temporalio.worker.SignalExternalWorkflowInput ) -> None: - trace = get_trace_provider().get_current_trace() - if trace: + if get_trace_provider().get_current_trace(): with custom_span( name="temporal:signalExternalWorkflow", data={"workflowId": input.workflow_id}, @@ -400,48 +398,38 @@ async def signal_external_workflow( def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( + if get_trace_provider().get_current_trace(): + with custom_span( name="temporal:startActivity", data={"activity": input.activity} - ) - span.start(mark_as_current=True) + ): + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_activity(input) set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_activity(input) - if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore - return handle + return self.next.start_activity(input) async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( + if get_trace_provider().get_current_trace(): + with custom_span( name="temporal:startChildWorkflow", data={"workflow": input.workflow} - ) - span.start(mark_as_current=True) + ): + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.start_child_workflow(input) + set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = await self.next.start_child_workflow(input) - if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore - return handle + return await self.next.start_child_workflow(input) def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: - trace = get_trace_provider().get_current_trace() - span: Optional[Span] = None - if trace: - span = custom_span( + if get_trace_provider().get_current_trace(): + with custom_span( name="temporal:startLocalActivity", data={"activity": input.activity} - ) - span.start(mark_as_current=True) + ): + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_local_activity(input) + set_header_from_context(input, temporalio.workflow.payload_converter()) - handle = self.next.start_local_activity(input) - if span: - handle.add_done_callback(lambda _: span.finish()) # type: ignore - return handle + return self.next.start_local_activity(input) diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index c8ad366e6..f905f89e4 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -82,15 +82,16 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: paired_span(processor.span_events[0], processor.span_events[5]) assert processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" - paired_span(processor.span_events[1], processor.span_events[4]) + # startActivity happens before executeActivity + paired_span(processor.span_events[1], processor.span_events[2]) assert ( processor.span_events[1][0].span_data.export().get("name") == "temporal:startActivity" ) - paired_span(processor.span_events[2], processor.span_events[3]) + paired_span(processor.span_events[3], processor.span_events[4]) assert ( - processor.span_events[2][0].span_data.export().get("name") + processor.span_events[3][0].span_data.export().get("name") == "temporal:executeActivity" ) @@ -113,7 +114,7 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: len(parents) == 2 and parents[0].span_data.export()["type"] == "agent" ) - # Execute is parented to start + # Execute is parented to the start activity span if span_data.get("name") == "temporal:executeActivity": parents = [ s for (s, _) in processor.span_events if s.span_id == span.parent_id @@ -127,14 +128,14 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: paired_span(processor.span_events[-6], processor.span_events[-1]) assert processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" - paired_span(processor.span_events[-5], processor.span_events[-2]) + paired_span(processor.span_events[-5], processor.span_events[-4]) assert ( processor.span_events[-5][0].span_data.export().get("name") == "temporal:startActivity" ) - paired_span(processor.span_events[-4], processor.span_events[-3]) + paired_span(processor.span_events[-3], processor.span_events[-2]) assert ( - processor.span_events[-4][0].span_data.export().get("name") + processor.span_events[-3][0].span_data.export().get("name") == "temporal:executeActivity" )