Main code with websockets, Tree-of-thought, and springboot

This commit is contained in:
Mahesh Kommareddi 2024-07-16 20:03:09 -04:00
parent a97cd853fe
commit 634b4b2561
16 changed files with 329 additions and 100 deletions

95
pom.xml
View File

@ -1,12 +1,101 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.ioa</groupId>
<artifactId>ioa-system</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<java.version>11</java.version>
<spring-boot.version>2.5.5</spring-boot.version>
<aws.sdk.version>2.26.9</aws.sdk.version>
</properties>
<dependencyManagement>
<dependencies> <dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency> <dependency>
<groupId>software.amazon.awssdk</groupId> <groupId>software.amazon.awssdk</groupId>
<artifactId>bedrock-runtime</artifactId> <artifactId>bom</artifactId>
<version>2.20.0</version> <version>${aws.sdk.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrockruntime</artifactId>
<version>${aws.sdk.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId> <artifactId>jackson-databind</artifactId>
<version>2.13.0</version> </dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>${spring-boot.version}</version>
<configuration>
<mainClass>com.ioa.IoASystem</mainClass>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -1,35 +1,100 @@
package com.ioa; package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry; import com.ioa.agent.AgentRegistry;
import com.ioa.task.Task;
import com.ioa.task.TaskManager; import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation; import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools; import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel; import com.ioa.model.BedrockLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel; import com.ioa.service.WebSocketService;
import com.ioa.tool.Tool;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
@SpringBootApplication
public class IoASystem { public class IoASystem {
public static void initialize() {
@Bean
public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry(); ToolRegistry toolRegistry = new ToolRegistry();
CommonTools commonTools = new CommonTools(); CommonTools commonTools = new CommonTools();
// Register all tools from CommonTools
for (Method method : CommonTools.class.getMethods()) { for (Method method : CommonTools.class.getMethods()) {
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) { if (method.isAnnotationPresent(Tool.class)) {
toolRegistry.registerTool(method.getName(), method); toolRegistry.registerTool(method.getName(), method);
} }
} }
return toolRegistry;
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); @Bean
ChatLanguageModel model = OpenAiChatModel.builder() public AgentRegistry agentRegistry(ToolRegistry toolRegistry) {
.apiKey(System.getenv("OPENAI_API_KEY")) return new AgentRegistry(toolRegistry);
.build(); }
TeamFormation teamFormation = new TeamFormation(agentRegistry, model); @Bean
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); public BedrockLanguageModel bedrockLanguageModel() {
return new BedrockLanguageModel("anthropic.claude-v2");
}
// Initialize other components as needed @Bean
public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
return new TeamFormation(agentRegistry, model);
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
return new TaskManager(agentRegistry, model, toolRegistry);
}
@Bean
public WebSocketService webSocketService() {
return new WebSocketService();
}
public static void main(String[] args) {
var context = SpringApplication.run(IoASystem.class, args);
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class);
// Register some example agents
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"));
AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
agentRegistry.registerAgent(agent1.getId(), agent1);
agentRegistry.registerAgent(agent2.getId(), agent2);
// Create a sample task
Task task = new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
// Form a team for the task
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
// Assign the task to the first agent in the team (simplified)
task.setAssignedAgent(team.get(0));
// Execute the task
taskManager.addTask(task);
taskManager.executeTask(task.getId());
// Print the result
System.out.println("Task result: " + task.getResult());
} }
} }

View File

@ -7,8 +7,7 @@ import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation; import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools; import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel; import com.ioa.model.BedrockLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Arrays; import java.util.Arrays;
@ -28,9 +27,7 @@ public class Main {
} }
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
ChatLanguageModel model = OpenAiChatModel.builder() BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
TeamFormation teamFormation = new TeamFormation(agentRegistry, model); TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);

View File

