Main code with websockets, Tree-of-thought, and springboot

This commit is contained in:
Mahesh Kommareddi 2024-07-16 20:03:09 -04:00
parent a97cd853fe
commit 634b4b2561
16 changed files with 329 additions and 100 deletions

113
pom.xml
View File

@ -1,12 +1,101 @@
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrock-runtime</artifactId>
<version>2.20.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.13.0</version>
</dependency>
</dependencies>
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.ioa</groupId>
<artifactId>ioa-system</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<java.version>11</java.version>
<spring-boot.version>2.5.5</spring-boot.version>
<aws.sdk.version>2.26.9</aws.sdk.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-dependencies</artifactId>
<version>${spring-boot.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>${aws.sdk.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrockruntime</artifactId>
<version>${aws.sdk.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<version>${spring-boot.version}</version>
<configuration>
<mainClass>com.ioa.IoASystem</mainClass>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
<executions>
<execution>
<goals>
<goal>repackage</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -1,35 +1,100 @@
package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.tool.Tool;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
@SpringBootApplication
public class IoASystem {
public static void initialize() {
@Bean
public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry();
CommonTools commonTools = new CommonTools();
// Register all tools from CommonTools
for (Method method : CommonTools.class.getMethods()) {
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
if (method.isAnnotationPresent(Tool.class)) {
toolRegistry.registerTool(method.getName(), method);
}
}
return toolRegistry;
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
@Bean
public AgentRegistry agentRegistry(ToolRegistry toolRegistry) {
return new AgentRegistry(toolRegistry);
}
@Bean
public BedrockLanguageModel bedrockLanguageModel() {
return new BedrockLanguageModel("anthropic.claude-v2");
}
@Bean
public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
return new TeamFormation(agentRegistry, model);
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
return new TaskManager(agentRegistry, model, toolRegistry);
}
@Bean
public WebSocketService webSocketService() {
return new WebSocketService();
}
public static void main(String[] args) {
var context = SpringApplication.run(IoASystem.class, args);
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class);
// Register some example agents
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"));
AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
// Initialize other components as needed
agentRegistry.registerAgent(agent1.getId(), agent1);
agentRegistry.registerAgent(agent2.getId(), agent2);
// Create a sample task
Task task = new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
// Form a team for the task
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
// Assign the task to the first agent in the team (simplified)
task.setAssignedAgent(team.get(0));
// Execute the task
taskManager.addTask(task);
taskManager.executeTask(task.getId());
// Print the result
System.out.println("Task result: " + task.getResult());
}
}

View File

@ -7,8 +7,7 @@ import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import com.ioa.model.BedrockLanguageModel;
import java.lang.reflect.Method;
import java.util.Arrays;
@ -28,9 +27,7 @@ public class Main {
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
@ -65,4 +62,4 @@ public class Main {
// Print the result
System.out.println("Task result: " + task.getResult());
}
}
}

View File

@ -1,28 +1,19 @@
package com.ioa.agent;
import lombok.Data;
import java.util.List;
@Data
public class AgentInfo {
private String id;
private String name;
private List<String> capabilities;
private List<String> tools;
// Constructor
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
this.id = id;
this.name = name;
this.capabilities = capabilities;
this.tools = tools;
}
// Getters and setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public List<String> getCapabilities() { return capabilities; }
public void setCapabilities(List<String> capabilities) { this.capabilities = capabilities; }
public List<String> getTools() { return tools; }
public void setTools(List<String> tools) { this.tools = tools; }
}

View File

@ -1,12 +1,14 @@
package com.ioa.agent;
import com.ioa.tool.ToolRegistry;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Component
public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>();
private ToolRegistry toolRegistry;
@ -17,7 +19,7 @@ public class AgentRegistry {
public void registerAgent(String agentId, AgentInfo agentInfo) {
agents.put(agentId, agentInfo);
// Register agent's tools
// Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool);
@ -34,4 +36,8 @@ public class AgentRegistry {
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
.collect(Collectors.toList());
}
public List<AgentInfo> getAllAgents() {
return List.copyOf(agents.values());
}
}

View File

@ -0,0 +1,23 @@
package com.ioa.config;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void configureMessageBroker(MessageBrokerRegistry config) {
config.enableSimpleBroker("/topic");
config.setApplicationDestinationPrefixes("/app");
}
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").withSockJS();
}
}

View File

@ -1,29 +1,34 @@
package com.ioa.conversation;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@Component
public class ConversationFSM {
private ConversationState currentState;
private TreeOfThought treeOfThought;
private BedrockLanguageModel model;
public ConversationFSM(ChatLanguageModel model) {
@Autowired
private WebSocketService webSocketService;
public ConversationFSM(BedrockLanguageModel model) {
this.currentState = ConversationState.DISCUSSION;
this.treeOfThought = new TreeOfThought(model);
this.model = model;
}
public void handleMessage(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState;
String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3);
String reasoning = model.generate(stateTransitionTask);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
Response<String> response = treeOfThought.getModel().generate(decisionPrompt);
String response = model.generate(decisionPrompt);
ConversationState newState = ConversationState.valueOf(response.content().trim());
ConversationState newState = ConversationState.valueOf(response.trim());
transitionTo(newState);
// Handle the message based on the new state
@ -44,8 +49,8 @@ public class ConversationFSM {
}
private void transitionTo(ConversationState newState) {
// Add any transition logic here
this.currentState = newState;
webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState));
}
private void handleDiscussionMessage(Message message) {
@ -63,4 +68,12 @@ public class ConversationFSM {
private void handleConclusionMessage(Message message) {
// Implement conclusion logic
}
private class ConversationStateUpdate {
public ConversationState state;
ConversationStateUpdate(ConversationState state) {
this.state = state;
}
}
}

View File

@ -1,5 +1,8 @@
package com.ioa.conversation;
import lombok.Data;
@Data
public class Message {
private String sender;
private String content;
@ -8,7 +11,4 @@ public class Message {
this.sender = sender;
this.content = content;
}
public String getSender() { return sender; }
public String getContent() { return content; }
}

View File

@ -1,16 +1,19 @@
package com.ioa.model;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import software.amazon.awssdk.core.SdkBytes;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Map;
import java.util.HashMap;
import org.springframework.stereotype.Component;
@Component
public class BedrockLanguageModel {
private final BedrockRuntimeClient bedrockClient;
private final ObjectMapper objectMapper;

View File

@ -0,0 +1,16 @@
package com.ioa.service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
@Service
public class WebSocketService {
@Autowired
private SimpMessagingTemplate messagingTemplate;
public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload);
}
}

