WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit ab420cc

Browse files
authored
Overhaul tools (autogenhub#53)
* Overhaul tools * add a simple test * mypy fixes * format
1 parent 837c388 commit ab420cc

36 files changed

+450
-228
lines changed

examples/futures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33

44
from agnext.application import SingleThreadedAgentRuntime
5-
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
5+
from agnext.components import TypeRoutedAgent, message_handler
66
from agnext.core import Agent, AgentRuntime, CancellationToken
77

88

examples/orchestrator.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import logging
55
import os
6-
from typing import Annotated, Callable
6+
from typing import Callable
77

88
import openai
99
from agnext.application import (
@@ -13,19 +13,44 @@
1313
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
1414
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
1515
from agnext.chat.types import TextMessage
16-
from agnext.components.function_executor._impl.in_process_function_executor import (
17-
InProcessFunctionExecutor,
18-
)
1916
from agnext.components.models import OpenAI, SystemMessage
20-
from agnext.core import Agent, AgentRuntime
17+
from agnext.components.tools import BaseTool
18+
from agnext.core import Agent, AgentRuntime, CancellationToken
2119
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
22-
from tavily import TavilyClient
20+
from pydantic import BaseModel, Field
21+
from tavily import TavilyClient # type: ignore
2322
from typing_extensions import Any, override
2423

2524
logging.basicConfig(level=logging.WARNING)
2625
logging.getLogger("agnext").setLevel(logging.DEBUG)
2726

2827

28+
class SearchQuery(BaseModel):
29+
query: str = Field(description="The search query.")
30+
31+
32+
class SearchResult(BaseModel):
33+
result: str = Field(description="The search results.")
34+
35+
36+
class SearchTool(BaseTool[SearchQuery, SearchResult]):
37+
def __init__(self) -> None:
38+
super().__init__(
39+
args_type=SearchQuery,
40+
return_type=SearchResult,
41+
name="search",
42+
description="Search the web.",
43+
)
44+
45+
async def run(self, args: SearchQuery, cancellation_token: CancellationToken) -> SearchResult:
46+
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) # type: ignore
47+
result = await asyncio.create_task(client.search(args.query)) # type: ignore
48+
if result:
49+
return SearchResult(result=json.dumps(result, indent=2, ensure_ascii=False))
50+
51+
return SearchResult(result="No results found.")
52+
53+
2954
class LoggingHandler(DefaultInterventionHandler): # type: ignore
3055
send_color = "\033[31m"
3156
response_color = "\033[34m"
@@ -76,16 +101,6 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
76101
thread_id=tester_oai_thread.id,
77102
)
78103

79-
def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The search results."]:
80-
"""Search the web."""
81-
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
82-
result = client.search(query) # type: ignore
83-
if result:
84-
return json.dumps(result, indent=2, ensure_ascii=False) # type: ignore
85-
return "No results found."
86-
87-
function_executor = InProcessFunctionExecutor(functions=[search])
88-
89104
product_manager = ChatCompletionAgent(
90105
name="ProductManager",
91106
description="A product manager that performs research and comes up with specs.",
@@ -95,7 +110,7 @@ def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The se
95110
SystemMessage("You can use the search tool to find information on the web."),
96111
],
97112
model_client=OpenAI(model="gpt-4-turbo"),
98-
function_executor=function_executor,
113+
tools=[SearchTool()],
99114
)
100115