@ -1,28 +1,19 @@
package com.ioa.agent; package com.ioa.agent;
import lombok.Data;
import java.util.List; import java.util.List;
@Data
public class AgentInfo { public class AgentInfo {
private String id; private String id;
private String name; private String name;
private List<String> capabilities; private List<String> capabilities;
private List<String> tools; private List<String> tools;
// Constructor
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) { public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
this.id = id; this.id = id;
this.name = name; this.name = name;
this.capabilities = capabilities; this.capabilities = capabilities;
this.tools = tools; this.tools = tools;
} }
// Getters and setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public List<String> getCapabilities() { return capabilities; }
public void setCapabilities(List<String> capabilities) { this.capabilities = capabilities; }
public List<String> getTools() { return tools; }
public void setTools(List<String> tools) { this.tools = tools; }
} }

View File

@ -1,12 +1,14 @@
package com.ioa.agent; package com.ioa.agent;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import org.springframework.stereotype.Component;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Component
public class AgentRegistry { public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>(); private Map<String, AgentInfo> agents = new HashMap<>();
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
@ -17,7 +19,7 @@ public class AgentRegistry {
public void registerAgent(String agentId, AgentInfo agentInfo) { public void registerAgent(String agentId, AgentInfo agentInfo) {
agents.put(agentId, agentInfo); agents.put(agentId, agentInfo);
// Register agent's tools // Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) { for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) { if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool); throw new IllegalArgumentException("Tool not found in registry: " + tool);
@ -34,4 +36,8 @@ public class AgentRegistry {
.filter(agent -> agent.getCapabilities().containsAll(capabilities)) .filter(agent -> agent.getCapabilities().containsAll(capabilities))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
public List<AgentInfo> getAllAgents() {
return List.copyOf(agents.values());
}
} }

View File

@ -0,0 +1,23 @@
package com.ioa.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void configureMessageBroker(MessageBrokerRegistry config) {
config.enableSimpleBroker("/topic");
config.setApplicationDestinationPrefixes("/app");
}
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").withSockJS();
}
}

View File

@ -1,29 +1,34 @@
package com.ioa.conversation; package com.ioa.conversation;
import com.ioa.util.TreeOfThought; import com.ioa.model.BedrockLanguageModel;
import dev.langchain4j.model.chat.ChatLanguageModel; import com.ioa.service.WebSocketService;
import dev.langchain4j.model.output.Response; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
public class ConversationFSM { public class ConversationFSM {
private ConversationState currentState; private ConversationState currentState;
private TreeOfThought treeOfThought; private BedrockLanguageModel model;
public ConversationFSM(ChatLanguageModel model) { @Autowired
private WebSocketService webSocketService;
public ConversationFSM(BedrockLanguageModel model) {
this.currentState = ConversationState.DISCUSSION; this.currentState = ConversationState.DISCUSSION;
this.treeOfThought = new TreeOfThought(model); this.model = model;
} }
public void handleMessage(Message message) { public void handleMessage(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() + String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState; "\nCurrent state: " + currentState;
String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3); String reasoning = model.generate(stateTransitionTask);
String decisionPrompt = "Based on this reasoning:\n" + reasoning + String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION)."; "\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
Response<String> response = treeOfThought.getModel().generate(decisionPrompt); String response = model.generate(decisionPrompt);
ConversationState newState = ConversationState.valueOf(response.content().trim()); ConversationState newState = ConversationState.valueOf(response.trim());
transitionTo(newState); transitionTo(newState);
// Handle the message based on the new state // Handle the message based on the new state
@ -44,8 +49,8 @@ public class ConversationFSM {
} }
private void transitionTo(ConversationState newState) { private void transitionTo(ConversationState newState) {
// Add any transition logic here
this.currentState = newState; this.currentState = newState;
webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState));
} }
private void handleDiscussionMessage(Message message) { private void handleDiscussionMessage(Message message) {
@ -63,4 +68,12 @@ public class ConversationFSM {
private void handleConclusionMessage(Message message) { private void handleConclusionMessage(Message message) {
// Implement conclusion logic // Implement conclusion logic
} }
private class ConversationStateUpdate {
public ConversationState state;
ConversationStateUpdate(ConversationState state) {
this.state = state;
}
}
} }

View File

@ -1,5 +1,8 @@
package com.ioa.conversation; package com.ioa.conversation;
import lombok.Data;
@Data
public class Message { public class Message {
private String sender; private String sender;
private String content; private String content;
@ -8,7 +11,4 @@ public class Message {
this.sender = sender; this.sender = sender;
this.content = content; this.content = content;
} }
public String getSender() { return sender; }
public String getContent() { return content; }
} }

