Working tree-of-thought for agent selection
This commit is contained in:
parent
195d2daa59
commit
cb5012a077
|
@ -109,27 +109,27 @@ public class IoASystem {
|
||||||
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
|
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
|
||||||
new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
|
new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
|
||||||
Arrays.asList("travel", "family planning"),
|
Arrays.asList("travel", "family planning"),
|
||||||
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants")),
|
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants"))
|
||||||
|
|
||||||
new Task("task7", "Organize an international tech conference with virtual and in-person components",
|
// new Task("task7", "Organize an international tech conference with virtual and in-person components",
|
||||||
Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
|
// Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
|
||||||
Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates")),
|
// Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates")),
|
||||||
|
|
||||||
new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
|
// new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
|
||||||
Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
|
// Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
|
||||||
Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
|
// Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
|
||||||
|
|
||||||
new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
|
// new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
|
||||||
Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
|
// Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
|
||||||
Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
|
// Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
|
||||||
|
|
||||||
new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
|
// new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
|
||||||
Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
|
// Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
|
||||||
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
|
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
|
||||||
|
|
||||||
new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
|
// new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
|
||||||
Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
|
// Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
|
||||||
Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
|
// Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -16,4 +16,8 @@ public class AgentInfo {
|
||||||
this.capabilities = capabilities;
|
this.capabilities = capabilities;
|
||||||
this.tools = tools;
|
this.tools = tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<String> getCapabilities() {
|
||||||
|
return this.capabilities;
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -3,9 +3,7 @@ package com.ioa.agent;
|
||||||
import com.ioa.tool.ToolRegistry;
|
import com.ioa.tool.ToolRegistry;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.*;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
|
@ -32,12 +30,28 @@ public class AgentRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<AgentInfo> searchAgents(List<String> capabilities) {
|
public List<AgentInfo> searchAgents(List<String> capabilities) {
|
||||||
|
return searchAgents(capabilities, 1.0); // Default to exact match
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<AgentInfo> searchAgents(List<String> capabilities, double matchThreshold) {
|
||||||
return agents.values().stream()
|
return agents.values().stream()
|
||||||
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
|
.filter(agent -> calculateMatchScore(agent.getCapabilities(), capabilities) >= matchThreshold)
|
||||||
|
.sorted(Comparator.comparingDouble((AgentInfo agent) -> calculateMatchScore(agent.getCapabilities(), capabilities)).reversed())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<AgentInfo> searchAgentsPartial(List<String> capabilities) {
|
||||||
|
return searchAgents(capabilities, 0.0); // Return all agents with any matching capability
|
||||||
|
}
|
||||||
|
|
||||||
|
private double calculateMatchScore(List<String> agentCapabilities, List<String> requiredCapabilities) {
|
||||||
|
long matchingCapabilities = requiredCapabilities.stream()
|
||||||
|
.filter(agentCapabilities::contains)
|
||||||
|
.count();
|
||||||
|
return (double) matchingCapabilities / requiredCapabilities.size();
|
||||||
|
}
|
||||||
|
|
||||||
public List<AgentInfo> getAllAgents() {
|
public List<AgentInfo> getAllAgents() {
|
||||||
return List.copyOf(agents.values());
|
return new ArrayList<>(agents.values());
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -4,76 +4,62 @@ import com.ioa.agent.AgentInfo;
|
||||||
import com.ioa.agent.AgentRegistry;
|
import com.ioa.agent.AgentRegistry;
|
||||||
import com.ioa.model.BedrockLanguageModel;
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import com.ioa.task.Task;
|
import com.ioa.task.Task;
|
||||||
|
import com.ioa.util.TreeOfThought;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
public class TeamFormation {
|
public class TeamFormation {
|
||||||
private AgentRegistry agentRegistry;
|
private AgentRegistry agentRegistry;
|
||||||
private BedrockLanguageModel model;
|
private TreeOfThought treeOfThought;
|
||||||
|
|
||||||
public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
||||||
this.agentRegistry = agentRegistry;
|
this.agentRegistry = agentRegistry;
|
||||||
this.model = model;
|
this.treeOfThought = new TreeOfThought(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<AgentInfo> formTeam(Task task) {
|
public List<AgentInfo> formTeam(Task task) {
|
||||||
List<String> requiredCapabilities = task.getRequiredCapabilities();
|
List<String> requiredCapabilities = task.getRequiredCapabilities();
|
||||||
List<String> requiredTools = task.getRequiredTools();
|
List<String> requiredTools = task.getRequiredTools();
|
||||||
List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities);
|
List<AgentInfo> potentialAgents = agentRegistry.searchAgentsPartial(requiredCapabilities);
|
||||||
|
|
||||||
System.out.println("DEBUG: Potential agents: " + potentialAgents);
|
System.out.println("DEBUG: Potential agents: " + potentialAgents);
|
||||||
|
|
||||||
String initialPrompt = "Task: " + task.getDescription() + "\n" +
|
String teamFormationTask = "Form the best team for this task: " + task.getDescription() +
|
||||||
"Required capabilities: " + String.join(", ", requiredCapabilities) + "\n" +
|
"\nRequired capabilities: " + requiredCapabilities +
|
||||||
"Required tools: " + String.join(", ", requiredTools) + "\n" +
|
"\nRequired tools: " + requiredTools +
|
||||||
"Available agents:\n" + formatAgentDetails(potentialAgents) + "\n" +
|
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents) +
|
||||||
"Instructions: Analyze the task requirements and the available agents. " +
|
"\nAnalyze the task, evaluate agents, and propose a team composition. " +
|
||||||
"Form the best team by selecting agents whose combined capabilities and tools meet the task requirements. " +
|
"Conclude with a final team selection in the format: 'Final Team Selection: agent1, agent2, ...'";
|
||||||
"Consider the following steps:\n" +
|
|
||||||
"1. Identify which capabilities and tools are crucial for the task.\n" +
|
|
||||||
"2. Match these requirements with the available agents.\n" +
|
|
||||||
"3. Consider how agents can complement each other's skills and tools.\n" +
|
|
||||||
"4. Aim to cover all required capabilities and tools with the smallest effective team.\n" +
|
|
||||||
"5. If a perfect match isn't possible, prioritize the most important requirements.\n" +
|
|
||||||
"Provide your reasoning for each step, then conclude with a final team selection in this format: 'Selected Team: agent1, agent2, ...'";
|
|
||||||
|
|
||||||
System.out.println("DEBUG: Sending initial prompt to language model: " + initialPrompt);
|
String reasoning = treeOfThought.reason(teamFormationTask, 3, 2);
|
||||||
|
System.out.println("DEBUG: Tree of Thought reasoning:\n" + reasoning);
|
||||||
String reasoning = model.generate(initialPrompt, null);
|
|
||||||
System.out.println("DEBUG: Language model reasoning:\n" + reasoning);
|
|
||||||
|
|
||||||
return parseTeamComposition(reasoning, potentialAgents);
|
return parseTeamComposition(reasoning, potentialAgents);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String formatAgentDetails(List<AgentInfo> agents) {
|
private String formatAgentTools(List<AgentInfo> agents) {
|
||||||
StringBuilder sb = new StringBuilder();
|
return agents.stream()
|
||||||
for (AgentInfo agent : agents) {
|
.map(agent -> agent.getId() + " (capabilities: " + agent.getCapabilities() + ", tools: " + agent.getTools() + ")")
|
||||||
sb.append(agent.getId()).append(":\n")
|
.collect(Collectors.joining(", "));
|
||||||
.append(" Capabilities: ").append(String.join(", ", agent.getCapabilities())).append("\n")
|
|
||||||
.append(" Tools: ").append(String.join(", ", agent.getTools())).append("\n\n");
|
|
||||||
}
|
|
||||||
return sb.toString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AgentInfo> parseTeamComposition(String reasoning, List<AgentInfo> potentialAgents) {
|
private List<AgentInfo> parseTeamComposition(String reasoning, List<AgentInfo> potentialAgents) {
|
||||||
// Extract the team selection from the reasoning
|
|
||||||
String[] lines = reasoning.split("\n");
|
String[] lines = reasoning.split("\n");
|
||||||
String selectedTeamLine = "";
|
String selectedTeamLine = "";
|
||||||
for (String line : lines) {
|
for (String line : lines) {
|
||||||
if (line.startsWith("Selected Team:")) {
|
if (line.startsWith("Final Team Selection:")) {
|
||||||
selectedTeamLine = line.substring("Selected Team:".length()).trim();
|
selectedTeamLine = line.substring("Final Team Selection:".length()).trim();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectedTeamLine.isEmpty()) {
|
if (selectedTeamLine.isEmpty()) {
|
||||||
System.out.println("DEBUG: No team selection found in the response.");
|
System.out.println("DEBUG: No team selection found in the response.");
|
||||||
return Collections.emptyList();
|
return List.of();
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> selectedIds = Arrays.asList(selectedTeamLine.split(",\\s*"));
|
List<String> selectedIds = Arrays.asList(selectedTeamLine.split(",\\s*"));
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user