Main code with websockets, Tree-of-thought, and springboot
This commit is contained in:
parent
a97cd853fe
commit
634b4b2561
113
pom.xml
113
pom.xml
|
@ -1,12 +1,101 @@
|
|||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bedrock-runtime</artifactId>
|
||||
<version>2.20.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.13.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.ioa</groupId>
|
||||
<artifactId>ioa-system</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
|
||||
<properties>
|
||||
<java.version>11</java.version>
|
||||
<spring-boot.version>2.5.5</spring-boot.version>
|
||||
<aws.sdk.version>2.26.9</aws.sdk.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-dependencies</artifactId>
|
||||
<version>${spring-boot.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bom</artifactId>
|
||||
<version>${aws.sdk.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>software.amazon.awssdk</groupId>
|
||||
<artifactId>bedrockruntime</artifactId>
|
||||
<version>${aws.sdk.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||
<version>${spring-boot.version}</version>
|
||||
<configuration>
|
||||
<mainClass>com.ioa.IoASystem</mainClass>
|
||||
<excludes>
|
||||
<exclude>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
</exclude>
|
||||
</excludes>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>repackage</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.8.1</version>
|
||||
<configuration>
|
||||
<source>${java.version}</source>
|
||||
<target>${java.version}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
|
@ -1,35 +1,100 @@
|
|||
package com.ioa;
|
||||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.task.Task;
|
||||
import com.ioa.task.TaskManager;
|
||||
import com.ioa.team.TeamFormation;
|
||||
import com.ioa.tool.CommonTools;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.tool.Tool;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@SpringBootApplication
|
||||
public class IoASystem {
|
||||
public static void initialize() {
|
||||
|
||||
@Bean
|
||||
public ToolRegistry toolRegistry() {
|
||||
ToolRegistry toolRegistry = new ToolRegistry();
|
||||
CommonTools commonTools = new CommonTools();
|
||||
|
||||
// Register all tools from CommonTools
|
||||
for (Method method : CommonTools.class.getMethods()) {
|
||||
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
|
||||
if (method.isAnnotationPresent(Tool.class)) {
|
||||
toolRegistry.registerTool(method.getName(), method);
|
||||
}
|
||||
}
|
||||
return toolRegistry;
|
||||
}
|
||||
|
||||
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
|
||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.build();
|
||||
@Bean
|
||||
public AgentRegistry agentRegistry(ToolRegistry toolRegistry) {
|
||||
return new AgentRegistry(toolRegistry);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockLanguageModel bedrockLanguageModel() {
|
||||
return new BedrockLanguageModel("anthropic.claude-v2");
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
||||
return new TeamFormation(agentRegistry, model);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
|
||||
return new TaskManager(agentRegistry, model, toolRegistry);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public WebSocketService webSocketService() {
|
||||
return new WebSocketService();
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
var context = SpringApplication.run(IoASystem.class, args);
|
||||
|
||||
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
|
||||
TeamFormation teamFormation = context.getBean(TeamFormation.class);
|
||||
TaskManager taskManager = context.getBean(TaskManager.class);
|
||||
|
||||
// Register some example agents
|
||||
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
|
||||
Arrays.asList("general", "search"),
|
||||
Arrays.asList("webSearch", "getWeather", "setReminder"));
|
||||
AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
|
||||
Arrays.asList("travel", "booking"),
|
||||
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
|
||||
|
||||
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
|
||||
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
|
||||
|
||||
// Initialize other components as needed
|
||||
agentRegistry.registerAgent(agent1.getId(), agent1);
|
||||
agentRegistry.registerAgent(agent2.getId(), agent2);
|
||||
|
||||
// Create a sample task
|
||||
Task task = new Task("task1", "Plan a weekend trip to Paris",
|
||||
Arrays.asList("travel", "booking"),
|
||||
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
|
||||
|
||||
// Form a team for the task
|
||||
List<AgentInfo> team = teamFormation.formTeam(task);
|
||||
System.out.println("Formed team: " + team);
|
||||
|
||||
// Assign the task to the first agent in the team (simplified)
|
||||
task.setAssignedAgent(team.get(0));
|
||||
|
||||
// Execute the task
|
||||
taskManager.addTask(task);
|
||||
taskManager.executeTask(task.getId());
|
||||
|
||||
// Print the result
|
||||
System.out.println("Task result: " + task.getResult());
|
||||
}
|
||||
}
|
|
@ -7,8 +7,7 @@ import com.ioa.task.TaskManager;
|
|||
import com.ioa.team.TeamFormation;
|
||||
import com.ioa.tool.CommonTools;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.Arrays;
|
||||
|
@ -28,9 +27,7 @@ public class Main {
|
|||
}
|
||||
|
||||
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
|
||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.build();
|
||||
BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID
|
||||
|
||||
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
|
||||
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
|
||||
|
@ -65,4 +62,4 @@ public class Main {
|
|||
// Print the result
|
||||
System.out.println("Task result: " + task.getResult());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,28 +1,19 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class AgentInfo {
|
||||
private String id;
|
||||
private String name;
|
||||
private List<String> capabilities;
|
||||
private List<String> tools;
|
||||
|
||||
// Constructor
|
||||
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
|
||||
this.id = id;
|
||||
this.name = name;
|
||||
this.capabilities = capabilities;
|
||||
this.tools = tools;
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public String getId() { return id; }
|
||||
public void setId(String id) { this.id = id; }
|
||||
public String getName() { return name; }
|
||||
public void setName(String name) { this.name = name; }
|
||||
public List<String> getCapabilities() { return capabilities; }
|
||||
public void setCapabilities(List<String> capabilities) { this.capabilities = capabilities; }
|
||||
public List<String> getTools() { return tools; }
|
||||
public void setTools(List<String> tools) { this.tools = tools; }
|
||||
}
|
|
@ -1,12 +1,14 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
public class AgentRegistry {
|
||||
private Map<String, AgentInfo> agents = new HashMap<>();
|
||||
private ToolRegistry toolRegistry;
|
||||
|
@ -17,7 +19,7 @@ public class AgentRegistry {
|
|||
|
||||
public void registerAgent(String agentId, AgentInfo agentInfo) {
|
||||
agents.put(agentId, agentInfo);
|
||||
// Register agent's tools
|
||||
// Verify that all tools the agent claims to have are registered
|
||||
for (String tool : agentInfo.getTools()) {
|
||||
if (toolRegistry.getTool(tool) == null) {
|
||||
throw new IllegalArgumentException("Tool not found in registry: " + tool);
|
||||
|
@ -34,4 +36,8 @@ public class AgentRegistry {
|
|||
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<AgentInfo> getAllAgents() {
|
||||
return List.copyOf(agents.values());
|
||||
}
|
||||
}
|
23
src/main/java/com/ioa/config/WebSocketConfig.java
Normal file
23
src/main/java/com/ioa/config/WebSocketConfig.java
Normal file
|
@ -0,0 +1,23 @@
|
|||
package com.ioa.config;
|
||||
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
||||
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
||||
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
|
||||
|
||||
@Configuration
|
||||
@EnableWebSocketMessageBroker
|
||||
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
|
||||
|
||||
@Override
|
||||
public void configureMessageBroker(MessageBrokerRegistry config) {
|
||||
config.enableSimpleBroker("/topic");
|
||||
config.setApplicationDestinationPrefixes("/app");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry.addEndpoint("/ws").withSockJS();
|
||||
}
|
||||
}
|
|
@ -1,29 +1,34 @@
|
|||
package com.ioa.conversation;
|
||||
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class ConversationFSM {
|
||||
private ConversationState currentState;
|
||||
private TreeOfThought treeOfThought;
|
||||
private BedrockLanguageModel model;
|
||||
|
||||
public ConversationFSM(ChatLanguageModel model) {
|
||||
@Autowired
|
||||
private WebSocketService webSocketService;
|
||||
|
||||
public ConversationFSM(BedrockLanguageModel model) {
|
||||
this.currentState = ConversationState.DISCUSSION;
|
||||
this.treeOfThought = new TreeOfThought(model);
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public void handleMessage(Message message) {
|
||||
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
||||
"\nCurrent state: " + currentState;
|
||||
|
||||
String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3);
|
||||
String reasoning = model.generate(stateTransitionTask);
|
||||
|
||||
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
|
||||
Response<String> response = treeOfThought.getModel().generate(decisionPrompt);
|
||||
String response = model.generate(decisionPrompt);
|
||||
|
||||
ConversationState newState = ConversationState.valueOf(response.content().trim());
|
||||
ConversationState newState = ConversationState.valueOf(response.trim());
|
||||
transitionTo(newState);
|
||||
|
||||
// Handle the message based on the new state
|
||||
|
@ -44,8 +49,8 @@ public class ConversationFSM {
|
|||
}
|
||||
|
||||
private void transitionTo(ConversationState newState) {
|
||||
// Add any transition logic here
|
||||
this.currentState = newState;
|
||||
webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState));
|
||||
}
|
||||
|
||||
private void handleDiscussionMessage(Message message) {
|
||||
|
@ -63,4 +68,12 @@ public class ConversationFSM {
|
|||
private void handleConclusionMessage(Message message) {
|
||||
// Implement conclusion logic
|
||||
}
|
||||
|
||||
private class ConversationStateUpdate {
|
||||
public ConversationState state;
|
||||
|
||||
ConversationStateUpdate(ConversationState state) {
|
||||
this.state = state;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,8 @@
|
|||
package com.ioa.conversation;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Message {
|
||||
private String sender;
|
||||
private String content;
|
||||
|
@ -8,7 +11,4 @@ public class Message {
|
|||
this.sender = sender;
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public String getSender() { return sender; }
|
||||
public String getContent() { return content; }
|
||||
}
|
|
@ -1,16 +1,19 @@
|
|||
package com.ioa.model;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class BedrockLanguageModel {
|
||||
private final BedrockRuntimeClient bedrockClient;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
|
16
src/main/java/com/ioa/service/WebSocketService.java
Normal file
16
src/main/java/com/ioa/service/WebSocketService.java
Normal file
|
@ -0,0 +1,16 @@
|
|||
package com.ioa.service;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class WebSocketService {
|
||||
|
||||
@Autowired
|
||||
private SimpMessagingTemplate messagingTemplate;
|
||||
|
||||
public void sendUpdate(String topic, Object payload) {
|
||||
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
||||
}
|
||||
}
|
|
@ -1,9 +1,11 @@
|
|||
package com.ioa.task;
|
||||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class Task {
|
||||
private String id;
|
||||
private String description;
|
||||
|
@ -12,21 +14,10 @@ public class Task {
|
|||
private AgentInfo assignedAgent;
|
||||
private String result;
|
||||
|
||||
// Constructor
|
||||
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
|
||||
this.id = id;
|
||||
this.description = description;
|
||||
this.requiredCapabilities = requiredCapabilities;
|
||||
this.requiredTools = requiredTools;
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public String getId() { return id; }
|
||||
public String getDescription() { return description; }
|
||||
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
|
||||
public List<String> getRequiredTools() { return requiredTools; }
|
||||
public AgentInfo getAssignedAgent() { return assignedAgent; }
|
||||
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
|
||||
public String getResult() { return result; }
|
||||
public void setResult(String result) { this.result = result; }
|
||||
}
|
|
@ -2,43 +2,57 @@ package com.ioa.task;
|
|||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class TaskManager {
|
||||
private Map<String, Task> tasks = new HashMap<>();
|
||||
private AgentRegistry agentRegistry;
|
||||
private TreeOfThought treeOfThought;
|
||||
private BedrockLanguageModel model;
|
||||
private ToolRegistry toolRegistry;
|
||||
|
||||
public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) {
|
||||
@Autowired
|
||||
private WebSocketService webSocketService;
|
||||
|
||||
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
|
||||
this.agentRegistry = agentRegistry;
|
||||
this.treeOfThought = new TreeOfThought(model);
|
||||
this.model = model;
|
||||
this.toolRegistry = toolRegistry;
|
||||
}
|
||||
|
||||
public void addTask(Task task) {
|
||||
tasks.put(task.getId(), task);
|
||||
}
|
||||
|
||||
public void executeTask(String taskId) {
|
||||
Task task = tasks.get(taskId);
|
||||
AgentInfo agent = task.getAssignedAgent();
|
||||
|
||||
updateTaskProgress(taskId, "STARTED", 0);
|
||||
|
||||
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
|
||||
"\nAssigned agent capabilities: " + agent.getCapabilities() +
|
||||
"\nAvailable tools: " + agent.getTools();
|
||||
|
||||
String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3);
|
||||
String reasoning = model.generate(executionPlanningTask);
|
||||
|
||||
updateTaskProgress(taskId, "IN_PROGRESS", 50);
|
||||
|
||||
String executionPrompt = "Based on this execution plan:\n" + reasoning +
|
||||
"\nExecute the task using the available tools and provide the result.";
|
||||
Response<String> response = treeOfThought.getModel().generate(executionPrompt);
|
||||
String response = model.generate(executionPrompt);
|
||||
|
||||
String result = executeToolsFromResponse(response.content(), agent);
|
||||
String result = executeToolsFromResponse(response, agent);
|
||||
|
||||
task.setResult(result);
|
||||
updateTaskProgress(taskId, "COMPLETED", 100);
|
||||
}
|
||||
|
||||
private String executeToolsFromResponse(String response, AgentInfo agent) {
|
||||
|
@ -53,11 +67,20 @@ public class TaskManager {
|
|||
return result.toString();
|
||||
}
|
||||
|
||||
public void addTask(Task task) {
|
||||
tasks.put(task.getId(), task);
|
||||
private void updateTaskProgress(String taskId, String status, int progressPercentage) {
|
||||
TaskProgress progress = new TaskProgress(taskId, status, progressPercentage);
|
||||
webSocketService.sendUpdate("task_progress", progress);
|
||||
}
|
||||
|
||||
public Task getTask(String taskId) {
|
||||
return tasks.get(taskId);
|
||||
private class TaskProgress {
|
||||
public String taskId;
|
||||
public String status;
|
||||
public int progressPercentage;
|
||||
|
||||
TaskProgress(String taskId, String status, int progressPercentage) {
|
||||
this.taskId = taskId;
|
||||
this.status = status;
|
||||
this.progressPercentage = progressPercentage;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,22 +2,22 @@ package com.ioa.team;
|
|||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.task.Task;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
public class TeamFormation {
|
||||
private AgentRegistry agentRegistry;
|
||||
private TreeOfThought treeOfThought;
|
||||
private BedrockLanguageModel model;
|
||||
|
||||
public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) {
|
||||
public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
||||
this.agentRegistry = agentRegistry;
|
||||
this.treeOfThought = new TreeOfThought(model);
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public List<AgentInfo> formTeam(Task task) {
|
||||
|
@ -29,13 +29,13 @@ public class TeamFormation {
|
|||
"\nRequired tools: " + requiredTools +
|
||||
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
|
||||
|
||||
String reasoning = treeOfThought.reason(teamFormationTask, 3, 3);
|
||||
String reasoning = model.generate(teamFormationTask);
|
||||
|
||||
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||
"\nProvide the final team composition as a comma-separated list of agent IDs.";
|
||||
Response<String> response = treeOfThought.getModel().generate(finalDecisionPrompt);
|
||||
String response = model.generate(finalDecisionPrompt);
|
||||
|
||||
return parseTeamComposition(response.content(), potentialAgents);
|
||||
return parseTeamComposition(response, potentialAgents);
|
||||
}
|
||||
|
||||
private String formatAgentTools(List<AgentInfo> agents) {
|
||||
|
|
12
src/main/java/com/ioa/tool/Tool.java
Normal file
12
src/main/java/com/ioa/tool/Tool.java
Normal file
|
@ -0,0 +1,12 @@
|
|||
package com.ioa.tool;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target(ElementType.METHOD)
|
||||
public @interface Tool {
|
||||
String value();
|
||||
}
|
|
@ -1,8 +1,11 @@
|
|||
package com.ioa.tool;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
public class ToolRegistry {
|
||||
private Map<String, Object> tools = new HashMap<>();
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
package com.ioa.util;
|
||||
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
|
||||
public class TreeOfThought {
|
||||
private final ChatLanguageModel model;
|
||||
private final BedrockLanguageModel model;
|
||||
|
||||
public TreeOfThought(ChatLanguageModel model) {
|
||||
public TreeOfThought(BedrockLanguageModel model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
|
@ -23,8 +22,7 @@ public class TreeOfThought {
|
|||
for (int i = 0; i < branches; i++) {
|
||||
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
|
||||
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
||||
Response<String> response = model.generate(branchPrompt);
|
||||
String thought = response.content();
|
||||
String thought = model.generate(branchPrompt);
|
||||
|
||||
result.append("Branch ").append(i + 1).append(":\n");
|
||||
result.append(thought).append("\n");
|
||||
|
@ -35,11 +33,10 @@ public class TreeOfThought {
|
|||
|
||||
private String evaluateLeaf(String task, String path) {
|
||||
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
|
||||
Response<String> response = model.generate(prompt);
|
||||
return response.content();
|
||||
return model.generate(prompt);
|
||||
}
|
||||
|
||||
public ChatLanguageModel getModel() {
|
||||
public BedrockLanguageModel getModel() {
|
||||
return model;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user