|
46 | 46 | @HookPositions(HookPosition.AFTER_MODEL) |
47 | 47 | public class HumanInTheLoopHook extends ModelHook implements AsyncNodeActionWithConfig, InterruptableAction { |
48 | 48 | private static final Logger log = LoggerFactory.getLogger(HumanInTheLoopHook.class); |
49 | | - |
| 49 | + public static final String HITL_NODE_NAME = "HITL"; |
50 | 50 | private Map<String, ToolConfig> approvalOn; |
51 | 51 |
|
52 | 52 | private HumanInTheLoopHook(Builder builder) { |
@@ -209,39 +209,67 @@ private Optional<InterruptionMetadata> buildInterruptionMetadata(OverAllState st |
209 | 209 | return needsInterruption ? Optional.of(builder.build()) : Optional.empty(); |
210 | 210 | } |
211 | 211 |
|
212 | | - private boolean validateFeedback(InterruptionMetadata feedback, List<AssistantMessage.ToolCall> toolCalls) { |
213 | | - if (feedback == null || feedback.toolFeedbacks() == null || feedback.toolFeedbacks().isEmpty()) { |
214 | | - return false; |
215 | | - } |
| 212 | + private boolean validateFeedback(InterruptionMetadata feedback, List<AssistantMessage.ToolCall> toolCalls) { |
| 213 | + if (feedback == null || feedback.toolFeedbacks() == null || feedback.toolFeedbacks().isEmpty()) { |
| 214 | + return false; |
| 215 | + } |
216 | 216 |
|
217 | | - List<InterruptionMetadata.ToolFeedback> toolFeedbacks = feedback.toolFeedbacks(); |
| 217 | + List<InterruptionMetadata.ToolFeedback> toolFeedbacks = feedback.toolFeedbacks(); |
218 | 218 |
|
219 | | - // 1. Ensure each ToolFeedback's result is not empty |
220 | | - for (InterruptionMetadata.ToolFeedback toolFeedback : toolFeedbacks) { |
221 | | - if (toolFeedback.getResult() == null) { |
222 | | - log.warn("No tool feedback provided, continue to wait for human input."); |
223 | | - return false; |
224 | | - } |
225 | | - } |
| 219 | + // 1. Tool calls in this step that actually require human approval (names defined in approvalOn) |
| 220 | + List<AssistantMessage.ToolCall> toolCallsNeedingApproval = toolCalls.stream() |
| 221 | + .filter(tc -> approvalOn.containsKey(tc.name())) |
| 222 | + .toList(); |
226 | 223 |
|
227 | | - // 2. Ensure ToolFeedback count matches approvalOn count and all names are in approvalOn |
228 | | - if (toolFeedbacks.size() != toolCalls.size()) { |
229 | | - log.warn("Only {} tool feedbacks provided, but {} tool calls need approval, continue to wait for human input.", toolFeedbacks.size(), toolCalls.size()); |
230 | | - return false; |
231 | | - } |
232 | | - for (InterruptionMetadata.ToolFeedback toolFeedback : toolFeedbacks) { |
233 | | - if (!approvalOn.containsKey(toolFeedback.getName())) { |
234 | | - log.warn("Tool feedback for tool {} is not expected(not in the tool executing list), continue to wait for human input.", toolFeedback.getName()); |
235 | | - return false; |
236 | | - } |
237 | | - } |
| 224 | + // If no tool calls in this step require human approval, validation is trivially satisfied |
| 225 | + if (toolCallsNeedingApproval.isEmpty()) { |
| 226 | + return true; |
| 227 | + } |
238 | 228 |
|
239 | | - return true; |
240 | | - } |
| 229 | + // 2. For each tool call requiring approval, ensure corresponding feedback exists and its result is non-null |
| 230 | + for (AssistantMessage.ToolCall call : toolCallsNeedingApproval) { |
| 231 | + InterruptionMetadata.ToolFeedback matchedFeedback = toolFeedbacks.stream() |
| 232 | + .filter(tf -> tf.getName().equals(call.name()) |
| 233 | + // Also validate id if ToolFeedback contains id field |
| 234 | + && call.id().equals(tf.getId())) |
| 235 | + .findFirst() |
| 236 | + .orElse(null); |
| 237 | + |
| 238 | + if (matchedFeedback == null) { |
| 239 | + log.warn("Missing feedback for tool {} (id={}); waiting for human input.", |
| 240 | + call.name(), call.id()); |
| 241 | + return false; |
| 242 | + } |
| 243 | + |
| 244 | + // Ensure the feedback result is provided |
| 245 | + if (matchedFeedback.getResult() == null) { |
| 246 | + log.warn("Feedback result for tool {} (id={}) is null; waiting for human input.", |
| 247 | + call.name(), call.id()); |
| 248 | + return false; |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + // 3. Optional: log unexpected or extra feedback entries that do not match any pending approval tool |
| 253 | + for (InterruptionMetadata.ToolFeedback tf : toolFeedbacks) { |
| 254 | + boolean matched = toolCallsNeedingApproval.stream() |
| 255 | + .anyMatch(call -> call.name().equals(tf.getName()) && call.id().equals(tf.getId())); |
| 256 | + if (!matched) { |
| 257 | + log.warn("Ignoring unexpected tool feedback: name={}, id={}", tf.getName(), tf.getId()); |
| 258 | + } |
| 259 | + } |
| 260 | + |
| 261 | + |
| 262 | + |
| 263 | + |
| 264 | + |
| 265 | + |
| 266 | + |
| 267 | + return true; |
| 268 | + } |
241 | 269 |
|
242 | 270 | @Override |
243 | 271 | public String getName() { |
244 | | - return "HITL"; |
| 272 | + return HITL_NODE_NAME; |
245 | 273 | } |
246 | 274 |
|
247 | 275 | @Override |
|
0 commit comments