101116
planner = ChatCompletionAgent(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ disallow_untyped_decorators = true
7575
disallow_any_unimported = true
7676

7777
[tool.pyright]
78-
include = ["src", "tests"]
78+
include = ["src", "tests", "examples"]
7979
typeCheckingMode = "strict"
8080
reportUnnecessaryIsInstance = false
8181
reportMissingTypeStubs = false

src/agnext/chat/agents/chat_completion_agent.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import json
3-
from typing import Any, Coroutine, Dict, List, Mapping, Tuple
3+
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple
44

55
from agnext.chat.agents.base import BaseChatAgent
66
from agnext.chat.types import (
@@ -12,13 +12,13 @@
1212
TextMessage,
1313
)
1414
from agnext.chat.utils import convert_messages_to_llm_messages
15-
from agnext.components.function_executor import FunctionExecutor
16-
from agnext.components.models import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
17-
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
18-
from agnext.components.types import (
15+
from agnext.components import (
1916
FunctionCall,
20-
FunctionSignature,
17+
TypeRoutedAgent,
18+
message_handler,
2119
)
20+
from agnext.components.models import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
21+
from agnext.components.tools import Tool
2222
from agnext.core import AgentRuntime, CancellationToken
2323

2424

@@ -30,13 +30,13 @@ def __init__(
3030
runtime: AgentRuntime,
3131
system_messages: List[SystemMessage],
3232
model_client: ModelClient,
33-
function_executor: FunctionExecutor | None = None,
33+
tools: Sequence[Tool] = [],
3434
) -> None:
3535
super().__init__(name, description, runtime)
3636
self._system_messages = system_messages
3737
self._client = model_client
3838
self._chat_messages: List[Message] = []
39-
self._function_executor = function_executor
39+
self._tools = tools
4040

4141
@message_handler()
4242
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
@@ -52,23 +52,18 @@ async def on_reset(self, message: Reset, cancellation_token: CancellationToken)
5252
async def on_respond_now(
5353
self, message: RespondNow, cancellation_token: CancellationToken
5454
) -> TextMessage | FunctionCallMessage:
55-
# Get function signatures.
56-
function_signatures: List[FunctionSignature] = (
57-
[] if self._function_executor is None else list(self._function_executor.function_signatures)
58-
)
59-
6055
# Get a response from the model.
6156
response = await self._client.create(
6257
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
63-
functions=function_signatures,
58+
tools=self._tools,
6459
json_output=message.response_format == ResponseFormat.json_object,
6560
)
6661

6762
# If the agent has function executor, and the response is a list of
6863
# tool calls, iterate with itself until we get a response that is not a
6964
# list of tool calls.
7065
while (
71-
self._function_executor is not None
66+
len(self._tools) > 0
7267
and isinstance(response.content, list)
7368
and all(isinstance(x, FunctionCall) for x in response.content)
7469
):
@@ -81,7 +76,7 @@ async def on_respond_now(
8176
# Make an assistant message from the response.
8277
response = await self._client.create(
8378
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
84-
functions=function_signatures,
79+
tools=self._tools,
8580
json_output=message.response_format == ResponseFormat.json_object,
8681
)
8782

@@ -105,8 +100,8 @@ async def on_respond_now(
105100
async def on_tool_call_message(
106101
self, message: FunctionCallMessage, cancellation_token: CancellationToken
107102
) -> FunctionExecutionResultMessage:
108-
if self._function_executor is None:
109-
raise ValueError("Function executor is not set.")
103+
if len(self._tools) == 0:
104+
raise ValueError("No tools available")
110105

111106
# Add a tool call message.
112107
self._chat_messages.append(message)
@@ -127,7 +122,9 @@ async def on_tool_call_message(
127122
)
128123
continue
129124
# Execute the function.
130-
future = self.execute_function(function_call.name, arguments, function_call.id)
125+
future = self.execute_function(
126+
function_call.name, arguments, function_call.id, cancellation_token=cancellation_token
127+
)
131128
# Append the async result.
132129
execution_futures.append(future)
133130
if execution_futures:
@@ -146,14 +143,25 @@ async def on_tool_call_message(
146143
# Return the results.
147144
return tool_call_result_msg
148145

149-
async def execute_function(self, name: str, args: Dict[str, Any], call_id: str) -> Tuple[str, str]:
150-
if self._function_executor is None:
151-
raise ValueError("Function executor is not set.")
146+
async def execute_function(
147+
self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken
148+
) -> Tuple[str, str]:
149+
# Find tool
150+
tool = next((t for t in self._tools if t.name == name), None)
151+
if tool is None:
152+
raise ValueError(f"Tool {name} not found.")
152153
try:
153-
result = await self._function_executor.execute_function(name, args)
154+
result = await tool.run_json(args, cancellation_token)
155+
result_json_or_str = result.model_dump()
156+
if isinstance(result, dict):
157+
result_str = json.dumps(result_json_or_str)
158+
elif isinstance(result_json_or_str, str):
159+
result_str = result_json_or_str
160+
else:
161+
raise ValueError(f"Unexpected result type: {type(result)}")
154162
except Exception as e:
155-
result = f"Error: {str(e)}"
156-
return (result, call_id)
163+
result_str = f"Error: {str(e)}"
164+
return (result_str, call_id)
157165

158166
def save_state(self) -> Mapping[str, Any]:
159167
return {

src/agnext/chat/agents/oai_assistant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from agnext.chat.agents.base import BaseChatAgent
77
from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage
8-
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
8+
from agnext.components import TypeRoutedAgent, message_handler
99
from agnext.core import AgentRuntime, CancellationToken
1010

1111

src/agnext/chat/patterns/group_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, List, Protocol, Sequence
22

3-
from ...components.type_routed_agent import TypeRoutedAgent, message_handler
3+
from ...components import TypeRoutedAgent, message_handler
44
from ...core import AgentRuntime, CancellationToken
55
from ..agents.base import BaseChatAgent
66
from ..types import Reset, RespondNow, TextMessage

src/agnext/chat/patterns/orchestrator_chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing import Any, Sequence, Tuple
33

4-
from ...components.type_routed_agent import TypeRoutedAgent, message_handler
4+
from ...components import TypeRoutedAgent, message_handler
55
from ...core import AgentRuntime, CancellationToken
66
from ..agents.base import BaseChatAgent
77
from ..types import Reset, RespondNow, ResponseFormat, TextMessage

src/agnext/chat/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from enum import Enum
55
from typing import List, Union
66

7-
from agnext.components.image import Image
7+
from agnext.components import FunctionCall, Image
88
from agnext.components.models import FunctionExecutionResultMessage
9-
from agnext.components.types import FunctionCall
109

1110

1211
@dataclass(kw_only=True)

src/agnext/components/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
"""
22
The :mod:`agnext.components` module provides building blocks for creating single agents
33
"""
4+
5+
from ._image import Image
6+
from ._type_routed_agent import TypeRoutedAgent, message_handler
7+
from ._types import FunctionCall
8+
9+
__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]

src/agnext/components/_function_utils.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
from logging import getLogger
66
from typing import (
7+
Annotated,
78
Any,
89
Callable,
910
Dict,
@@ -15,10 +16,13 @@
1516
Type,
1617
TypeVar,
1718
Union,
19+
get_args,
20+
get_origin,
1821
)
1922

20-
from pydantic import BaseModel, Field
21-
from typing_extensions import Annotated, Literal
23+
from pydantic import BaseModel, Field, create_model # type: ignore
24+
from pydantic_core import PydanticUndefined
25+
from typing_extensions import Literal
2226

2327
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
2428

@@ -125,6 +129,18 @@ class ToolFunction(BaseModel):
125129
function: Annotated[Function, Field(description="Function under tool")]
126130

127131

132+
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
133+
# handles Annotated
134+
if hasattr(v, "__metadata__"):
135+
retval = v.__metadata__[0]
136+
if isinstance(retval, str):
137+
return retval
138+
else:
139+
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
140+
else:
141+
return k
142+
143+
128144
def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
129145
"""Get a JSON schema for a parameter as defined by the OpenAI API
130146
@@ -137,17 +153,6 @@ def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) ->
137153
A Pydanitc model for the parameter
138154
"""
139155

140-
def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
141-
# handles Annotated
142-
if hasattr(v, "__metadata__"):
143-
retval = v.__metadata__[0]
144-
if isinstance(retval, str):
145-
return retval
146-
else:
147-
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
148-
else:
149-
return k
150-
151156
schema = type2schema(v)
152157
if k in default_values:
153158
dv = default_values[k]
@@ -297,3 +302,48 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
297302
)
298303

299304
return model_dump(function)
305+
306+
307+
def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
308+
"""Normalize typing.Annotated types to the inner type."""
309+
if get_origin(type_hint) is Annotated:
310+
# Extract the inner type from Annotated
311+
return get_args(type_hint)[0] # type: ignore
312+
return type_hint
313+
314+
315+
def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
316+
fields: List[tuple[str, Any]] = []
317+
for name, param in sig.parameters.items():
318+
# This is handled externally
319+
if name == "cancellation_token":
320+
continue
321+
322+
if param.annotation is inspect.Parameter.empty:
323+
raise ValueError("No annotation")
324+
325+
type = normalize_annotated_type(param.annotation)
326+
description = type2description(name, param.annotation)
327+
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
328+
329+
fields.append((name, (type, Field(default=default_value, description=description))))
330+
331+
return create_model(name, *fields)
332+
333+
334+
def return_value_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
335+
if issubclass(BaseModel, sig.return_annotation):
336+
return sig.return_annotation # type: ignore
337+
338+
fields: List[tuple[str, Any]] = []
339+
for name, param in sig.return_annotation:
340+
if param.annotation is inspect.Parameter.empty:
341+
raise ValueError("No annotation")
342+
343+
type = normalize_annotated_type(param.annotation)
344+
description = type2description(name, param.annotation)
345+
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined
346+
347+
fields.append((name, (type, Field(default=default_value, description=description))))
348+
349+
return create_model(name, *fields)

0 commit comments

Comments
 (0)