Main code with websockets, Tree-of-thought, and springboot
This commit is contained in:
parent
a97cd853fe
commit
634b4b2561
113
pom.xml
113
pom.xml
|
@ -1,12 +1,101 @@
|
||||||
<dependencies>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<dependency>
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
<groupId>software.amazon.awssdk</groupId>
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
<artifactId>bedrock-runtime</artifactId>
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<version>2.20.0</version>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<groupId>com.ioa</groupId>
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
<artifactId>ioa-system</artifactId>
|
||||||
<artifactId>jackson-databind</artifactId>
|
<version>1.0-SNAPSHOT</version>
|
||||||
<version>2.13.0</version>
|
|
||||||
</dependency>
|
<properties>
|
||||||
</dependencies>
|
<java.version>11</java.version>
|
||||||
|
<spring-boot.version>2.5.5</spring-boot.version>
|
||||||
|
<aws.sdk.version>2.26.9</aws.sdk.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencyManagement>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-dependencies</artifactId>
|
||||||
|
<version>${spring-boot.version}</version>
|
||||||
|
<type>pom</type>
|
||||||
|
<scope>import</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>software.amazon.awssdk</groupId>
|
||||||
|
<artifactId>bom</artifactId>
|
||||||
|
<version>${aws.sdk.version}</version>
|
||||||
|
<type>pom</type>
|
||||||
|
<scope>import</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</dependencyManagement>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-websocket</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>software.amazon.awssdk</groupId>
|
||||||
|
<artifactId>bedrockruntime</artifactId>
|
||||||
|
<version>${aws.sdk.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
|
<artifactId>jackson-databind</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
<optional>true</optional>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||||
|
<version>${spring-boot.version}</version>
|
||||||
|
<configuration>
|
||||||
|
<mainClass>com.ioa.IoASystem</mainClass>
|
||||||
|
<excludes>
|
||||||
|
<exclude>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
</exclude>
|
||||||
|
</excludes>
|
||||||
|
</configuration>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<goals>
|
||||||
|
<goal>repackage</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
<version>3.8.1</version>
|
||||||
|
<configuration>
|
||||||
|
<source>${java.version}</source>
|
||||||
|
<target>${java.version}</target>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
|
</project>
|
|
@ -1,35 +1,100 @@
|
||||||
package com.ioa;
|
package com.ioa;
|
||||||
|
|
||||||
|
import com.ioa.agent.AgentInfo;
|
||||||
import com.ioa.agent.AgentRegistry;
|
import com.ioa.agent.AgentRegistry;
|
||||||
|
import com.ioa.task.Task;
|
||||||
import com.ioa.task.TaskManager;
|
import com.ioa.task.TaskManager;
|
||||||
import com.ioa.team.TeamFormation;
|
import com.ioa.team.TeamFormation;
|
||||||
import com.ioa.tool.CommonTools;
|
import com.ioa.tool.CommonTools;
|
||||||
import com.ioa.tool.ToolRegistry;
|
import com.ioa.tool.ToolRegistry;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
import com.ioa.service.WebSocketService;
|
||||||
|
import com.ioa.tool.Tool;
|
||||||
|
|
||||||
|
import org.springframework.boot.SpringApplication;
|
||||||
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@SpringBootApplication
|
||||||
public class IoASystem {
|
public class IoASystem {
|
||||||
public static void initialize() {
|
|
||||||
|
@Bean
|
||||||
|
public ToolRegistry toolRegistry() {
|
||||||
ToolRegistry toolRegistry = new ToolRegistry();
|
ToolRegistry toolRegistry = new ToolRegistry();
|
||||||
CommonTools commonTools = new CommonTools();
|
CommonTools commonTools = new CommonTools();
|
||||||
|
|
||||||
// Register all tools from CommonTools
|
|
||||||
for (Method method : CommonTools.class.getMethods()) {
|
for (Method method : CommonTools.class.getMethods()) {
|
||||||
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
|
if (method.isAnnotationPresent(Tool.class)) {
|
||||||
toolRegistry.registerTool(method.getName(), method);
|
toolRegistry.registerTool(method.getName(), method);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return toolRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
|
@Bean
|
||||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
public AgentRegistry agentRegistry(ToolRegistry toolRegistry) {
|
||||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
return new AgentRegistry(toolRegistry);
|
||||||
.build();
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public BedrockLanguageModel bedrockLanguageModel() {
|
||||||
|
return new BedrockLanguageModel("anthropic.claude-v2");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
||||||
|
return new TeamFormation(agentRegistry, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
|
||||||
|
return new TaskManager(agentRegistry, model, toolRegistry);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public WebSocketService webSocketService() {
|
||||||
|
return new WebSocketService();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
var context = SpringApplication.run(IoASystem.class, args);
|
||||||
|
|
||||||
|
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
|
||||||
|
TeamFormation teamFormation = context.getBean(TeamFormation.class);
|
||||||
|
TaskManager taskManager = context.getBean(TaskManager.class);
|
||||||
|
|
||||||
|
// Register some example agents
|
||||||
|
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
|
||||||
|
Arrays.asList("general", "search"),
|
||||||
|
Arrays.asList("webSearch", "getWeather", "setReminder"));
|
||||||
|
AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
|
||||||
|
Arrays.asList("travel", "booking"),
|
||||||
|
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
|
||||||
|
|
||||||
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
|
agentRegistry.registerAgent(agent1.getId(), agent1);
|
||||||
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
|
agentRegistry.registerAgent(agent2.getId(), agent2);
|
||||||
|
|
||||||
// Initialize other components as needed
|
// Create a sample task
|
||||||
|
Task task = new Task("task1", "Plan a weekend trip to Paris",
|
||||||
|
Arrays.asList("travel", "booking"),
|
||||||
|
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
|
||||||
|
|
||||||
|
// Form a team for the task
|
||||||
|
List<AgentInfo> team = teamFormation.formTeam(task);
|
||||||
|
System.out.println("Formed team: " + team);
|
||||||
|
|
||||||
|
// Assign the task to the first agent in the team (simplified)
|
||||||
|
task.setAssignedAgent(team.get(0));
|
||||||
|
|
||||||
|
// Execute the task
|
||||||
|
taskManager.addTask(task);
|
||||||
|
taskManager.executeTask(task.getId());
|
||||||
|
|
||||||
|
// Print the result
|
||||||
|
System.out.println("Task result: " + task.getResult());
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -7,8 +7,7 @@ import com.ioa.task.TaskManager;
|
||||||
import com.ioa.team.TeamFormation;
|
import com.ioa.team.TeamFormation;
|
||||||
import com.ioa.tool.CommonTools;
|
import com.ioa.tool.CommonTools;
|
||||||
import com.ioa.tool.ToolRegistry;
|
import com.ioa.tool.ToolRegistry;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
|
||||||
|
|
||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -28,9 +27,7 @@ public class Main {
|
||||||
}
|
}
|
||||||
|
|
||||||
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
|
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
|
||||||
ChatLanguageModel model = OpenAiChatModel.builder()
|
BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID
|
||||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
|
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
|
||||||
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
|
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
|
||||||
|
@ -65,4 +62,4 @@ public class Main {
|
||||||
// Print the result
|
// Print the result
|
||||||
System.out.println("Task result: " + task.getResult());
|
System.out.println("Task result: " + task.getResult());
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,28 +1,19 @@
|
||||||
package com.ioa.agent;
|
package com.ioa.agent;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
public class AgentInfo {
|
public class AgentInfo {
|
||||||
private String id;
|
private String id;
|
||||||
private String name;
|
private String name;
|
||||||
private List<String> capabilities;
|
private List<String> capabilities;
|
||||||
private List<String> tools;
|
private List<String> tools;
|
||||||
|
|
||||||
// Constructor
|
|
||||||
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
|
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.capabilities = capabilities;
|
this.capabilities = capabilities;
|
||||||
this.tools = tools;
|
this.tools = tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Getters and setters
|
|
||||||
public String getId() { return id; }
|
|
||||||
public void setId(String id) { this.id = id; }
|
|
||||||
public String getName() { return name; }
|
|
||||||
public void setName(String name) { this.name = name; }
|
|
||||||
public List<String> getCapabilities() { return capabilities; }
|
|
||||||
public void setCapabilities(List<String> capabilities) { this.capabilities = capabilities; }
|
|
||||||
public List<String> getTools() { return tools; }
|
|
||||||
public void setTools(List<String> tools) { this.tools = tools; }
|
|
||||||
}
|
}
|
|
@ -1,12 +1,14 @@
|
||||||
package com.ioa.agent;
|
package com.ioa.agent;
|
||||||
|
|
||||||
import com.ioa.tool.ToolRegistry;
|
import com.ioa.tool.ToolRegistry;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class AgentRegistry {
|
public class AgentRegistry {
|
||||||
private Map<String, AgentInfo> agents = new HashMap<>();
|
private Map<String, AgentInfo> agents = new HashMap<>();
|
||||||
private ToolRegistry toolRegistry;
|
private ToolRegistry toolRegistry;
|
||||||
|
@ -17,7 +19,7 @@ public class AgentRegistry {
|
||||||
|
|
||||||
public void registerAgent(String agentId, AgentInfo agentInfo) {
|
public void registerAgent(String agentId, AgentInfo agentInfo) {
|
||||||
agents.put(agentId, agentInfo);
|
agents.put(agentId, agentInfo);
|
||||||
// Register agent's tools
|
// Verify that all tools the agent claims to have are registered
|
||||||
for (String tool : agentInfo.getTools()) {
|
for (String tool : agentInfo.getTools()) {
|
||||||
if (toolRegistry.getTool(tool) == null) {
|
if (toolRegistry.getTool(tool) == null) {
|
||||||
throw new IllegalArgumentException("Tool not found in registry: " + tool);
|
throw new IllegalArgumentException("Tool not found in registry: " + tool);
|
||||||
|
@ -34,4 +36,8 @@ public class AgentRegistry {
|
||||||
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
|
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<AgentInfo> getAllAgents() {
|
||||||
|
return List.copyOf(agents.values());
|
||||||
|
}
|
||||||
}
|
}
|
23
src/main/java/com/ioa/config/WebSocketConfig.java
Normal file
23
src/main/java/com/ioa/config/WebSocketConfig.java
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package com.ioa.config;
|
||||||
|
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
||||||
|
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
||||||
|
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@EnableWebSocketMessageBroker
|
||||||
|
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void configureMessageBroker(MessageBrokerRegistry config) {
|
||||||
|
config.enableSimpleBroker("/topic");
|
||||||
|
config.setApplicationDestinationPrefixes("/app");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||||
|
registry.addEndpoint("/ws").withSockJS();
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,29 +1,34 @@
|
||||||
package com.ioa.conversation;
|
package com.ioa.conversation;
|
||||||
|
|
||||||
import com.ioa.util.TreeOfThought;
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import com.ioa.service.WebSocketService;
|
||||||
import dev.langchain4j.model.output.Response;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class ConversationFSM {
|
public class ConversationFSM {
|
||||||
private ConversationState currentState;
|
private ConversationState currentState;
|
||||||
private TreeOfThought treeOfThought;
|
private BedrockLanguageModel model;
|
||||||
|
|
||||||
public ConversationFSM(ChatLanguageModel model) {
|
@Autowired
|
||||||
|
private WebSocketService webSocketService;
|
||||||
|
|
||||||
|
public ConversationFSM(BedrockLanguageModel model) {
|
||||||
this.currentState = ConversationState.DISCUSSION;
|
this.currentState = ConversationState.DISCUSSION;
|
||||||
this.treeOfThought = new TreeOfThought(model);
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void handleMessage(Message message) {
|
public void handleMessage(Message message) {
|
||||||
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
||||||
"\nCurrent state: " + currentState;
|
"\nCurrent state: " + currentState;
|
||||||
|
|
||||||
String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3);
|
String reasoning = model.generate(stateTransitionTask);
|
||||||
|
|
||||||
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||||
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
|
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
|
||||||
Response<String> response = treeOfThought.getModel().generate(decisionPrompt);
|
String response = model.generate(decisionPrompt);
|
||||||
|
|
||||||
ConversationState newState = ConversationState.valueOf(response.content().trim());
|
ConversationState newState = ConversationState.valueOf(response.trim());
|
||||||
transitionTo(newState);
|
transitionTo(newState);
|
||||||
|
|
||||||
// Handle the message based on the new state
|
// Handle the message based on the new state
|
||||||
|
@ -44,8 +49,8 @@ public class ConversationFSM {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void transitionTo(ConversationState newState) {
|
private void transitionTo(ConversationState newState) {
|
||||||
// Add any transition logic here
|
|
||||||
this.currentState = newState;
|
this.currentState = newState;
|
||||||
|
webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleDiscussionMessage(Message message) {
|
private void handleDiscussionMessage(Message message) {
|
||||||
|
@ -63,4 +68,12 @@ public class ConversationFSM {
|
||||||
private void handleConclusionMessage(Message message) {
|
private void handleConclusionMessage(Message message) {
|
||||||
// Implement conclusion logic
|
// Implement conclusion logic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class ConversationStateUpdate {
|
||||||
|
public ConversationState state;
|
||||||
|
|
||||||
|
ConversationStateUpdate(ConversationState state) {
|
||||||
|
this.state = state;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,5 +1,8 @@
|
||||||
package com.ioa.conversation;
|
package com.ioa.conversation;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
public class Message {
|
public class Message {
|
||||||
private String sender;
|
private String sender;
|
||||||
private String content;
|
private String content;
|
||||||
|
@ -8,7 +11,4 @@ public class Message {
|
||||||
this.sender = sender;
|
this.sender = sender;
|
||||||
this.content = content;
|
this.content = content;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getSender() { return sender; }
|
|
||||||
public String getContent() { return content; }
|
|
||||||
}
|
}
|
|
@ -1,16 +1,19 @@
|
||||||
package com.ioa.model;
|
package com.ioa.model;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
||||||
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
||||||
import software.amazon.awssdk.regions.Region;
|
import software.amazon.awssdk.regions.Region;
|
||||||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
|
||||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
|
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
|
||||||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
|
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
|
||||||
import software.amazon.awssdk.core.SdkBytes;
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class BedrockLanguageModel {
|
public class BedrockLanguageModel {
|
||||||
private final BedrockRuntimeClient bedrockClient;
|
private final BedrockRuntimeClient bedrockClient;
|
||||||
private final ObjectMapper objectMapper;
|
private final ObjectMapper objectMapper;
|
||||||
|
|
16
src/main/java/com/ioa/service/WebSocketService.java
Normal file
16
src/main/java/com/ioa/service/WebSocketService.java
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package com.ioa.service;
|
||||||
|
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class WebSocketService {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SimpMessagingTemplate messagingTemplate;
|
||||||
|
|
||||||
|
public void sendUpdate(String topic, Object payload) {
|
||||||
|
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,9 +1,11 @@
|
||||||
package com.ioa.task;
|
package com.ioa.task;
|
||||||
|
|
||||||
import com.ioa.agent.AgentInfo;
|
import com.ioa.agent.AgentInfo;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
public class Task {
|
public class Task {
|
||||||
private String id;
|
private String id;
|
||||||
private String description;
|
private String description;
|
||||||
|
@ -12,21 +14,10 @@ public class Task {
|
||||||
private AgentInfo assignedAgent;
|
private AgentInfo assignedAgent;
|
||||||
private String result;
|
private String result;
|
||||||
|
|
||||||
// Constructor
|
|
||||||
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
|
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.description = description;
|
this.description = description;
|
||||||
this.requiredCapabilities = requiredCapabilities;
|
this.requiredCapabilities = requiredCapabilities;
|
||||||
this.requiredTools = requiredTools;
|
this.requiredTools = requiredTools;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Getters and setters
|
|
||||||
public String getId() { return id; }
|
|
||||||
public String getDescription() { return description; }
|
|
||||||
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
|
|
||||||
public List<String> getRequiredTools() { return requiredTools; }
|
|
||||||
public AgentInfo getAssignedAgent() { return assignedAgent; }
|
|
||||||
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
|
|
||||||
public String getResult() { return result; }
|
|
||||||
public void setResult(String result) { this.result = result; }
|
|
||||||
}
|
}
|
|
@ -2,43 +2,57 @@ package com.ioa.task;
|
||||||
|
|
||||||
import com.ioa.agent.AgentInfo;
|
import com.ioa.agent.AgentInfo;
|
||||||
import com.ioa.agent.AgentRegistry;
|
import com.ioa.agent.AgentRegistry;
|
||||||
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
|
import com.ioa.service.WebSocketService;
|
||||||
import com.ioa.tool.ToolRegistry;
|
import com.ioa.tool.ToolRegistry;
|
||||||
import com.ioa.util.TreeOfThought;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import org.springframework.stereotype.Component;
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class TaskManager {
|
public class TaskManager {
|
||||||
private Map<String, Task> tasks = new HashMap<>();
|
private Map<String, Task> tasks = new HashMap<>();
|
||||||
private AgentRegistry agentRegistry;
|
private AgentRegistry agentRegistry;
|
||||||
private TreeOfThought treeOfThought;
|
private BedrockLanguageModel model;
|
||||||
private ToolRegistry toolRegistry;
|
private ToolRegistry toolRegistry;
|
||||||
|
|
||||||
public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) {
|
@Autowired
|
||||||
|
private WebSocketService webSocketService;
|
||||||
|
|
||||||
|
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
|
||||||
this.agentRegistry = agentRegistry;
|
this.agentRegistry = agentRegistry;
|
||||||
this.treeOfThought = new TreeOfThought(model);
|
this.model = model;
|
||||||
this.toolRegistry = toolRegistry;
|
this.toolRegistry = toolRegistry;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addTask(Task task) {
|
||||||
|
tasks.put(task.getId(), task);
|
||||||
|
}
|
||||||
|
|
||||||
public void executeTask(String taskId) {
|
public void executeTask(String taskId) {
|
||||||
Task task = tasks.get(taskId);
|
Task task = tasks.get(taskId);
|
||||||
AgentInfo agent = task.getAssignedAgent();
|
AgentInfo agent = task.getAssignedAgent();
|
||||||
|
|
||||||
|
updateTaskProgress(taskId, "STARTED", 0);
|
||||||
|
|
||||||
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
|
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
|
||||||
"\nAssigned agent capabilities: " + agent.getCapabilities() +
|
"\nAssigned agent capabilities: " + agent.getCapabilities() +
|
||||||
"\nAvailable tools: " + agent.getTools();
|
"\nAvailable tools: " + agent.getTools();
|
||||||
|
|
||||||
String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3);
|
String reasoning = model.generate(executionPlanningTask);
|
||||||
|
|
||||||
|
updateTaskProgress(taskId, "IN_PROGRESS", 50);
|
||||||
|
|
||||||
String executionPrompt = "Based on this execution plan:\n" + reasoning +
|
String executionPrompt = "Based on this execution plan:\n" + reasoning +
|
||||||
"\nExecute the task using the available tools and provide the result.";
|
"\nExecute the task using the available tools and provide the result.";
|
||||||
Response<String> response = treeOfThought.getModel().generate(executionPrompt);
|
String response = model.generate(executionPrompt);
|
||||||
|
|
||||||
String result = executeToolsFromResponse(response.content(), agent);
|
String result = executeToolsFromResponse(response, agent);
|
||||||
|
|
||||||
task.setResult(result);
|
task.setResult(result);
|
||||||
|
updateTaskProgress(taskId, "COMPLETED", 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String executeToolsFromResponse(String response, AgentInfo agent) {
|
private String executeToolsFromResponse(String response, AgentInfo agent) {
|
||||||
|
@ -53,11 +67,20 @@ public class TaskManager {
|
||||||
return result.toString();
|
return result.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addTask(Task task) {
|
private void updateTaskProgress(String taskId, String status, int progressPercentage) {
|
||||||
tasks.put(task.getId(), task);
|
TaskProgress progress = new TaskProgress(taskId, status, progressPercentage);
|
||||||
|
webSocketService.sendUpdate("task_progress", progress);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Task getTask(String taskId) {
|
private class TaskProgress {
|
||||||
return tasks.get(taskId);
|
public String taskId;
|
||||||
|
public String status;
|
||||||
|
public int progressPercentage;
|
||||||
|
|
||||||
|
TaskProgress(String taskId, String status, int progressPercentage) {
|
||||||
|
this.taskId = taskId;
|
||||||
|
this.status = status;
|
||||||
|
this.progressPercentage = progressPercentage;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -2,22 +2,22 @@ package com.ioa.team;
|
||||||
|
|
||||||
import com.ioa.agent.AgentInfo;
|
import com.ioa.agent.AgentInfo;
|
||||||
import com.ioa.agent.AgentRegistry;
|
import com.ioa.agent.AgentRegistry;
|
||||||
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import com.ioa.task.Task;
|
import com.ioa.task.Task;
|
||||||
import com.ioa.util.TreeOfThought;
|
import org.springframework.stereotype.Component;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class TeamFormation {
|
public class TeamFormation {
|
||||||
private AgentRegistry agentRegistry;
|
private AgentRegistry agentRegistry;
|
||||||
private TreeOfThought treeOfThought;
|
private BedrockLanguageModel model;
|
||||||
|
|
||||||
public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) {
|
public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
|
||||||
this.agentRegistry = agentRegistry;
|
this.agentRegistry = agentRegistry;
|
||||||
this.treeOfThought = new TreeOfThought(model);
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<AgentInfo> formTeam(Task task) {
|
public List<AgentInfo> formTeam(Task task) {
|
||||||
|
@ -29,13 +29,13 @@ public class TeamFormation {
|
||||||
"\nRequired tools: " + requiredTools +
|
"\nRequired tools: " + requiredTools +
|
||||||
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
|
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
|
||||||
|
|
||||||
String reasoning = treeOfThought.reason(teamFormationTask, 3, 3);
|
String reasoning = model.generate(teamFormationTask);
|
||||||
|
|
||||||
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
|
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||||
"\nProvide the final team composition as a comma-separated list of agent IDs.";
|
"\nProvide the final team composition as a comma-separated list of agent IDs.";
|
||||||
Response<String> response = treeOfThought.getModel().generate(finalDecisionPrompt);
|
String response = model.generate(finalDecisionPrompt);
|
||||||
|
|
||||||
return parseTeamComposition(response.content(), potentialAgents);
|
return parseTeamComposition(response, potentialAgents);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String formatAgentTools(List<AgentInfo> agents) {
|
private String formatAgentTools(List<AgentInfo> agents) {
|
||||||
|
|
12
src/main/java/com/ioa/tool/Tool.java
Normal file
12
src/main/java/com/ioa/tool/Tool.java
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package com.ioa.tool;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
@Target(ElementType.METHOD)
|
||||||
|
public @interface Tool {
|
||||||
|
String value();
|
||||||
|
}
|
|
@ -1,8 +1,11 @@
|
||||||
package com.ioa.tool;
|
package com.ioa.tool;
|
||||||
|
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Component
|
||||||
public class ToolRegistry {
|
public class ToolRegistry {
|
||||||
private Map<String, Object> tools = new HashMap<>();
|
private Map<String, Object> tools = new HashMap<>();
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
package com.ioa.util;
|
package com.ioa.util;
|
||||||
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import com.ioa.model.BedrockLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
|
|
||||||
public class TreeOfThought {
|
public class TreeOfThought {
|
||||||
private final ChatLanguageModel model;
|
private final BedrockLanguageModel model;
|
||||||
|
|
||||||
public TreeOfThought(ChatLanguageModel model) {
|
public TreeOfThought(BedrockLanguageModel model) {
|
||||||
this.model = model;
|
this.model = model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,8 +22,7 @@ public class TreeOfThought {
|
||||||
for (int i = 0; i < branches; i++) {
|
for (int i = 0; i < branches; i++) {
|
||||||
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
|
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
|
||||||
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
||||||
Response<String> response = model.generate(branchPrompt);
|
String thought = model.generate(branchPrompt);
|
||||||
String thought = response.content();
|
|
||||||
|
|
||||||
result.append("Branch ").append(i + 1).append(":\n");
|
result.append("Branch ").append(i + 1).append(":\n");
|
||||||
result.append(thought).append("\n");
|
result.append(thought).append("\n");
|
||||||
|
@ -35,11 +33,10 @@ public class TreeOfThought {
|
||||||
|
|
||||||
private String evaluateLeaf(String task, String path) {
|
private String evaluateLeaf(String task, String path) {
|
||||||
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
|
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
|
||||||
Response<String> response = model.generate(prompt);
|
return model.generate(prompt);
|
||||||
return response.content();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatLanguageModel getModel() {
|
public BedrockLanguageModel getModel() {
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user