View File

@ -1,9 +1,11 @@
package com.ioa.task;
import com.ioa.agent.AgentInfo;
import lombok.Data;
import java.util.List;
@Data
public class Task {
private String id;
private String description;
@ -12,21 +14,10 @@ public class Task {
private AgentInfo assignedAgent;
private String result;
// Constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id;
this.description = description;
this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools;
}
// Getters and setters
public String getId() { return id; }
public String getDescription() { return description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
}

View File

@ -2,43 +2,57 @@ package com.ioa.task;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component
public class TaskManager {
private Map<String, Task> tasks = new HashMap<>();
private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought;
private BedrockLanguageModel model;
private ToolRegistry toolRegistry;
public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) {
@Autowired
private WebSocketService webSocketService;
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model);
this.model = model;
this.toolRegistry = toolRegistry;
}
public void addTask(Task task) {
tasks.put(task.getId(), task);
}
public void executeTask(String taskId) {
Task task = tasks.get(taskId);
AgentInfo agent = task.getAssignedAgent();
updateTaskProgress(taskId, "STARTED", 0);
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
"\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools();
String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3);
String reasoning = model.generate(executionPlanningTask);
updateTaskProgress(taskId, "IN_PROGRESS", 50);
String executionPrompt = "Based on this execution plan:\n" + reasoning +
"\nExecute the task using the available tools and provide the result.";
Response<String> response = treeOfThought.getModel().generate(executionPrompt);
String response = model.generate(executionPrompt);
String result = executeToolsFromResponse(response.content(), agent);
String result = executeToolsFromResponse(response, agent);
task.setResult(result);
updateTaskProgress(taskId, "COMPLETED", 100);
}
private String executeToolsFromResponse(String response, AgentInfo agent) {
@ -53,11 +67,20 @@ public class TaskManager {
return result.toString();
}
public void addTask(Task task) {
tasks.put(task.getId(), task);
private void updateTaskProgress(String taskId, String status, int progressPercentage) {
TaskProgress progress = new TaskProgress(taskId, status, progressPercentage);
webSocketService.sendUpdate("task_progress", progress);
}
public Task getTask(String taskId) {
return tasks.get(taskId);
private class TaskProgress {
public String taskId;
public String status;
public int progressPercentage;
TaskProgress(String taskId, String status, int progressPercentage) {
this.taskId = taskId;
this.status = status;
this.progressPercentage = progressPercentage;
}
}
}

View File

@ -2,22 +2,22 @@ package com.ioa.team;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.task.Task;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import org.springframework.stereotype.Component;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
@Component
public class TeamFormation {
private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought;
private BedrockLanguageModel model;
public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) {
public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model);
this.model = model;
}
public List<AgentInfo> formTeam(Task task) {
@ -29,13 +29,13 @@ public class TeamFormation {
"\nRequired tools: " + requiredTools +
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
String reasoning = treeOfThought.reason(teamFormationTask, 3, 3);
String reasoning = model.generate(teamFormationTask);
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the final team composition as a comma-separated list of agent IDs.";
Response<String> response = treeOfThought.getModel().generate(finalDecisionPrompt);
String response = model.generate(finalDecisionPrompt);
return parseTeamComposition(response.content(), potentialAgents);
return parseTeamComposition(response, potentialAgents);
}
private String formatAgentTools(List<AgentInfo> agents) {

View File

@ -0,0 +1,12 @@
package com.ioa.tool;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Tool {
String value();
}

View File

@ -1,8 +1,11 @@
package com.ioa.tool;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component
public class ToolRegistry {
private Map<String, Object> tools = new HashMap<>();

View File

@ -1,12 +1,11 @@
package com.ioa.util;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import com.ioa.model.BedrockLanguageModel;
public class TreeOfThought {
private final ChatLanguageModel model;
private final BedrockLanguageModel model;
public TreeOfThought(ChatLanguageModel model) {
public TreeOfThought(BedrockLanguageModel model) {
this.model = model;
}
@ -23,8 +22,7 @@ public class TreeOfThought {
for (int i = 0; i < branches; i++) {
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
Response<String> response = model.generate(branchPrompt);
String thought = response.content();
String thought = model.generate(branchPrompt);
result.append("Branch ").append(i + 1).append(":\n");
result.append(thought).append("\n");
@ -35,11 +33,10 @@ public class TreeOfThought {
private String evaluateLeaf(String task, String path) {
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
Response<String> response = model.generate(prompt);
return response.content();
return model.generate(prompt);
}
public ChatLanguageModel getModel() {
public BedrockLanguageModel getModel() {
return model;
}
}
}