Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def run_context() -> AsyncIterator[None]:
super().__init__(
name="OpenAIAgentsPlugin",
data_converter=_data_converter,
worker_interceptors=[OpenAIAgentsTracingInterceptor()],
interceptors=[OpenAIAgentsTracingInterceptor()],
activities=add_activities,
workflow_runner=workflow_runner,
workflow_failure_exception_types=[AgentsWorkflowError],
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/opentelemetry/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:

super().__init__(
"OpenTelemetryPlugin",
client_interceptors=interceptors,
interceptors=interceptors,
workflow_runner=workflow_runner,
)
116 changes: 71 additions & 45 deletions temporalio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@ def __init__(
name: str,
*,
data_converter: PluginParameter[temporalio.converter.DataConverter] = None,
client_interceptors: PluginParameter[
Sequence[temporalio.client.Interceptor]
interceptors: PluginParameter[
Sequence[temporalio.client.Interceptor | temporalio.worker.Interceptor]
] = None,
activities: PluginParameter[Sequence[Callable]] = None,
nexus_service_handlers: PluginParameter[Sequence[Any]] = None,
workflows: PluginParameter[Sequence[type]] = None,
workflow_runner: PluginParameter[WorkflowRunner] = None,
worker_interceptors: PluginParameter[
Sequence[temporalio.worker.Interceptor]
] = None,
workflow_failure_exception_types: PluginParameter[
Sequence[type[BaseException]]
] = None,
Expand All @@ -66,9 +63,10 @@ def __init__(
name: The name of the plugin.
data_converter: Data converter for serialization, or callable to customize existing one.
Applied to the Client and Replayer.
client_interceptors: Client interceptors to append, or callable to customize existing ones.
Applied to the Client. Note, if the provided interceptor is also a worker.Interceptor,
it will be added to any worker which uses that client.
interceptors: Interceptors to append, or callable to customize existing ones.
Client interceptors are applied to the Client, worker interceptors are applied
to the Worker and Replayer. Interceptors that implement both interfaces will
be applied to both, with exactly one instance used per worker to avoid duplication.
activities: Activity functions to append, or callable to customize existing ones.
Applied to the Worker.
nexus_service_handlers: Nexus service handlers to append, or callable to customize existing ones.
Expand All @@ -77,8 +75,6 @@ def __init__(
Applied to the Worker and Replayer.
workflow_runner: Workflow runner, or callable to customize existing one.
Applied to the Worker and Replayer.
worker_interceptors: Worker interceptors to append, or callable to customize existing ones.
Applied to the Worker and Replayer.
workflow_failure_exception_types: Exception types for workflow failures to append,
or callable to customize existing ones. Applied to the Worker and Replayer.
run_context: A place to run custom code to wrap around the Worker (or Replayer) execution.
Expand All @@ -89,12 +85,11 @@ def __init__(
"""
self._name = name
self.data_converter = data_converter
self.client_interceptors = client_interceptors
self.interceptors = interceptors
self.activities = activities
self.nexus_service_handlers = nexus_service_handlers
self.workflows = workflows
self.workflow_runner = workflow_runner
self.worker_interceptors = worker_interceptors
self.workflow_failure_exception_types = workflow_failure_exception_types
self.run_context = run_context

Expand All @@ -110,11 +105,22 @@ def configure_client(self, config: ClientConfig) -> ClientConfig:
if data_converter:
config["data_converter"] = data_converter

interceptors = _resolve_append_parameter(
config.get("interceptors"), self.client_interceptors
# Resolve the combined interceptors first, then filter to client ones
all_interceptors = _resolve_append_parameter(
cast(
Sequence[temporalio.client.Interceptor | temporalio.worker.Interceptor]
| None,
config.get("interceptors"),
),
self.interceptors,
)
if interceptors is not None:
config["interceptors"] = interceptors
if all_interceptors is not None:
client_interceptors = [
interceptor
for interceptor in all_interceptors
if isinstance(interceptor, temporalio.client.Interceptor)
]
config["interceptors"] = client_interceptors

return config

Expand Down Expand Up @@ -150,36 +156,46 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
if workflow_runner:
config["workflow_runner"] = workflow_runner

interceptors = list(
_resolve_append_parameter(
config.get("interceptors"), self.worker_interceptors
if callable(self.interceptors):
interceptors = (
_resolve_append_parameter(
cast(
Sequence[
temporalio.client.Interceptor
| temporalio.worker.Interceptor
]
| None,
config.get("interceptors"),
),
self.interceptors,
)
or []
)
or []
)

# Only propagate client interceptors if they are provided as a simple list (not callable)
if self.client_interceptors is not None and not callable(
self.client_interceptors
):
client_worker_interceptors = [
# Filter out any client only interceptors the callable returned
config["interceptors"] = [
interceptor
for interceptor in self.client_interceptors
for interceptor in interceptors
if isinstance(interceptor, temporalio.worker.Interceptor)
]
for interceptor in client_worker_interceptors:
if interceptor not in interceptors:
# Check if interceptor is already in client's interceptors to avoid duplication
client_config = config.get("client")
if client_config is not None:
client_interceptors_list = client_config.config(
active_config=True
).get("interceptors", [])
if interceptor not in client_interceptors_list:
interceptors.append(interceptor)
else:
interceptors.append(interceptor)

config["interceptors"] = interceptors
elif self.interceptors is not None:
client_interceptors_list = (
config["client"].config(active_config=True).get("interceptors", []) # type:ignore[reportTypedDictNotRequiredAccess]
)

# Exclude any already registered interceptors and client only interceptors
worker_interceptors = [
interceptor
for interceptor in self.interceptors
if isinstance(interceptor, temporalio.worker.Interceptor)
and interceptor not in client_interceptors_list
]

provided_interceptors = _resolve_append_parameter(
config.get("interceptors"), worker_interceptors
)
if provided_interceptors is not None:
config["interceptors"] = provided_interceptors

failure_exception_types = _resolve_append_parameter(
config.get("workflow_failure_exception_types"),
Expand Down Expand Up @@ -208,11 +224,21 @@ def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig:
if workflow_runner:
config["workflow_runner"] = workflow_runner

interceptors = _resolve_append_parameter(
config.get("interceptors"), self.worker_interceptors
all_interceptors = _resolve_append_parameter(
cast(
Sequence[temporalio.client.Interceptor | temporalio.worker.Interceptor]
| None,
config.get("interceptors"),
),
self.interceptors,
)
if interceptors is not None:
config["interceptors"] = interceptors
if all_interceptors is not None:
worker_interceptors = [
interceptor
for interceptor in all_interceptors
if isinstance(interceptor, temporalio.worker.Interceptor)
]
config["interceptors"] = worker_interceptors

failure_exception_types = _resolve_append_parameter(
config.get("workflow_failure_exception_types"),
Expand Down
54 changes: 39 additions & 15 deletions tests/contrib/openai_agents/test_openai_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ async def test_tracing(client: Client):
execution_timeout=timedelta(seconds=120),
)
await workflow_handle.result()
print("\n".join([str({"name": t.name}) for t, _ in processor.trace_events]))

# There is one closed root trace
assert len(processor.trace_events) == 2
# There are two traces, one is created in the client because it is needed to start the temporal spans
assert len(processor.trace_events) == 4
assert (
processor.trace_events[0][0].trace_id
== processor.trace_events[1][0].trace_id
Expand All @@ -76,25 +77,48 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None:
assert a[1]
assert not b[1]

print(
"\n".join(
[
str({"id": t.span_id, "data": t.span_data.export()})
for t, _ in processor.span_events
]
)
)

# Start workflow traces
paired_span(processor.span_events[0], processor.span_events[1])
assert (
processor.span_events[0][0].span_data.export().get("name")
== "temporal:startWorkflow:ResearchWorkflow"
)

# Execute workflow
paired_span(processor.span_events[2], processor.span_events[-1])
assert (
processor.span_events[2][0].span_data.export().get("name")
== "temporal:executeWorkflow"
)

# Initial planner spans - There are only 3 because we don't make an actual model call
paired_span(processor.span_events[0], processor.span_events[5])
paired_span(processor.span_events[3], processor.span_events[8])
assert (
processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent"
processor.span_events[3][0].span_data.export().get("name") == "PlannerAgent"
)

paired_span(processor.span_events[1], processor.span_events[4])
paired_span(processor.span_events[4], processor.span_events[7])
assert (
processor.span_events[1][0].span_data.export().get("name")
processor.span_events[4][0].span_data.export().get("name")
== "temporal:startActivity"
)

paired_span(processor.span_events[2], processor.span_events[3])
paired_span(processor.span_events[5], processor.span_events[6])
assert (
processor.span_events[2][0].span_data.export().get("name")
processor.span_events[5][0].span_data.export().get("name")
== "temporal:executeActivity"
)

for span, start in processor.span_events[6:-6]:
for span, start in processor.span_events[9:-7]:
span_data = span.span_data.export()

# All spans should be closed
Expand Down Expand Up @@ -126,19 +150,19 @@ def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None:
)

# Final writer spans - There are only 3 because we don't make an actual model call
paired_span(processor.span_events[-6], processor.span_events[-1])
paired_span(processor.span_events[-7], processor.span_events[-2])
assert (
processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent"
processor.span_events[-7][0].span_data.export().get("name") == "WriterAgent"
)

paired_span(processor.span_events[-5], processor.span_events[-2])
paired_span(processor.span_events[-6], processor.span_events[-3])
assert (
processor.span_events[-5][0].span_data.export().get("name")
processor.span_events[-6][0].span_data.export().get("name")
== "temporal:startActivity"
)

paired_span(processor.span_events[-4], processor.span_events[-3])
paired_span(processor.span_events[-5], processor.span_events[-4])
assert (
processor.span_events[-4][0].span_data.export().get("name")
processor.span_events[-5][0].span_data.export().get("name")
== "temporal:executeActivity"
)
27 changes: 13 additions & 14 deletions tests/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,15 @@ async def test_simple_plugin_worker_interceptor_only_used_on_worker(
client: Client,
) -> None:
"""Test that when a combined client/worker interceptor is provided by SimplePlugin
to client_interceptors, and the plugin is only used on a worker (not on the client
to interceptors, and the plugin is only used on a worker (not on the client
used to create that worker), the worker interceptor functionality is still provided."""

interceptor = CombinedClientWorkerInterceptor()

# Create SimplePlugin that provides the combined interceptor as client_interceptors
# Create SimplePlugin that provides the combined interceptor
plugin = SimplePlugin(
"TestCombinedPlugin",
client_interceptors=[interceptor],
interceptors=[interceptor],
)

# Create worker with the plugin (but don't add plugin to client)
Expand All @@ -468,23 +468,23 @@ async def test_simple_plugin_worker_interceptor_only_used_on_worker(
), "Client interceptor should not have been used"

# The interceptor SHOULD have been used for worker interception
# even though it was specified in client_interceptors
# even though it was specified in interceptors
assert interceptor.worker_intercepted, "Worker interceptor should have been used"


async def test_simple_plugin_interceptor_duplication_when_used_on_client_and_worker(
client: Client,
) -> None:
"""Test that when a combined client/worker interceptor is provided by SimplePlugin
to client_interceptors, and the plugin is used on both client and worker,
to interceptors, and the plugin is used on both client and worker,
the interceptor is not duplicated in the worker."""

interceptor = CombinedClientWorkerInterceptor()

# Create SimplePlugin that provides the combined interceptor as client_interceptors
# Create SimplePlugin that provides the combined interceptor
plugin = SimplePlugin(
"TestCombinedPlugin",
client_interceptors=[interceptor],
interceptors=[interceptor],
)

# Add plugin to client first
Expand Down Expand Up @@ -535,16 +535,15 @@ async def test_simple_plugin_interceptor_duplication_when_used_on_client_and_wor
async def test_simple_plugin_no_duplication_when_interceptor_in_both_client_and_worker_params(
client: Client,
) -> None:
"""Test that when the same interceptor is provided to both client_interceptors
and worker_interceptors in a SimplePlugin, it doesn't get duplicated."""
"""Test that when the same interceptor is provided to the unified interceptors
parameter in a SimplePlugin, it doesn't get duplicated."""

interceptor = CombinedClientWorkerInterceptor()

# Create SimplePlugin that provides the same interceptor to both client and worker
# Create SimplePlugin that provides the interceptor once to the unified parameter
plugin = SimplePlugin(
"TestCombinedPlugin",
client_interceptors=[interceptor],
worker_interceptors=[interceptor], # Same interceptor in both places
interceptors=[interceptor], # Single unified parameter
)

# Create worker with plugin (not on client)
Expand Down Expand Up @@ -585,10 +584,10 @@ async def test_simple_plugin_no_duplication_in_interceptor_chain(

interceptor = CombinedClientWorkerInterceptor()

# Create SimplePlugin that provides the combined interceptor as client_interceptors only
# Create SimplePlugin that provides the combined interceptor
plugin = SimplePlugin(
"CountingPlugin",
client_interceptors=[interceptor],
interceptors=[interceptor],
)

# Add plugin to client (like OpenTelemetryPlugin does)
Expand Down
Loading