diff --git a/pom.xml b/pom.xml index f3d8047..6f5cbfc 100644 --- a/pom.xml +++ b/pom.xml @@ -1,12 +1,101 @@ - - - software.amazon.awssdk - bedrock-runtime - 2.20.0 - - - com.fasterxml.jackson.core - jackson-databind - 2.13.0 - - \ No newline at end of file + + + 4.0.0 + + com.ioa + ioa-system + 1.0-SNAPSHOT + + + 11 + 2.5.5 + 2.26.9 + + + + + + org.springframework.boot + spring-boot-dependencies + ${spring-boot.version} + pom + import + + + software.amazon.awssdk + bom + ${aws.sdk.version} + pom + import + + + + + + + org.springframework.boot + spring-boot-starter-web + + + org.springframework.boot + spring-boot-starter-websocket + + + software.amazon.awssdk + bedrockruntime + ${aws.sdk.version} + + + com.fasterxml.jackson.core + jackson-databind + + + org.projectlombok + lombok + true + + + org.springframework.boot + spring-boot-starter-test + test + + + + + + + org.springframework.boot + spring-boot-maven-plugin + ${spring-boot.version} + + com.ioa.IoASystem + + + org.projectlombok + lombok + + + + + + + repackage + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.1 + + ${java.version} + ${java.version} + + + + + + \ No newline at end of file diff --git a/src/main/java/com/ioa/IoASystem.java b/src/main/java/com/ioa/IoASystem.java index 69d5db0..f4abdd0 100644 --- a/src/main/java/com/ioa/IoASystem.java +++ b/src/main/java/com/ioa/IoASystem.java @@ -1,35 +1,100 @@ package com.ioa; +import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentRegistry; +import com.ioa.task.Task; import com.ioa.task.TaskManager; import com.ioa.team.TeamFormation; import com.ioa.tool.CommonTools; import com.ioa.tool.ToolRegistry; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiChatModel; +import com.ioa.model.BedrockLanguageModel; +import com.ioa.service.WebSocketService; +import com.ioa.tool.Tool; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +@SpringBootApplication public class IoASystem { - public static void initialize() { + + @Bean + public ToolRegistry toolRegistry() { ToolRegistry toolRegistry = new ToolRegistry(); CommonTools commonTools = new CommonTools(); - // Register all tools from CommonTools for (Method method : CommonTools.class.getMethods()) { - if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) { + if (method.isAnnotationPresent(Tool.class)) { toolRegistry.registerTool(method.getName(), method); } } + return toolRegistry; + } - AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); - ChatLanguageModel model = OpenAiChatModel.builder() - .apiKey(System.getenv("OPENAI_API_KEY")) - .build(); + @Bean + public AgentRegistry agentRegistry(ToolRegistry toolRegistry) { + return new AgentRegistry(toolRegistry); + } + + @Bean + public BedrockLanguageModel bedrockLanguageModel() { + return new BedrockLanguageModel("anthropic.claude-v2"); + } + + @Bean + public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) { + return new TeamFormation(agentRegistry, model); + } + + @Bean + public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) { + return new TaskManager(agentRegistry, model, toolRegistry); + } + + @Bean + public WebSocketService webSocketService() { + return new WebSocketService(); + } + + public static void main(String[] args) { + var context = SpringApplication.run(IoASystem.class, args); + + AgentRegistry agentRegistry = context.getBean(AgentRegistry.class); + TeamFormation teamFormation = context.getBean(TeamFormation.class); + TaskManager taskManager = context.getBean(TaskManager.class); + + // Register some example agents + AgentInfo agent1 = new AgentInfo("agent1", "General Assistant", + Arrays.asList("general", "search"), + Arrays.asList("webSearch", "getWeather", "setReminder")); + AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert", + Arrays.asList("travel", "booking"), + Arrays.asList("bookTravel", "calculateDistance", "findRestaurants")); - TeamFormation teamFormation = new TeamFormation(agentRegistry, model); - TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); - - // Initialize other components as needed + agentRegistry.registerAgent(agent1.getId(), agent1); + agentRegistry.registerAgent(agent2.getId(), agent2); + + // Create a sample task + Task task = new Task("task1", "Plan a weekend trip to Paris", + Arrays.asList("travel", "booking"), + Arrays.asList("bookTravel", "findRestaurants", "getWeather")); + + // Form a team for the task + List team = teamFormation.formTeam(task); + System.out.println("Formed team: " + team); + + // Assign the task to the first agent in the team (simplified) + task.setAssignedAgent(team.get(0)); + + // Execute the task + taskManager.addTask(task); + taskManager.executeTask(task.getId()); + + // Print the result + System.out.println("Task result: " + task.getResult()); } } \ No newline at end of file diff --git a/src/main/Main.java b/src/main/java/com/ioa/Main.java similarity index 90% rename from src/main/Main.java rename to src/main/java/com/ioa/Main.java index 1ad7458..34f4695 100644 --- a/src/main/Main.java +++ b/src/main/java/com/ioa/Main.java @@ -7,8 +7,7 @@ import com.ioa.task.TaskManager; import com.ioa.team.TeamFormation; import com.ioa.tool.CommonTools; import com.ioa.tool.ToolRegistry; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiChatModel; +import com.ioa.model.BedrockLanguageModel; import java.lang.reflect.Method; import java.util.Arrays; @@ -28,9 +27,7 @@ public class Main { } AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); - ChatLanguageModel model = OpenAiChatModel.builder() - .apiKey(System.getenv("OPENAI_API_KEY")) - .build(); + BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID TeamFormation teamFormation = new TeamFormation(agentRegistry, model); TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); @@ -65,4 +62,4 @@ public class Main { // Print the result System.out.println("Task result: " + task.getResult()); } -} \ No newline at end of file +} diff --git a/src/main/java/com/ioa/agent/AgentInfo.java b/src/main/java/com/ioa/agent/AgentInfo.java index acd3a68..55b156d 100644 --- a/src/main/java/com/ioa/agent/AgentInfo.java +++ b/src/main/java/com/ioa/agent/AgentInfo.java @@ -1,28 +1,19 @@ package com.ioa.agent; +import lombok.Data; import java.util.List; +@Data public class AgentInfo { private String id; private String name; private List capabilities; private List tools; - // Constructor public AgentInfo(String id, String name, List capabilities, List tools) { this.id = id; this.name = name; this.capabilities = capabilities; this.tools = tools; } - - // Getters and setters - public String getId() { return id; } - public void setId(String id) { this.id = id; } - public String getName() { return name; } - public void setName(String name) { this.name = name; } - public List getCapabilities() { return capabilities; } - public void setCapabilities(List capabilities) { this.capabilities = capabilities; } - public List getTools() { return tools; } - public void setTools(List tools) { this.tools = tools; } } \ No newline at end of file diff --git a/src/main/java/com/ioa/agent/AgentRegistry.java b/src/main/java/com/ioa/agent/AgentRegistry.java index 6bda082..272f8f4 100644 --- a/src/main/java/com/ioa/agent/AgentRegistry.java +++ b/src/main/java/com/ioa/agent/AgentRegistry.java @@ -1,12 +1,14 @@ package com.ioa.agent; import com.ioa.tool.ToolRegistry; +import org.springframework.stereotype.Component; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +@Component public class AgentRegistry { private Map agents = new HashMap<>(); private ToolRegistry toolRegistry; @@ -17,7 +19,7 @@ public class AgentRegistry { public void registerAgent(String agentId, AgentInfo agentInfo) { agents.put(agentId, agentInfo); - // Register agent's tools + // Verify that all tools the agent claims to have are registered for (String tool : agentInfo.getTools()) { if (toolRegistry.getTool(tool) == null) { throw new IllegalArgumentException("Tool not found in registry: " + tool); @@ -34,4 +36,8 @@ public class AgentRegistry { .filter(agent -> agent.getCapabilities().containsAll(capabilities)) .collect(Collectors.toList()); } + + public List getAllAgents() { + return List.copyOf(agents.values()); + } } \ No newline at end of file diff --git a/src/main/java/com/ioa/config/WebSocketConfig.java b/src/main/java/com/ioa/config/WebSocketConfig.java new file mode 100644 index 0000000..70919ec --- /dev/null +++ b/src/main/java/com/ioa/config/WebSocketConfig.java @@ -0,0 +1,23 @@ +package com.ioa.config; + +import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.MessageBrokerRegistry; +import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; +import org.springframework.web.socket.config.annotation.StompEndpointRegistry; +import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; + +@Configuration +@EnableWebSocketMessageBroker +public class WebSocketConfig implements WebSocketMessageBrokerConfigurer { + + @Override + public void configureMessageBroker(MessageBrokerRegistry config) { + config.enableSimpleBroker("/topic"); + config.setApplicationDestinationPrefixes("/app"); + } + + @Override + public void registerStompEndpoints(StompEndpointRegistry registry) { + registry.addEndpoint("/ws").withSockJS(); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/conversation/ConversationFSM.java b/src/main/java/com/ioa/conversation/ConversationFSM.java index 08f5a7b..893871c 100644 --- a/src/main/java/com/ioa/conversation/ConversationFSM.java +++ b/src/main/java/com/ioa/conversation/ConversationFSM.java @@ -1,29 +1,34 @@ package com.ioa.conversation; -import com.ioa.util.TreeOfThought; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.output.Response; +import com.ioa.model.BedrockLanguageModel; +import com.ioa.service.WebSocketService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +@Component public class ConversationFSM { private ConversationState currentState; - private TreeOfThought treeOfThought; + private BedrockLanguageModel model; - public ConversationFSM(ChatLanguageModel model) { + @Autowired + private WebSocketService webSocketService; + + public ConversationFSM(BedrockLanguageModel model) { this.currentState = ConversationState.DISCUSSION; - this.treeOfThought = new TreeOfThought(model); + this.model = model; } public void handleMessage(Message message) { String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() + "\nCurrent state: " + currentState; - String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3); + String reasoning = model.generate(stateTransitionTask); String decisionPrompt = "Based on this reasoning:\n" + reasoning + "\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION)."; - Response response = treeOfThought.getModel().generate(decisionPrompt); + String response = model.generate(decisionPrompt); - ConversationState newState = ConversationState.valueOf(response.content().trim()); + ConversationState newState = ConversationState.valueOf(response.trim()); transitionTo(newState); // Handle the message based on the new state @@ -44,8 +49,8 @@ public class ConversationFSM { } private void transitionTo(ConversationState newState) { - // Add any transition logic here this.currentState = newState; + webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState)); } private void handleDiscussionMessage(Message message) { @@ -63,4 +68,12 @@ public class ConversationFSM { private void handleConclusionMessage(Message message) { // Implement conclusion logic } + + private class ConversationStateUpdate { + public ConversationState state; + + ConversationStateUpdate(ConversationState state) { + this.state = state; + } + } } \ No newline at end of file diff --git a/src/main/java/com/ioa/conversation/Message.java b/src/main/java/com/ioa/conversation/Message.java index 278e4f1..b5d825b 100644 --- a/src/main/java/com/ioa/conversation/Message.java +++ b/src/main/java/com/ioa/conversation/Message.java @@ -1,5 +1,8 @@ package com.ioa.conversation; +import lombok.Data; + +@Data public class Message { private String sender; private String content; @@ -8,7 +11,4 @@ public class Message { this.sender = sender; this.content = content; } - - public String getSender() { return sender; } - public String getContent() { return content; } } \ No newline at end of file diff --git a/src/main/java/com/ioa/model/BedrockLanguageModel.java b/src/main/java/com/ioa/model/BedrockLanguageModel.java index a380393..bb530cd 100644 --- a/src/main/java/com/ioa/model/BedrockLanguageModel.java +++ b/src/main/java/com/ioa/model/BedrockLanguageModel.java @@ -1,16 +1,19 @@ package com.ioa.model; -import com.fasterxml.jackson.databind.ObjectMapper; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import software.amazon.awssdk.core.SdkBytes; +import com.fasterxml.jackson.databind.ObjectMapper; import java.util.Map; import java.util.HashMap; +import org.springframework.stereotype.Component; + +@Component public class BedrockLanguageModel { private final BedrockRuntimeClient bedrockClient; private final ObjectMapper objectMapper; diff --git a/src/main/java/com/ioa/service/WebSocketService.java b/src/main/java/com/ioa/service/WebSocketService.java new file mode 100644 index 0000000..4e144b8 --- /dev/null +++ b/src/main/java/com/ioa/service/WebSocketService.java @@ -0,0 +1,16 @@ +package com.ioa.service; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.stereotype.Service; + +@Service +public class WebSocketService { + + @Autowired + private SimpMessagingTemplate messagingTemplate; + + public void sendUpdate(String topic, Object payload) { + messagingTemplate.convertAndSend("/topic/" + topic, payload); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/task/Task.java b/src/main/java/com/ioa/task/Task.java index 6cf5677..c6457fa 100644 --- a/src/main/java/com/ioa/task/Task.java +++ b/src/main/java/com/ioa/task/Task.java @@ -1,9 +1,11 @@ package com.ioa.task; import com.ioa.agent.AgentInfo; +import lombok.Data; import java.util.List; +@Data public class Task { private String id; private String description; @@ -12,21 +14,10 @@ public class Task { private AgentInfo assignedAgent; private String result; - // Constructor public Task(String id, String description, List requiredCapabilities, List requiredTools) { this.id = id; this.description = description; this.requiredCapabilities = requiredCapabilities; this.requiredTools = requiredTools; } - - // Getters and setters - public String getId() { return id; } - public String getDescription() { return description; } - public List getRequiredCapabilities() { return requiredCapabilities; } - public List getRequiredTools() { return requiredTools; } - public AgentInfo getAssignedAgent() { return assignedAgent; } - public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; } - public String getResult() { return result; } - public void setResult(String result) { this.result = result; } } \ No newline at end of file diff --git a/src/main/java/com/ioa/task/TaskManager.java b/src/main/java/com/ioa/task/TaskManager.java index 990e9d0..33169ac 100644 --- a/src/main/java/com/ioa/task/TaskManager.java +++ b/src/main/java/com/ioa/task/TaskManager.java @@ -2,43 +2,57 @@ package com.ioa.task; import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentRegistry; +import com.ioa.model.BedrockLanguageModel; +import com.ioa.service.WebSocketService; import com.ioa.tool.ToolRegistry; -import com.ioa.util.TreeOfThought; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.output.Response; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; import java.util.HashMap; import java.util.Map; +@Component public class TaskManager { private Map tasks = new HashMap<>(); private AgentRegistry agentRegistry; - private TreeOfThought treeOfThought; + private BedrockLanguageModel model; private ToolRegistry toolRegistry; - public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) { + @Autowired + private WebSocketService webSocketService; + + public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) { this.agentRegistry = agentRegistry; - this.treeOfThought = new TreeOfThought(model); + this.model = model; this.toolRegistry = toolRegistry; } + public void addTask(Task task) { + tasks.put(task.getId(), task); + } + public void executeTask(String taskId) { Task task = tasks.get(taskId); AgentInfo agent = task.getAssignedAgent(); + updateTaskProgress(taskId, "STARTED", 0); + String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() + "\nAssigned agent capabilities: " + agent.getCapabilities() + "\nAvailable tools: " + agent.getTools(); - String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3); + String reasoning = model.generate(executionPlanningTask); + + updateTaskProgress(taskId, "IN_PROGRESS", 50); String executionPrompt = "Based on this execution plan:\n" + reasoning + "\nExecute the task using the available tools and provide the result."; - Response response = treeOfThought.getModel().generate(executionPrompt); + String response = model.generate(executionPrompt); - String result = executeToolsFromResponse(response.content(), agent); + String result = executeToolsFromResponse(response, agent); task.setResult(result); + updateTaskProgress(taskId, "COMPLETED", 100); } private String executeToolsFromResponse(String response, AgentInfo agent) { @@ -53,11 +67,20 @@ public class TaskManager { return result.toString(); } - public void addTask(Task task) { - tasks.put(task.getId(), task); + private void updateTaskProgress(String taskId, String status, int progressPercentage) { + TaskProgress progress = new TaskProgress(taskId, status, progressPercentage); + webSocketService.sendUpdate("task_progress", progress); } - public Task getTask(String taskId) { - return tasks.get(taskId); + private class TaskProgress { + public String taskId; + public String status; + public int progressPercentage; + + TaskProgress(String taskId, String status, int progressPercentage) { + this.taskId = taskId; + this.status = status; + this.progressPercentage = progressPercentage; + } } } \ No newline at end of file diff --git a/src/main/java/com/ioa/team/TeamFormation.java b/src/main/java/com/ioa/team/TeamFormation.java index 3d925da..0da1e67 100644 --- a/src/main/java/com/ioa/team/TeamFormation.java +++ b/src/main/java/com/ioa/team/TeamFormation.java @@ -2,22 +2,22 @@ package com.ioa.team; import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentRegistry; +import com.ioa.model.BedrockLanguageModel; import com.ioa.task.Task; -import com.ioa.util.TreeOfThought; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.output.Response; +import org.springframework.stereotype.Component; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +@Component public class TeamFormation { private AgentRegistry agentRegistry; - private TreeOfThought treeOfThought; + private BedrockLanguageModel model; - public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) { + public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) { this.agentRegistry = agentRegistry; - this.treeOfThought = new TreeOfThought(model); + this.model = model; } public List formTeam(Task task) { @@ -29,13 +29,13 @@ public class TeamFormation { "\nRequired tools: " + requiredTools + "\nAvailable agents and their tools: " + formatAgentTools(potentialAgents); - String reasoning = treeOfThought.reason(teamFormationTask, 3, 3); + String reasoning = model.generate(teamFormationTask); String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning + "\nProvide the final team composition as a comma-separated list of agent IDs."; - Response response = treeOfThought.getModel().generate(finalDecisionPrompt); + String response = model.generate(finalDecisionPrompt); - return parseTeamComposition(response.content(), potentialAgents); + return parseTeamComposition(response, potentialAgents); } private String formatAgentTools(List agents) { diff --git a/src/main/java/com/ioa/tool/Tool.java b/src/main/java/com/ioa/tool/Tool.java new file mode 100644 index 0000000..3db8918 --- /dev/null +++ b/src/main/java/com/ioa/tool/Tool.java @@ -0,0 +1,12 @@ +package com.ioa.tool; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Tool { + String value(); +} \ No newline at end of file diff --git a/src/main/java/com/ioa/tool/ToolRegistry.java b/src/main/java/com/ioa/tool/ToolRegistry.java index d0626af..03f6c90 100644 --- a/src/main/java/com/ioa/tool/ToolRegistry.java +++ b/src/main/java/com/ioa/tool/ToolRegistry.java @@ -1,8 +1,11 @@ package com.ioa.tool; +import org.springframework.stereotype.Component; + import java.util.HashMap; import java.util.Map; +@Component public class ToolRegistry { private Map tools = new HashMap<>(); diff --git a/src/main/java/com/ioa/util/TreeOfThought.java b/src/main/java/com/ioa/util/TreeOfThought.java index c5203ba..fbbecb5 100644 --- a/src/main/java/com/ioa/util/TreeOfThought.java +++ b/src/main/java/com/ioa/util/TreeOfThought.java @@ -1,12 +1,11 @@ package com.ioa.util; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.output.Response; +import com.ioa.model.BedrockLanguageModel; public class TreeOfThought { - private final ChatLanguageModel model; + private final BedrockLanguageModel model; - public TreeOfThought(ChatLanguageModel model) { + public TreeOfThought(BedrockLanguageModel model) { this.model = model; } @@ -23,8 +22,7 @@ public class TreeOfThought { for (int i = 0; i < branches; i++) { String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path + "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; - Response response = model.generate(branchPrompt); - String thought = response.content(); + String thought = model.generate(branchPrompt); result.append("Branch ").append(i + 1).append(":\n"); result.append(thought).append("\n"); @@ -35,11 +33,10 @@ public class TreeOfThought { private String evaluateLeaf(String task, String path) { String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path; - Response response = model.generate(prompt); - return response.content(); + return model.generate(prompt); } - public ChatLanguageModel getModel() { + public BedrockLanguageModel getModel() { return model; } -} \ No newline at end of file +}