WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content

Commit 46b93b4

Browse files
authored
fix[studio]: studio optimization (#3367)
1 parent c5b831d commit 46b93b4

File tree

406 files changed

+500
-151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

406 files changed

+500
-151
lines changed

examples/chatbot/src/main/java/com/alibaba/cloud/ai/examples/chatbot/AgentStaticLoader.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.alibaba.cloud.ai.agent.studio.loader.AgentLoader;
2020
import com.alibaba.cloud.ai.graph.GraphRepresentation;
21+
import com.alibaba.cloud.ai.graph.agent.Agent;
2122
import com.alibaba.cloud.ai.graph.agent.BaseAgent;
2223
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
2324

@@ -48,9 +49,9 @@
4849
@Component
4950
class AgentStaticLoader implements AgentLoader {
5051

51-
private final Map<String, BaseAgent> agents = new ConcurrentHashMap<>();
52+
private final Map<String, Agent> agents = new ConcurrentHashMap<>();
5253

53-
public AgentStaticLoader(BaseAgent agent) {
54+
public AgentStaticLoader(Agent agent) {
5455

5556
GraphRepresentation representation = agent.getAndCompileGraph().stateGraph.getGraph(GraphRepresentation.Type.PLANTUML);
5657
System.out.println(representation.content());
@@ -65,12 +66,12 @@ public List<String> listAgents() {
6566
}
6667

6768
@Override
68-
public BaseAgent loadAgent(String name) {
69+
public Agent loadAgent(String name) {
6970
if (name == null || name.trim().isEmpty()) {
7071
throw new IllegalArgumentException("Agent name cannot be null or empty");
7172
}
7273

73-
BaseAgent agent = agents.get(name);
74+
Agent agent = agents.get(name);
7475
if (agent == null) {
7576
throw new NoSuchElementException("Agent not found: " + name);
7677
}

examples/chatbot/src/main/java/com/alibaba/cloud/ai/examples/chatbot/ChatbotAgent.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
package com.alibaba.cloud.ai.examples.chatbot;
1717

1818
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
19+
import com.alibaba.cloud.ai.graph.agent.hook.shelltool.ShellToolAgentHook;
1920
import com.alibaba.cloud.ai.graph.agent.tools.ShellTool;
2021
import com.alibaba.cloud.ai.graph.agent.extension.tools.filesystem.ReadFileTool;
22+
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
2123

2224
import org.springframework.ai.chat.model.ChatModel;
2325
import org.springframework.ai.chat.model.ToolContext;
@@ -50,12 +52,16 @@ public class ChatbotAgent {
5052
public ReactAgent chatbotReactAgent(ChatModel chatModel,
5153
ToolCallback executeShellCommand,
5254
ToolCallback executePythonCode,
53-
ToolCallback viewTextFile) {
55+
ToolCallback viewTextFile,
56+
MemorySaver memorySaver) {
5457
return ReactAgent.builder()
5558
.name("SAA")
5659
.model(chatModel)
5760
.instruction(INSTRUCTION)
5861
.enableLogging(true)
62+
.saver(memorySaver)
63+
// Must set ShellToolAgentHook to manage shell session lifecycle for executeShellCommand
64+
.hooks(ShellToolAgentHook.builder().shellToolName(executeShellCommand.getToolDefinition().name()).build())
5965
.tools(
6066
executeShellCommand,
6167
executePythonCode,
@@ -64,6 +70,11 @@ public ReactAgent chatbotReactAgent(ChatModel chatModel,
6470
.build();
6571
}
6672

73+
@Bean
74+
public MemorySaver memorySaver() {
75+
return new MemorySaver();
76+
}
77+
6778
// Tool: execute_shell_command
6879
@Bean
6980
public ToolCallback executeShellCommand() {

examples/deepresearch/src/main/java/com/alibaba/cloud/ai/examples/deepresearch/AgentStaticLoader.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.alibaba.cloud.ai.agent.studio.loader.AgentLoader;
2020
import com.alibaba.cloud.ai.graph.GraphRepresentation;
21+
import com.alibaba.cloud.ai.graph.agent.Agent;
2122
import com.alibaba.cloud.ai.graph.agent.BaseAgent;
2223
import com.alibaba.cloud.ai.graph.agent.ReactAgent;
2324

@@ -48,7 +49,7 @@
4849
@Component
4950
class AgentStaticLoader implements AgentLoader {
5051

51-
private Map<String, BaseAgent> agents = new ConcurrentHashMap<>();
52+
private Map<String, Agent> agents = new ConcurrentHashMap<>();
5253

5354
// public AgentStaticLoader(){}
5455

@@ -73,12 +74,12 @@ public List<String> listAgents() {
7374
}
7475

7576
@Override
76-
public BaseAgent loadAgent(String name) {
77+
public Agent loadAgent(String name) {
7778
if (name == null || name.trim().isEmpty()) {
7879
throw new IllegalArgumentException("Agent name cannot be null or empty");
7980
}
8081

81-
BaseAgent agent = agents.get(name);
82+
Agent agent = agents.get(name);
8283
if (agent == null) {
8384
throw new NoSuchElementException("Agent not found: " + name);
8485
}

examples/deepresearch/src/main/java/com/alibaba/cloud/ai/examples/deepresearch/DeepResearchAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ private Interceptor subAgentAsInterceptors(List<ToolCallback> toolsFromMcp) {
175175
patchToolCallsInterceptor,
176176
largeResultEvictionInterceptor
177177
)
178-
.defaultHooks(summarizationHook, toolCallLimitHook)
178+
.defaultHooks(humanInTheLoopHook, summarizationHook, toolCallLimitHook)
179179
.addSubAgent(researchAgent)
180180
.includeGeneralPurpose(true)
181181
.addSubAgent(critiqueAgent);

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/hook/hip/HumanInTheLoopHook.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ private Optional<InterruptionMetadata> buildInterruptionMetadata(OverAllState st
202202
.name(toolCall.name()).description(content).arguments(toolCall.arguments()).build())
203203
.build();
204204
needsInterruption = true;
205+
} else {
206+
builder.addToolsAutomaticallyApproved(toolCall);
205207
}
206208
}
207209
return needsInterruption ? Optional.of(builder.build()) : Optional.empty();

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/hook/shelltool/ShellToolAgentHook.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class ShellToolAgentHook extends AgentHook implements ToolInjection {
4242
private static final Logger log = LoggerFactory.getLogger(ShellToolAgentHook.class);
4343

4444
private ShellTool shellTool;
45+
private String shellToolName;
4546

4647
/**
4748
* Private constructor for builder pattern.
@@ -53,8 +54,9 @@ private ShellToolAgentHook() {
5354
* Private constructor with ShellTool for builder pattern.
5455
* @param shellTool the ShellTool instance to use
5556
*/
56-
private ShellToolAgentHook(ShellTool shellTool) {
57+
private ShellToolAgentHook(ShellTool shellTool, String shellToolName) {
5758
this.shellTool = shellTool;
59+
this.shellToolName = shellToolName;
5860
}
5961

6062
/**
@@ -165,7 +167,7 @@ private ShellTool extractShellTool(ToolCallback toolCallback) {
165167
@Override
166168
public String getRequiredToolName() {
167169
// Match by tool name "shell"
168-
return "shell";
170+
return shellToolName;
169171
}
170172

171173
@Override
@@ -188,6 +190,7 @@ protected ShellTool getShellTool() {
188190
*/
189191
public static class Builder {
190192
private ShellTool shellTool;
193+
private String shellToolName;
191194

192195
/**
193196
* Set the ShellTool instance.
@@ -199,12 +202,17 @@ public Builder shellTool(ShellTool shellTool) {
199202
return this;
200203
}
201204

205+
public Builder shellToolName(String shellToolName) {
206+
this.shellToolName = shellToolName;
207+
return this;
208+
}
209+
202210
/**
203211
* Build the ShellToolAgentHook instance.
204212
* @return a new ShellToolAgentHook instance
205213
*/
206214
public ShellToolAgentHook build() {
207-
return new ShellToolAgentHook(this.shellTool);
215+
return new ShellToolAgentHook(this.shellTool, this.shellToolName);
208216
}
209217
}
210218

spring-ai-alibaba-agent-framework/src/main/java/com/alibaba/cloud/ai/graph/agent/tools/ShellSessionManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ private void doCleanup(RunnableConfig config) {
149149
public CommandResult executeCommand(String command, RunnableConfig config) {
150150
ShellSession session = (ShellSession) config.context().get(SESSION_INSTANCE_CONTEXT_KEY);
151151
if (session == null) {
152-
throw new IllegalStateException("Shell session not initialized. Call initialize() first.");
152+
throw new IllegalStateException("Shell session not initialized. Call initialize() first, you might need to enable ShellToolAgentHook to enable shell session management.");
153153
}
154154

155155
log.info("Executing shell command: {}", command);

spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/action/InterruptionMetadata.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import com.alibaba.cloud.ai.graph.OverAllState;
2121
import com.alibaba.cloud.ai.graph.utils.CollectionsUtils;
2222

23+
import org.springframework.ai.chat.messages.AssistantMessage;
24+
2325
import java.util.ArrayList;
2426
import java.util.List;
2527
import java.util.Map;
@@ -37,12 +39,19 @@ public final class InterruptionMetadata extends NodeOutput implements HasMetadat
3739

3840
private final Map<String, Object> metadata;
3941

42+
private List<AssistantMessage.ToolCall> toolsAutomaticallyApproved;
43+
4044
private List<ToolFeedback> toolFeedbacks;
4145

4246
private InterruptionMetadata(Builder builder) {
4347
super(builder.nodeId, builder.state);
4448
this.metadata = builder.metadata();
4549
this.toolFeedbacks = new ArrayList<>(builder.toolFeedbacks);
50+
if (builder.toolsAutomaticallyApproved != null) {
51+
this.toolsAutomaticallyApproved = builder.toolsAutomaticallyApproved;
52+
} else {
53+
this.toolsAutomaticallyApproved = new ArrayList<>();
54+
}
4655
}
4756

4857
/**
@@ -69,6 +78,10 @@ public List<ToolFeedback> toolFeedbacks() {
6978
return toolFeedbacks;
7079
}
7180

81+
public List<AssistantMessage.ToolCall> getToolsAutomaticallyApproved() {
82+
return toolsAutomaticallyApproved;
83+
}
84+
7285
@Override
7386
public String toString() {
7487
return String.format("""
@@ -92,9 +105,16 @@ public static Builder builder() {
92105
}
93106

94107
public static Builder builder(InterruptionMetadata interruptionMetadata) {
95-
return new Builder(interruptionMetadata.metadata().orElse(Map.of()))
108+
Builder builder = new Builder(interruptionMetadata.metadata().orElse(Map.of()))
96109
.nodeId(interruptionMetadata.node())
97110
.state(interruptionMetadata.state());
111+
if (interruptionMetadata.getToolsAutomaticallyApproved() != null) {
112+
builder.toolsAutomaticallyApproved(interruptionMetadata.getToolsAutomaticallyApproved());
113+
}
114+
// if (interruptionMetadata.toolFeedbacks() != null && !interruptionMetadata.toolFeedbacks().isEmpty()) {
115+
// builder.toolFeedbacks(interruptionMetadata.toolFeedbacks());
116+
// }
117+
return builder;
98118
}
99119

100120
/**
@@ -104,6 +124,8 @@ public static Builder builder(InterruptionMetadata interruptionMetadata) {
104124
public static class Builder extends HasMetadata.Builder<Builder> {
105125
List<ToolFeedback> toolFeedbacks;
106126

127+
List<AssistantMessage.ToolCall> toolsAutomaticallyApproved;
128+
107129
String nodeId;
108130

109131
OverAllState state;
@@ -147,6 +169,19 @@ public Builder toolFeedbacks(List<ToolFeedback> toolFeedbacks) {
147169
return this;
148170
}
149171

172+
public Builder addToolsAutomaticallyApproved(AssistantMessage.ToolCall toolCall) {
173+
if (this.toolsAutomaticallyApproved == null) {
174+
this.toolsAutomaticallyApproved = new ArrayList<>();
175+
}
176+
this.toolsAutomaticallyApproved.add(toolCall);
177+
return this;
178+
}
179+
180+
public Builder toolsAutomaticallyApproved(List<AssistantMessage.ToolCall> toolsAutomaticallyApproved) {
181+
this.toolsAutomaticallyApproved = new ArrayList<>(toolsAutomaticallyApproved);
182+
return this;
183+
}
184+
150185
/**
151186
* Builds the {@link InterruptionMetadata} instance.
152187
* @return a new, immutable {@link InterruptionMetadata} instance

spring-ai-alibaba-studio/src/main/java/com/alibaba/cloud/ai/agent/studio/controller/ExecutionController.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import com.alibaba.cloud.ai.graph.NodeOutput;
2626
import com.alibaba.cloud.ai.graph.RunnableConfig;
2727
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
28-
import com.alibaba.cloud.ai.graph.agent.BaseAgent;
28+
import com.alibaba.cloud.ai.graph.agent.Agent;
2929
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
3030
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
3131

@@ -171,7 +171,7 @@ public Flux<ServerSentEvent<String>> agentRunSse(@RequestBody AgentRunRequest re
171171
}
172172

173173
try {
174-
BaseAgent agent = agentLoader.loadAgent(request.appName);
174+
Agent agent = agentLoader.loadAgent(request.appName);
175175
RunnableConfig runnableConfig = RunnableConfig.builder()
176176
.threadId(request.threadId)
177177
.addMetadata("user_id", request.userId)
@@ -205,7 +205,7 @@ public Flux<ServerSentEvent<String>> agentResumeSse(@RequestBody AgentResumeRequ
205205
}
206206

207207
try {
208-
BaseAgent agent = agentLoader.loadAgent(request.appName);
208+
Agent agent = agentLoader.loadAgent(request.appName);
209209

210210
InterruptionMetadata.Builder metadataBuilder = InterruptionMetadata.builder();
211211

@@ -247,7 +247,7 @@ public Flux<ServerSentEvent<String>> agentResumeSse(@RequestBody AgentResumeRequ
247247
}
248248

249249
@NotNull
250-
private Flux<ServerSentEvent<String>> executeAgent(UserMessage userMessage, BaseAgent agent, RunnableConfig runnableConfig) throws GraphRunnerException {
250+
private Flux<ServerSentEvent<String>> executeAgent(UserMessage userMessage, Agent agent, RunnableConfig runnableConfig) throws GraphRunnerException {
251251

252252
Flux<NodeOutput> agentStream;
253253

0 commit comments

Comments
 (0)