diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..dc3a0c7ff 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -1,6 +1,7 @@ """A framework for building, deploying, and managing AI agents.""" from . import agent, models, telemetry, types +from .agent._agent_as_tool import AgentAsTool from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy @@ -11,6 +12,7 @@ __all__ = [ "Agent", + "AgentAsTool", "AgentBase", "AgentSkills", "agent", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index c901e800f..d0254852d 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -10,6 +10,7 @@ from typing import Any from ..event_loop._retry import ModelRetryStrategy +from ._agent_as_tool import AgentAsTool from .agent import Agent from .agent_result import AgentResult from .base import AgentBase @@ -24,6 +25,7 @@ "Agent", "AgentBase", "AgentResult", + "AgentAsTool", "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", diff --git a/src/strands/agent/_agent_as_tool.py b/src/strands/agent/_agent_as_tool.py new file mode 100644 index 000000000..a54b67df9 --- /dev/null +++ b/src/strands/agent/_agent_as_tool.py @@ -0,0 +1,237 @@ +"""Agent-as-tool adapter. + +This module provides the AgentAsTool class that wraps an Agent (or any AgentBase) as a tool +so it can be passed to another agent's tool list. +""" + +import copy +import logging +from typing import Any + +from typing_extensions import override + +from ..agent.state import AgentState +from ..types._events import AgentAsToolStreamEvent, ToolResultEvent +from ..types.content import Messages +from ..types.tools import AgentTool, ToolGenerator, ToolSpec, ToolUse +from .base import AgentBase + +logger = logging.getLogger(__name__) + + +class AgentAsTool(AgentTool): + """Adapter that exposes an Agent as a tool for use by other agents. + + The tool accepts a single ``input`` string parameter, invokes the wrapped + agent, and returns the text response. + + Example: + ```python + from strands import Agent + from strands.agent import AgentAsTool + + researcher = Agent(name="researcher", description="Finds information") + + # Use directly + tool = AgentAsTool(researcher, name="researcher", description="Finds information") + + # Or via convenience method + tool = researcher.as_tool() + + # Start each invocation with a fresh conversation + tool = researcher.as_tool(preserve_context=False) + + writer = Agent(name="writer", tools=[tool]) + writer("Write about AI agents") + ``` + """ + + def __init__( + self, + agent: AgentBase, + *, + name: str, + description: str, + preserve_context: bool = False, + ) -> None: + r"""Initialize the agent-as-tool adapter. + + Args: + agent: The agent to wrap as a tool. + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + description: Tool description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. Only effective when the + wrapped agent exposes a mutable ``messages`` list and/or an ``AgentState`` + (e.g. ``strands.agent.Agent``). + """ + super().__init__() + self._agent = agent + self._tool_name = name + self._description = description + self._preserve_context = preserve_context + + # When preserve_context=False, we snapshot the agent's initial state so we can + # restore it before each invocation. This mirrors GraphNode.reset_executor_state(). + # We require an Agent instance for this since AgentBase doesn't guarantee + # messages/state attributes. + self._initial_messages: Messages = [] + self._initial_state: AgentState = AgentState() + + if not preserve_context: + from .agent import Agent + + if not isinstance(agent, Agent): + raise TypeError(f"preserve_context=False requires an Agent instance, got {type(agent).__name__}") + self._initial_messages = copy.deepcopy(agent.messages) + self._initial_state = AgentState(agent.state.get()) + + @property + def agent(self) -> AgentBase: + """The wrapped agent instance.""" + return self._agent + + @property + def tool_name(self) -> str: + """Get the tool name.""" + return self._tool_name + + @property + def tool_spec(self) -> ToolSpec: + """Get the tool specification.""" + return { + "name": self._tool_name, + "description": self._description, + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "The input to send to the agent tool.", + }, + }, + "required": ["input"], + } + }, + } + + @property + def tool_type(self) -> str: + """Get the tool type.""" + return "agent" + + @override + async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + """Invoke the wrapped agent via streaming and yield events. + + Intermediate agent events are wrapped in AgentAsToolStreamEvent so the caller + can distinguish sub-agent progress from regular tool events. The final + AgentResult is yielded as a ToolResultEvent. + + Args: + tool_use: The tool use request containing the input parameter. + invocation_state: Context for the tool invocation. + **kwargs: Additional keyword arguments. + + Yields: + AgentAsToolStreamEvent for intermediate events, then ToolResultEvent with the final response. + """ + tool_input = tool_use["input"] + if isinstance(tool_input, dict): + prompt = tool_input.get("input", "") + elif isinstance(tool_input, str): + prompt = tool_input + else: + logger.warning("tool_name=<%s> | unexpected input type: %s", self._tool_name, type(tool_input)) + prompt = str(tool_input) + + tool_use_id = tool_use["toolUseId"] + + if not self._preserve_context: + self._reset_agent_state(tool_use_id) + + logger.debug("tool_name=<%s>, tool_use_id=<%s> | invoking agent", self._tool_name, tool_use_id) + + try: + result = None + async for event in self._agent.stream_async(prompt): + if "result" in event: + result = event["result"] + else: + yield AgentAsToolStreamEvent(tool_use, event, self) + + if result is None: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": "Agent did not produce a result"}], + } + ) + return + + if result.structured_output: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"json": result.structured_output.model_dump()}], + } + ) + else: + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": str(result)}], + } + ) + + except Exception as e: + logger.warning( + "tool_name=<%s>, tool_use_id=<%s> | agent invocation failed: %s", + self._tool_name, + tool_use_id, + e, + ) + yield ToolResultEvent( + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Agent error: {e}"}], + } + ) + + def _reset_agent_state(self, tool_use_id: str) -> None: + """Reset the wrapped agent to its initial state. + + Restores messages and state to the values captured at construction time. + This mirrors the pattern used by ``GraphNode.reset_executor_state()``. + + Args: + tool_use_id: Tool use ID for logging context. + """ + from .agent import Agent + + # isinstance narrows the type for mypy; __init__ guarantees this when preserve_context=False + if not isinstance(self._agent, Agent): + return + + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | resetting agent to initial state", + self._tool_name, + tool_use_id, + ) + self._agent.messages = copy.deepcopy(self._initial_messages) + self._agent.state = AgentState(self._initial_state.get()) + + @override + def get_display_properties(self) -> dict[str, str]: + """Get properties for UI display.""" + properties = super().get_display_properties() + properties["Agent"] = getattr(self._agent, "name", "unknown") + return properties diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..f09399dbf 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -62,6 +62,7 @@ from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException from ..types.traces import AttributeValue +from ._agent_as_tool import AgentAsTool from .agent_result import AgentResult from .base import AgentBase from .conversation_manager import ( @@ -612,6 +613,40 @@ async def structured_output_async(self, output_model: type[T], prompt: AgentInpu finally: await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self, invocation_state={})) + def as_tool( + self, + name: str | None = None, + description: str | None = None, + preserve_context: bool = False, + ) -> AgentAsTool: + r"""Convert this agent into a tool for use by another agent. + + Args: + name: Tool name. Must match the pattern ``[a-zA-Z0-9_\\-]{1,64}``. + Defaults to the agent's name. + description: Tool description. Defaults to the agent's description. + preserve_context: Whether to preserve the agent's conversation history across + invocations. When False, the agent's messages and state are reset to the + values they had at construction time before each call, ensuring every + invocation starts from the same baseline regardless of any external + interactions with the agent. Defaults to False. + + Returns: + An AgentAsTool wrapping this agent. + + Example: + ```python + researcher = Agent(name="researcher", description="Finds information") + writer = Agent(name="writer", tools=[researcher.as_tool()]) + writer("Write about AI agents") + ``` + """ + if not name: + name = self.name + if not description: + description = self.description or f"Use the {name} agent as a tool by providing a natural language input" + return AgentAsTool(self, name=name, description=description, preserve_context=preserve_context) + def cleanup(self) -> None: """Clean up resources used by the agent. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 5b0ae78f6..5603aedfb 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..agent import AgentResult + from ..agent._agent_as_tool import AgentAsTool from ..multiagent.base import MultiAgentResult, NodeResult @@ -323,6 +324,31 @@ def tool_use_id(self) -> str: return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use"))["toolUseId"] +class AgentAsToolStreamEvent(ToolStreamEvent): + """Event emitted when an agent-as-tool yields intermediate events during execution. + + Extends ToolStreamEvent with a reference to the originating AgentAsTool so callers + can distinguish sub-agent stream events from regular tool stream events and access + the wrapped agent, tool name, description, etc. + """ + + def __init__(self, tool_use: ToolUse, tool_stream_data: Any, agent_as_tool: "AgentAsTool") -> None: + """Initialize with tool streaming data and agent-tool reference. + + Args: + tool_use: The tool invocation producing the stream. + tool_stream_data: The yielded event from the sub-agent execution. + agent_as_tool: The AgentAsTool instance that produced this event. + """ + super().__init__(tool_use, tool_stream_data) + self._agent_as_tool = agent_as_tool + + @property + def agent_as_tool(self) -> "AgentAsTool": + """The AgentAsTool instance that produced this event.""" + return self._agent_as_tool + + class ToolCancelEvent(TypedEvent): """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 967a0dafb..c089ba808 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -15,7 +15,7 @@ import strands from strands import Agent, Plugin, ToolContext -from strands.agent import AgentResult +from strands.agent import AgentAsTool, AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState @@ -2699,3 +2699,52 @@ def hook_callback(event: BeforeModelCallEvent): agent("test") assert len(hook_called) == 1 + + +def test_as_tool_returns_agent_tool(): + """Test that as_tool returns an AgentAsTool wrapping the agent.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert isinstance(tool, AgentAsTool) + assert tool.agent is agent + + +def test_as_tool_defaults_name_from_agent(): + """Test that as_tool defaults the tool name to the agent's name.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_name == "researcher" + + +def test_as_tool_defaults_description_from_agent(): + """Test that as_tool defaults the description to the agent's description.""" + agent = Agent(name="researcher", description="Finds information") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Finds information" + + +def test_as_tool_custom_name(): + """Test that as_tool accepts a custom name.""" + agent = Agent(name="researcher") + tool = agent.as_tool(name="custom_name") + + assert tool.tool_name == "custom_name" + + +def test_as_tool_custom_description(): + """Test that as_tool accepts a custom description.""" + agent = Agent(name="researcher", description="Original") + tool = agent.as_tool(description="Custom description") + + assert tool.tool_spec["description"] == "Custom description" + + +def test_as_tool_defaults_description_when_agent_has_none(): + """Test that as_tool generates a default description when agent has none.""" + agent = Agent(name="researcher") + tool = agent.as_tool() + + assert tool.tool_spec["description"] == "Use the researcher agent as a tool by providing a natural language input" diff --git a/tests/strands/agent/test_agent_as_tool.py b/tests/strands/agent/test_agent_as_tool.py new file mode 100644 index 000000000..68128e6e5 --- /dev/null +++ b/tests/strands/agent/test_agent_as_tool.py @@ -0,0 +1,463 @@ +"""Tests for AgentAsTool - the agent-as-tool adapter.""" + +from unittest.mock import MagicMock + +import pytest + +from strands.agent import AgentAsTool +from strands.agent.agent_result import AgentResult +from strands.telemetry.metrics import EventLoopMetrics +from strands.types._events import AgentAsToolStreamEvent, ToolResultEvent, ToolStreamEvent + + +async def _mock_stream_async(result, intermediate_events=None): + """Helper that yields intermediate events then the final result event.""" + for event in intermediate_events or []: + yield event + yield {"result": result} + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.name = "test_agent" + agent.description = "A test agent" + return agent + + +@pytest.fixture +def fake_agent(): + """A real Agent instance for tests that need Agent-specific features.""" + from strands.agent.agent import Agent + + return Agent(name="fake_agent", callback_handler=None) + + +@pytest.fixture +def tool(mock_agent): + return AgentAsTool(mock_agent, name="test_agent", description="A test agent", preserve_context=True) + + +@pytest.fixture +def tool_use(): + return { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {"input": "hello"}, + } + + +@pytest.fixture +def agent_result(): + return AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "response text"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + + +# --- init --- + + +def test_init(mock_agent): + tool = AgentAsTool(mock_agent, name="my_tool", description="custom desc", preserve_context=True) + assert tool.tool_name == "my_tool" + assert tool._description == "custom desc" + assert tool.agent is mock_agent + + +def test_init_preserve_context_defaults_false(fake_agent): + tool = AgentAsTool(fake_agent, name="t", description="d") + assert tool._preserve_context is False + + +def test_init_preserve_context_true(mock_agent): + tool = AgentAsTool(mock_agent, name="t", description="d", preserve_context=True) + assert tool._preserve_context is True + + +# --- properties --- + + +def test_tool_properties(tool): + assert tool.tool_name == "test_agent" + assert tool.tool_type == "agent" + + spec = tool.tool_spec + assert spec["name"] == "test_agent" + assert spec["description"] == "A test agent" + + schema = spec["inputSchema"]["json"] + assert schema["type"] == "object" + assert "input" in schema["properties"] + assert schema["properties"]["input"]["type"] == "string" + assert schema["required"] == ["input"] + + props = tool.get_display_properties() + assert props["Agent"] == "test_agent" + assert props["Type"] == "agent" + + +# --- stream --- + + +@pytest.mark.asyncio +async def test_stream_success(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["text"] == "response text\n" + + +@pytest.mark.asyncio +async def test_stream_passes_input_to_agent(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("hello") + + +@pytest.mark.asyncio +async def test_stream_empty_input(tool, mock_agent, agent_result): + empty_tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": {}, + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(empty_tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("") + + +@pytest.mark.asyncio +async def test_stream_string_input(tool, mock_agent, agent_result): + tool_use = { + "toolUseId": "tool-123", + "name": "test_agent", + "input": "direct string", + } + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + async for _ in tool.stream(tool_use, {}): + pass + + mock_agent.stream_async.assert_called_once_with("direct string") + + +@pytest.mark.asyncio +async def test_stream_error(tool, mock_agent, tool_use): + mock_agent.stream_async.side_effect = RuntimeError("boom") + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "boom" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_propagates_tool_use_id(tool, mock_agent, tool_use, agent_result): + mock_agent.stream_async.return_value = _mock_stream_async(agent_result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["toolUseId"] == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_forwards_intermediate_events(tool, mock_agent, tool_use, agent_result): + intermediate = [{"data": "partial"}, {"data": "more"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 2 + assert stream_events[0]["tool_stream_event"]["data"]["data"] == "partial" + assert stream_events[1]["tool_stream_event"]["data"]["data"] == "more" + assert stream_events[0].agent_as_tool is tool + assert stream_events[0].tool_use_id == "tool-123" + + +@pytest.mark.asyncio +async def test_stream_events_not_double_wrapped_by_executor(tool, mock_agent, tool_use, agent_result): + """AgentAsToolStreamEvent is a ToolStreamEvent subclass, so the executor should pass it through directly.""" + intermediate = [{"data": "chunk"}] + mock_agent.stream_async.return_value = _mock_stream_async(agent_result, intermediate) + + events = [event async for event in tool.stream(tool_use, {})] + + stream_events = [e for e in events if isinstance(e, AgentAsToolStreamEvent)] + assert len(stream_events) == 1 + + event = stream_events[0] + # It's a ToolStreamEvent (so the executor yields it directly) + assert isinstance(event, ToolStreamEvent) + # But it's specifically an AgentAsToolStreamEvent (not re-wrapped) + assert type(event) is AgentAsToolStreamEvent + # And it references the originating AgentAsTool + assert event.agent_as_tool is tool + + +@pytest.mark.asyncio +async def test_stream_no_result_yields_error(tool, mock_agent, tool_use): + async def _empty_stream(): + return + yield # noqa: RET504 - make it an async generator + + mock_agent.stream_async.return_value = _empty_stream() + + events = [event async for event in tool.stream(tool_use, {})] + + assert len(events) == 1 + assert events[0]["tool_result"]["status"] == "error" + assert "did not produce a result" in events[0]["tool_result"]["content"][0]["text"] + + +@pytest.mark.asyncio +async def test_stream_structured_output(tool, mock_agent, tool_use): + from pydantic import BaseModel + + class MyOutput(BaseModel): + answer: str + + structured = MyOutput(answer="42") + result = AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ignored"}]}, + metrics=EventLoopMetrics(), + state={}, + structured_output=structured, + ) + mock_agent.stream_async.return_value = _mock_stream_async(result) + + events = [event async for event in tool.stream(tool_use, {})] + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert result_events[0]["tool_result"]["status"] == "success" + assert result_events[0]["tool_result"]["content"][0]["json"] == {"answer": "42"} + + +# --- preserve_context --- + + +@pytest.mark.asyncio +async def test_stream_resets_to_initial_state_when_preserve_context_false(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "initial"}]}] + fake_agent.state.set("counter", 0) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + # Mutate agent state as if a previous invocation happened + fake_agent.messages.append({"role": "assistant", "content": [{"text": "reply"}]}) + fake_agent.state.set("counter", 5) + + # Mock stream_async so we don't need a real model + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "initial"}]}] + assert fake_agent.state.get("counter") == 0 + + +@pytest.mark.asyncio +async def test_stream_resets_on_every_invocation(fake_agent): + """Each call should reset to the same initial snapshot, not to the previous call's state.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "seed"}]}] + fake_agent.state.set("count", 1) + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-1", + "name": "fake_agent", + "input": {"input": "first"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + fake_agent.messages.append({"role": "assistant", "content": [{"text": "added"}]}) + fake_agent.state.set("count", 99) + + tool_use["toolUseId"] = "tool-2" + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "seed"}]}] + assert fake_agent.state.get("count") == 1 + + +@pytest.mark.asyncio +async def test_stream_initial_snapshot_is_deep_copy(fake_agent): + """Mutating the agent's messages after construction should not affect the snapshot.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "original"}]}] + + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages[0]["content"][0]["text"] = "mutated" + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [{"role": "user", "content": [{"text": "original"}]}] + + +@pytest.mark.asyncio +async def test_stream_resets_empty_initial_state_when_preserve_context_false(fake_agent): + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=False) + + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert fake_agent.messages == [] + assert fake_agent.state.get() == {} + + +@pytest.mark.asyncio +async def test_stream_resets_context_by_default(fake_agent): + """Default preserve_context=False means each invocation starts fresh.""" + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc") + + # Mutate after construction + fake_agent.messages.append({"role": "assistant", "content": [{"text": "extra"}]}) + fake_agent.state.set("key", "changed") + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + # Should reset to construction-time snapshot + assert fake_agent.messages == [{"role": "user", "content": [{"text": "old"}]}] + assert fake_agent.state.get("key") == "value" + + +@pytest.mark.asyncio +async def test_stream_preserves_context_when_explicitly_true(fake_agent): + fake_agent.messages = [{"role": "user", "content": [{"text": "old"}]}] + fake_agent.state.set("key", "value") + tool = AgentAsTool(fake_agent, name="fake_agent", description="desc", preserve_context=True) + + fake_agent.stream_async = lambda prompt, **kw: _mock_stream_async( + AgentResult( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "ok"}]}, + metrics=EventLoopMetrics(), + state={}, + ) + ) + + tool_use = { + "toolUseId": "tool-123", + "name": "fake_agent", + "input": {"input": "hello"}, + } + + async for _ in tool.stream(tool_use, {}): + pass + + assert len(fake_agent.messages) >= 1 + assert fake_agent.state.get("key") == "value" + + +def test_preserve_context_false_requires_agent_instance(): + """Default preserve_context=False should raise TypeError for non-Agent instances.""" + + class _NotAnAgent: + name = "not_agent" + + async def invoke_async(self, prompt=None, **kwargs): + pass + + def __call__(self, prompt=None, **kwargs): + pass + + def stream_async(self, prompt=None, **kwargs): + pass + + with pytest.raises(TypeError, match="requires an Agent instance"): + AgentAsTool(_NotAnAgent(), name="bad", description="desc") diff --git a/tests/strands/types/test__events.py b/tests/strands/types/test__events.py index 6163faeb6..48465e1f6 100644 --- a/tests/strands/types/test__events.py +++ b/tests/strands/types/test__events.py @@ -6,6 +6,7 @@ from strands.telemetry import EventLoopMetrics from strands.types._events import ( + AgentAsToolStreamEvent, AgentResultEvent, CitationStreamEvent, EventLoopStopEvent, @@ -465,3 +466,39 @@ def test_event_inheritance(self): assert hasattr(event, "is_callback_event") assert hasattr(event, "as_dict") assert hasattr(event, "prepare") + + +class TestAgentAsToolStreamEvent: + """Tests for AgentAsToolStreamEvent.""" + + def test_initialization(self): + """Test AgentAsToolStreamEvent initialization with agent-tool reference.""" + tool_use: ToolUse = { + "toolUseId": "agent_tool_123", + "name": "researcher", + "input": {"input": "hello"}, + } + agent_event = {"data": "partial response"} + mock_agent_as_tool = MagicMock() + mock_agent_as_tool.tool_name = "researcher" + + event = AgentAsToolStreamEvent(tool_use, agent_event, mock_agent_as_tool) + + assert event["tool_stream_event"]["tool_use"] == tool_use + assert event["tool_stream_event"]["data"] == agent_event + assert event.agent_as_tool is mock_agent_as_tool + assert event.tool_use_id == "agent_tool_123" + + def test_is_tool_stream_event_subclass(self): + """Test that AgentAsToolStreamEvent is a ToolStreamEvent subclass.""" + tool_use: ToolUse = { + "toolUseId": "id_123", + "name": "tool", + "input": {}, + } + mock_agent_as_tool = MagicMock() + event = AgentAsToolStreamEvent(tool_use, {}, mock_agent_as_tool) + + assert isinstance(event, ToolStreamEvent) + assert isinstance(event, TypedEvent) + assert type(event) is AgentAsToolStreamEvent diff --git a/tests_integ/test_agent_as_tool.py b/tests_integ/test_agent_as_tool.py new file mode 100644 index 000000000..a808fcd23 --- /dev/null +++ b/tests_integ/test_agent_as_tool.py @@ -0,0 +1,36 @@ +import pytest + +from strands import Agent, tool + + +@tool +def get_tiger_height() -> int: + """Returns the height of a tiger in centimeters.""" + return 100 + + +@pytest.mark.asyncio +async def test_stream_async_with_agent_tool(): + inner_agent = Agent( + name="myAgentTool", + description="An agent tool knowledgeable about tigers", + tools=[get_tiger_height], + ) + agent_tool = inner_agent.as_tool() + agent = Agent( + name="myOtherAgent", + tools=[agent_tool], + ) + + result = await agent.invoke_async( + prompt="Invoke the myAgentTool and ask about the height of tigers.", + ) + + # Outer agent completed and called the agent tool + assert result.stop_reason == "end_turn" + assert "myAgentTool" in result.metrics.tool_metrics + assert result.metrics.tool_metrics["myAgentTool"].success_count >= 1 + + # Inner agent called get_tiger_height + assert "get_tiger_height" in inner_agent.event_loop_metrics.tool_metrics + assert inner_agent.event_loop_metrics.tool_metrics["get_tiger_height"].success_count >= 1