View File

@ -1,16 +1,19 @@
package com.ioa.model; package com.ioa.model;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Map; import java.util.Map;
import java.util.HashMap; import java.util.HashMap;
import org.springframework.stereotype.Component;
@Component
public class BedrockLanguageModel { public class BedrockLanguageModel {
private final BedrockRuntimeClient bedrockClient; private final BedrockRuntimeClient bedrockClient;
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;

View File

@ -0,0 +1,16 @@
package com.ioa.service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
@Service
public class WebSocketService {
@Autowired
private SimpMessagingTemplate messagingTemplate;
public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload);
}
}

View File

@ -1,9 +1,11 @@
package com.ioa.task; package com.ioa.task;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import lombok.Data;
import java.util.List; import java.util.List;
@Data
public class Task { public class Task {
private String id; private String id;
private String description; private String description;
@ -12,21 +14,10 @@ public class Task {
private AgentInfo assignedAgent; private AgentInfo assignedAgent;
private String result; private String result;
// Constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) { public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id; this.id = id;
this.description = description; this.description = description;
this.requiredCapabilities = requiredCapabilities; this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools; this.requiredTools = requiredTools;
} }
// Getters and setters
public String getId() { return id; }
public String getDescription() { return description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
} }

View File

@ -2,43 +2,57 @@ package com.ioa.task;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry; import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought; import org.springframework.beans.factory.annotation.Autowired;
import dev.langchain4j.model.chat.ChatLanguageModel; import org.springframework.stereotype.Component;
import dev.langchain4j.model.output.Response;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@Component
public class TaskManager { public class TaskManager {
private Map<String, Task> tasks = new HashMap<>(); private Map<String, Task> tasks = new HashMap<>();
private AgentRegistry agentRegistry; private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought; private BedrockLanguageModel model;
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) { @Autowired
private WebSocketService webSocketService;
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
this.agentRegistry = agentRegistry; this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model); this.model = model;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
} }
public void addTask(Task task) {
tasks.put(task.getId(), task);
}
public void executeTask(String taskId) { public void executeTask(String taskId) {
Task task = tasks.get(taskId); Task task = tasks.get(taskId);
AgentInfo agent = task.getAssignedAgent(); AgentInfo agent = task.getAssignedAgent();
updateTaskProgress(taskId, "STARTED", 0);
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() + String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
"\nAssigned agent capabilities: " + agent.getCapabilities() + "\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools(); "\nAvailable tools: " + agent.getTools();
String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3); String reasoning = model.generate(executionPlanningTask);
updateTaskProgress(taskId, "IN_PROGRESS", 50);
String executionPrompt = "Based on this execution plan:\n" + reasoning + String executionPrompt = "Based on this execution plan:\n" + reasoning +
"\nExecute the task using the available tools and provide the result."; "\nExecute the task using the available tools and provide the result.";
Response<String> response = treeOfThought.getModel().generate(executionPrompt); String response = model.generate(executionPrompt);
String result = executeToolsFromResponse(response.content(), agent); String result = executeToolsFromResponse(response, agent);
task.setResult(result); task.setResult(result);
updateTaskProgress(taskId, "COMPLETED", 100);
} }
private String executeToolsFromResponse(String response, AgentInfo agent) { private String executeToolsFromResponse(String response, AgentInfo agent) {
@ -53,11 +67,20 @@ public class TaskManager {
return result.toString(); return result.toString();
} }
public void addTask(Task task) { private void updateTaskProgress(String taskId, String status, int progressPercentage) {
tasks.put(task.getId(), task); TaskProgress progress = new TaskProgress(taskId, status, progressPercentage);
webSocketService.sendUpdate("task_progress", progress);
} }
public Task getTask(String taskId) { private class TaskProgress {
return tasks.get(taskId); public String taskId;
public String status;
public int progressPercentage;
TaskProgress(String taskId, String status, int progressPercentage) {
this.taskId = taskId;
this.status = status;
this.progressPercentage = progressPercentage;
}
} }
} }

View File

@ -2,22 +2,22 @@ package com.ioa.team;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry; import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.task.Task; import com.ioa.task.Task;
import com.ioa.util.TreeOfThought; import org.springframework.stereotype.Component;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Component
public class TeamFormation { public class TeamFormation {
private AgentRegistry agentRegistry; private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought; private BedrockLanguageModel model;
public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) { public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
this.agentRegistry = agentRegistry; this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model); this.model = model;
} }
public List<AgentInfo> formTeam(Task task) { public List<AgentInfo> formTeam(Task task) {
@ -29,13 +29,13 @@ public class TeamFormation {
"\nRequired tools: " + requiredTools + "\nRequired tools: " + requiredTools +
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents); "\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
String reasoning = treeOfThought.reason(teamFormationTask, 3, 3); String reasoning = model.generate(teamFormationTask);
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning + String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the final team composition as a comma-separated list of agent IDs."; "\nProvide the final team composition as a comma-separated list of agent IDs.";
Response<String> response = treeOfThought.getModel().generate(finalDecisionPrompt); String response = model.generate(finalDecisionPrompt);
return parseTeamComposition(response.content(), potentialAgents); return parseTeamComposition(response, potentialAgents);
} }
private String formatAgentTools(List<AgentInfo> agents) { private String formatAgentTools(List<AgentInfo> agents) {

View File

@ -0,0 +1,12 @@
package com.ioa.tool;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Tool {
String value();
}

View File

@ -1,8 +1,11 @@
package com.ioa.tool; package com.ioa.tool;
import org.springframework.stereotype.Component;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@Component
public class ToolRegistry { public class ToolRegistry {
private Map<String, Object> tools = new HashMap<>(); private Map<String, Object> tools = new HashMap<>();

View File

@ -1,12 +1,11 @@
package com.ioa.util; package com.ioa.util;
import dev.langchain4j.model.chat.ChatLanguageModel; import com.ioa.model.BedrockLanguageModel;
import dev.langchain4j.model.output.Response;
public class TreeOfThought { public class TreeOfThought {
private final ChatLanguageModel model; private final BedrockLanguageModel model;
public TreeOfThought(ChatLanguageModel model) { public TreeOfThought(BedrockLanguageModel model) {
this.model = model; this.model = model;
} }
@ -23,8 +22,7 @@ public class TreeOfThought {
for (int i = 0; i < branches; i++) { for (int i = 0; i < branches; i++) {
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path + String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
Response<String> response = model.generate(branchPrompt); String thought = model.generate(branchPrompt);
String thought = response.content();
result.append("Branch ").append(i + 1).append(":\n"); result.append("Branch ").append(i + 1).append(":\n");
result.append(thought).append("\n"); result.append(thought).append("\n");
@ -35,11 +33,10 @@ public class TreeOfThought {
private String evaluateLeaf(String task, String path) { private String evaluateLeaf(String task, String path) {
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path; String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
Response<String> response = model.generate(prompt); return model.generate(prompt);
return response.content();
} }
public ChatLanguageModel getModel() { public BedrockLanguageModel getModel() {
return model; return model;
} }
} }