Compare commits

..

3 Commits

55 changed files with 799 additions and 565 deletions

3
README.md Normal file
View File

@ -0,0 +1,3 @@
mvn spring-boot:run
http://localhost:8080/

2
application.properties Normal file
View File

@ -0,0 +1,2 @@
logging.level.org.springframework.messaging=TRACE
logging.level.org.springframework.web.socket=TRACE

View File

@ -1,59 +1,51 @@
package com.ioa; package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationManager;
import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry;
import com.ioa.tool.Tool;
import com.ioa.tool.common.*;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.util.TreeOfThought;
import com.ioa.tool.common.AppointmentSchedulerTool;
import com.ioa.tool.common.DistanceCalculatorTool;
import com.ioa.tool.common.FinancialAdviceTool;
import com.ioa.tool.common.FitnessClassFinderTool;
import com.ioa.tool.common.MovieRecommendationTool;
import com.ioa.tool.common.NewsUpdateTool;
import com.ioa.tool.common.PriceComparisonTool;
import com.ioa.tool.common.RecipeTool;
import com.ioa.tool.common.ReminderTool;
import com.ioa.tool.common.RestaurantFinderTool;
import com.ioa.tool.common.TranslationTool;
import com.ioa.tool.common.TravelBookingTool;
import com.ioa.tool.common.WeatherTool;
import com.ioa.tool.common.WebSearchTool;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.ComponentScan;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import java.util.Arrays; import com.ioa.agent.AgentRegistry;
import java.util.List; import com.ioa.task.TaskManager;
import com.ioa.websocket.WebSocketHandler;
import java.util.concurrent.ExecutorService; import com.ioa.tool.ToolRegistry;
import java.util.concurrent.Executors; import com.ioa.model.BedrockLanguageModel;
import java.util.concurrent.TimeUnit; import com.ioa.util.TreeOfThought;
import java.util.concurrent.CountDownLatch; import com.ioa.service.WebSocketService;
import com.ioa.conversation.ConversationManager;
@SpringBootApplication @SpringBootApplication
@ComponentScan(basePackages = "com.ioa")
public class IoASystem { public class IoASystem {
@Bean @Bean
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) { public WebSocketHandler webSocketHandler(AgentRegistry agentRegistry,
return new WebSocketService(messagingTemplate, conversationManager); TaskManager taskManager,
SimpMessagingTemplate messagingTemplate,
TreeOfThought treeOfThought,
WebSocketService webSocketService,
ToolRegistry toolRegistry,
BedrockLanguageModel model) {
return new WebSocketHandler(agentRegistry, taskManager, messagingTemplate,
treeOfThought, webSocketService, toolRegistry, model);
} }
@Bean @Bean
public ConversationManager conversationManager(BedrockLanguageModel model, @Lazy WebSocketService webSocketService) { public AgentRegistry agentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
return new ConversationManager(model, webSocketService); return new AgentRegistry(messagingTemplate, toolRegistry);
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought,
conversationManager, webSocketService, messagingTemplate);
} }
@Bean @Bean
@ -67,195 +59,21 @@ public class IoASystem {
} }
@Bean @Bean
public AgentRegistry agentRegistry(ToolRegistry toolRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, ConversationManager conversationManager) { public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate) {
AgentRegistry registry = new AgentRegistry(toolRegistry); return new WebSocketService(messagingTemplate);
// Agent creation is now moved to processTasksAndAgents method
return registry;
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, TreeOfThought treeOfThought, ConversationManager conversationManager) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought, conversationManager);
}
@Bean
public TeamFormation teamFormation(AgentRegistry agentRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, BedrockLanguageModel model) {
return new TeamFormation(agentRegistry, treeOfThought, webSocketService, model);
} }
@Bean @Bean
public ToolRegistry toolRegistry() { public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry(); return new ToolRegistry();
}
// Register all tools @Bean
toolRegistry.registerTool("webSearch", new WebSearchTool()); public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
toolRegistry.registerTool("getWeather", new WeatherTool()); return new ConversationManager(model, webSocketService);
toolRegistry.registerTool("setReminder", new ReminderTool());
toolRegistry.registerTool("bookTravel", new TravelBookingTool());
toolRegistry.registerTool("calculateDistance", new DistanceCalculatorTool());
toolRegistry.registerTool("findRestaurants", new RestaurantFinderTool());
toolRegistry.registerTool("scheduleAppointment", new AppointmentSchedulerTool());
toolRegistry.registerTool("findFitnessClasses", new FitnessClassFinderTool());
toolRegistry.registerTool("getRecipe", new RecipeTool());
toolRegistry.registerTool("getNewsUpdates", new NewsUpdateTool());
toolRegistry.registerTool("translate", new TranslationTool());
toolRegistry.registerTool("compareProductPrices", new PriceComparisonTool());
toolRegistry.registerTool("getMovieRecommendations", new MovieRecommendationTool());
toolRegistry.registerTool("getFinancialAdvice", new FinancialAdviceTool());
return toolRegistry;
} }
public static void main(String[] args) { public static void main(String[] args) {
ConfigurableApplicationContext context = SpringApplication.run(IoASystem.class, args); SpringApplication.run(IoASystem.class, args);
IoASystem system = context.getBean(IoASystem.class);
system.processTasksAndAgents(context);
}
public void processTasksAndAgents(ConfigurableApplicationContext context) {
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class);
TreeOfThought treeOfThought = context.getBean(TreeOfThought.class);
WebSocketService webSocketService = context.getBean(WebSocketService.class);
ToolRegistry toolRegistry = context.getBean(ToolRegistry.class);
ConversationManager conversationManager = context.getBean(ConversationManager.class);
// Register all agents
agentRegistry.registerAgent("agent1", new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent2", new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent3", new AgentInfo("agent3", "Event Planner Extraordinaire",
Arrays.asList("event planning", "team management", "booking"),
Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment", "getWeather"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent4", new AgentInfo("agent4", "Fitness Guru",
Arrays.asList("health", "nutrition", "motivation"),
Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent5", new AgentInfo("agent5", "Research Specialist",
Arrays.asList("research", "writing", "analysis"),
Arrays.asList("webSearch", "getNewsUpdates", "translate", "compareProductPrices"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent6", new AgentInfo("agent6", "Digital Marketing Expert",
Arrays.asList("marketing", "social media", "content creation"),
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "getMovieRecommendations"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
agentRegistry.registerAgent("agent7", new AgentInfo("agent7", "Family Travel Coordinator",
Arrays.asList("travel", "family planning", "budgeting"),
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants", "getFinancialAdvice"),
treeOfThought, webSocketService, toolRegistry, conversationManager));
// Create all tasks
List<Task> tasks = Arrays.asList(
new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"))//,
// new Task("task2", "Organize a corporate team-building event in New York",
// Arrays.asList("event planning", "team management"),
// Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment")),
// new Task("task3", "Develop a personalized fitness and nutrition plan",
// Arrays.asList("health", "nutrition"),
// Arrays.asList("getWeather", "findFitnessClasses", "getRecipe")),
// new Task("task4", "Research and summarize recent advancements in renewable energy",
// Arrays.asList("research", "writing"),
// Arrays.asList("webSearch", "getNewsUpdates", "translate")),
// new Task("task5", "Plan and execute a social media marketing campaign for a new product launch",
// Arrays.asList("marketing", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
// new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
// Arrays.asList("travel", "family planning"),
// Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants")),
// new Task("task7", "Organize an international tech conference with virtual and in-person components",
// Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
// Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates"))//,
// new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
// Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
// Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
// new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
// Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
// Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
// new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
// Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
// new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
// Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
// Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
);
// Create a thread pool
ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
for (Task task : tasks) {
System.out.println("\nProcessing task: " + task.getDescription());
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
if (!team.isEmpty()) {
// Create a conversation for the team
String conversationId = conversationManager.createConversation();
// Add team members to the conversation
for (AgentInfo agent : team) {
conversationManager.addParticipant(conversationId, agent);
}
// Create a CountDownLatch to wait for all agents to complete their tasks
CountDownLatch latch = new CountDownLatch(team.size());
// Assign the task to all team members and execute in parallel
for (AgentInfo agent : team) {
executorService.submit(() -> {
try {
Task agentTask = new Task(task.getId() + "_" + agent.getId(), task.getDescription(), task.getRequiredCapabilities(), task.getRequiredTools());
agentTask.setAssignedAgent(agent);
taskManager.addTask(agentTask);
taskManager.executeTask(agentTask.getId(), conversationId);
} finally {
latch.countDown();
}
});
}
// Start the conversation
conversationManager.startConversation(conversationId, "Let's work on the task: " + task.getDescription());
// Wait for all agents to complete their tasks
try {
latch.await(40, TimeUnit.MINUTES); // Wait for up to 5 minutes
} catch (InterruptedException e) {
e.printStackTrace();
}
// Get the result
String result = conversationManager.getConversationResult(conversationId);
System.out.println("Task result: " + result);
} else {
System.out.println("No suitable agents found for this task. Consider updating the agent pool or revising the task requirements.");
}
}
// Shutdown the executor service
executorService.shutdown();
try {
if (!executorService.awaitTermination(60, TimeUnit.SECONDS)) {
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
}
} }
} }

View File

@ -1,125 +1,82 @@
package com.ioa.agent; package com.ioa.agent;
import com.ioa.conversation.Message; import com.ioa.conversation.Message;
import com.ioa.util.TreeOfThought; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import com.ioa.conversation.ConversationFSM; import com.ioa.util.TreeOfThought;
import com.ioa.conversation.ConversationManager; import com.ioa.websocket.AgentProcessingUpdate;
import com.ioa.websocket.AgentResponseUpdate;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.UUID;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
public class AgentInfo { public class AgentInfo {
private String id; private String id;
private String name; private String name;
private List<String> capabilities; private List<String> capabilities;
private List<String> tools; private List<String> tools;
private Memory memory;
private TreeOfThought treeOfThought; private TreeOfThought treeOfThought;
private WebSocketService webSocketService; private WebSocketService webSocketService;
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
private final ConversationManager conversationManager; private BedrockLanguageModel model;
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools, public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, ConversationManager conversationManager) {
this.id = id; this.id = id;
this.name = name; this.name = name;
this.capabilities = capabilities; this.capabilities = capabilities;
this.tools = tools; this.tools = tools;
this.memory = new Memory();
}
public void setDependencies(TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, BedrockLanguageModel model) {
this.treeOfThought = treeOfThought; this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.conversationManager = conversationManager; this.model = model;
} }
public void sendMessage(String conversationId, String content) { public void receiveMessageLocal(Message message) {
conversationManager.postMessage(conversationId, this.id, content); if (this.memory == null)
this.memory = new Memory();
memory.addToHistory(message.getContent());
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nYou received a message: " + message.getContent() +
"\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
String response = model.generate(prompt, null);
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
System.out.println("DEBUG: " + name + " response: " + response);
// Add the response to memory
memory.addToHistory("My response: " + response);
} }
public void receiveMessage(Message message) { public void receiveMessage(Message message) {
// Process the received message if (this.memory == null)
Map<String, Object> reasoningResult = treeOfThought.reason("Respond to message: " + message.getContent(), 2, 2); this.memory = new Memory();
String response = (String) reasoningResult.get("response"); // Assuming the response is stored under the key "response"
// This is a turn notification for this agent
respondToTurn(message.getConversationId());
}
private void respondToTurn(String conversationId) {
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nIt's your turn to speak in the conversation. What would you like to say or do?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
// Send the response back to the conversation // Send the response back to the conversation
sendMessage(message.getConversationId(), response); Message responseMessage = new Message(conversationId, this.id, prompt);
webSocketService.sendUpdate("conversation_message", responseMessage);
} }
public List<String> getCapabilities() {
return this.capabilities;
}
public String getId() {
return this.id;
}
public List<String> getTools() {
return this.tools;
}
@Override
public String toString() {
return "AgentInfo{id='" + id + "', name='" + name + "'}";
}
public String executeTask(String taskDescription) {
System.out.println("DEBUG: Agent " + id + " executing task: " + taskDescription);
webSocketService.sendUpdate("agent_task", Map.of("agentId", id, "task", taskDescription));
// Use Tree of Thought to decide on tools and actions
Map<String, Object> reasoning = treeOfThought.reason("Select tools and actions for task: " + taskDescription +
"\nAvailable tools: " + tools, 2, 2);
String reasoningString = formatReasoning(reasoning);
System.out.println("DEBUG: Agent " + id + " reasoning:\n" + reasoningString);
webSocketService.sendUpdate("agent_reasoning", Map.of("agentId", id, "reasoning", reasoningString));
// Extract tool selection from reasoning
List<String> selectedTools = extractToolSelection(reasoningString);
System.out.println("DEBUG: Agent " + id + " selected tools: " + selectedTools);
webSocketService.sendUpdate("agent_tools_selected", Map.of("agentId", id, "tools", selectedTools));
// Execute actions using selected tools
StringBuilder result = new StringBuilder();
for (String tool : selectedTools) {
String actionResult = executeTool(tool, taskDescription);
result.append(actionResult).append("\n");
}
String finalResult = result.toString().trim();
System.out.println("DEBUG: Agent " + id + " task result: " + finalResult);
webSocketService.sendUpdate("agent_task_result", Map.of("agentId", id, "result", finalResult));
return finalResult;
}
private String formatReasoning(Map<String, Object> reasoning) {
// Implement a method to format the reasoning tree into a string
// This is a placeholder implementation
return reasoning.toString();
}
private List<String> extractToolSelection(String reasoning) {
// Implement a method to extract tool selection from reasoning
// This is a placeholder implementation
return new ArrayList<>(tools);
}
private String executeTool(String tool, String context) {
System.out.println("DEBUG: Agent " + id + " executing tool: " + tool);
webSocketService.sendUpdate("agent_tool_execution", Map.of("agentId", id, "tool", tool));
// Placeholder for tool execution
// In a real implementation, you would call the actual tool method from the ToolRegistry
String result = "Simulated result of using " + tool + " for context: " + context;
System.out.println("DEBUG: Agent " + id + " tool result: " + result);
webSocketService.sendUpdate("agent_tool_result", Map.of("agentId", id, "tool", tool, "result", result));
return result;
}
public void voteToFinish(String conversationId) {
conversationManager.postMessage(conversationId, this.id, "/vote");
}
} }

View File

@ -0,0 +1,12 @@
package com.ioa.agent;
public class AgentMessage {
private String agentId;
private String content;
// Getters and setters
public String getAgentId() { return agentId; }
public void setAgentId(String agentId) { this.agentId = agentId; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -1,34 +1,53 @@
package com.ioa.agent; package com.ioa.agent;
import com.ioa.tool.ToolRegistry; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import com.ioa.tool.ToolRegistry;
import java.util.*; import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Component @Component
public class AgentRegistry { public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>(); private Map<String, AgentInfo> agents = new ConcurrentHashMap<>();
private ToolRegistry toolRegistry; private final SimpMessagingTemplate messagingTemplate;
private final ToolRegistry toolRegistry;
public AgentRegistry(ToolRegistry toolRegistry) { @Autowired
public AgentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
this.messagingTemplate = messagingTemplate;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
} }
public void registerAgent(String agentId, AgentInfo agentInfo) { public void registerAgent(AgentInfo agentInfo) {
agents.put(agentId, agentInfo); agents.put(agentInfo.getId(), agentInfo);
// Verify that all tools the agent claims to have are registered // Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) { for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) { if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool); throw new IllegalArgumentException("Tool not found in registry: " + tool);
} }
} }
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
public void unregisterAgent(String agentId) {
agents.remove(agentId);
messagingTemplate.convertAndSend("/topic/agents", "Agent unregistered: " + agentId);
} }
public AgentInfo getAgent(String agentId) { public AgentInfo getAgent(String agentId) {
return agents.get(agentId); return agents.get(agentId);
} }
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
public List<AgentInfo> searchAgents(List<String> capabilities) { public List<AgentInfo> searchAgents(List<String> capabilities) {
return searchAgents(capabilities, 1.0); // Default to exact match return searchAgents(capabilities, 1.0); // Default to exact match
} }
@ -50,8 +69,4 @@ public class AgentRegistry {
.count(); .count();
return (double) matchingCapabilities / requiredCapabilities.size(); return (double) matchingCapabilities / requiredCapabilities.size();
} }
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
} }

View File

@ -0,0 +1,47 @@
package com.ioa.agent;
import java.util.ArrayList;
import java.util.List;
public class Memory {
private List<String> conversationHistory;
private List<String> contextualFacts;
private int maxHistorySize = 10; // Adjust as needed
public Memory() {
this.conversationHistory = new ArrayList<>();
this.contextualFacts = new ArrayList<>();
}
public void addToHistory(String message) {
conversationHistory.add(message);
if (conversationHistory.size() > maxHistorySize) {
conversationHistory.remove(0);
}
}
public void addContextualFact(String fact) {
contextualFacts.add(fact);
}
public List<String> getConversationHistory() {
return new ArrayList<>(conversationHistory);
}
public List<String> getContextualFacts() {
return new ArrayList<>(contextualFacts);
}
public String getFormattedMemory() {
StringBuilder sb = new StringBuilder();
sb.append("Conversation History:\n");
for (String message : conversationHistory) {
sb.append("- ").append(message).append("\n");
}
sb.append("\nContextual Facts:\n");
for (String fact : contextualFacts) {
sb.append("- ").append(fact).append("\n");
}
return sb.toString();
}
}

View File

@ -1,7 +1,19 @@
package com.ioa.config; package com.ioa.config;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.MimeTypeUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.converter.DefaultContentTypeResolver;
import java.util.List;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@ -18,6 +30,34 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").withSockJS(); registry.addEndpoint("/ws")
.setAllowedOriginPatterns("*")
.withSockJS();
}
@Override
public boolean configureMessageConverters(List<MessageConverter> messageConverters) {
MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter();
DefaultContentTypeResolver resolver = new DefaultContentTypeResolver();
resolver.setDefaultMimeType(MimeTypeUtils.APPLICATION_JSON);
converter.setContentTypeResolver(resolver);
messageConverters.add(converter);
return false;
}
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (accessor != null && accessor.getCommand() != null) {
System.out.println("Received STOMP Frame: " + accessor.getCommand());
System.out.println("Destination: " + accessor.getDestination());
System.out.println("Payload: " + new String((byte[]) message.getPayload()));
}
return message;
}
});
} }
} }

View File

@ -1,6 +1,5 @@
package com.ioa.conversation; package com.ioa.conversation;
import java.util.stream.Collectors;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
@ -8,9 +7,8 @@ import org.springframework.stereotype.Component;
import java.util.*; import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
@Component @Component
public class ConversationFSM { public class ConversationFSM {
@ -19,31 +17,42 @@ public class ConversationFSM {
private final WebSocketService webSocketService; private final WebSocketService webSocketService;
private final Queue<Message> messageQueue; private final Queue<Message> messageQueue;
private final List<AgentInfo> participants; private final List<AgentInfo> participants;
private final Queue<AgentInfo> speakingQueue;
private final Map<String, Boolean> votes; private final Map<String, Boolean> votes;
private final AtomicBoolean finished; private final AtomicBoolean finished;
private String result; private String result;
private final ScheduledExecutorService executorService; private String conversationId;
public ConversationFSM(BedrockLanguageModel model, WebSocketService webSocketService) { public ConversationFSM(BedrockLanguageModel model, WebSocketService webSocketService) {
this.executorService = Executors.newScheduledThreadPool(1);
this.currentState = ConversationState.DISCUSSION; this.currentState = ConversationState.DISCUSSION;
this.model = model; this.model = model;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.messageQueue = new ConcurrentLinkedQueue<>(); this.messageQueue = new ConcurrentLinkedQueue<>();
this.participants = new ArrayList<>(); this.participants = new ArrayList<>();
this.speakingQueue = new ConcurrentLinkedQueue<>();
this.votes = new HashMap<>(); this.votes = new HashMap<>();
this.finished = new AtomicBoolean(false); this.finished = new AtomicBoolean(false);
this.result = ""; this.result = "";
} }
public void initialize(String conversationId) {
this.conversationId = conversationId;
}
public String getConversationId() {
return conversationId;
}
public void addParticipant(AgentInfo agent) { public void addParticipant(AgentInfo agent) {
participants.add(agent); participants.add(agent);
speakingQueue.offer(agent);
votes.put(agent.getId(), false); votes.put(agent.getId(), false);
webSocketService.sendUpdate("conversation_participants", participants); webSocketService.sendUpdate("conversation_participants", participants);
} }
public void removeParticipant(AgentInfo agent) { public void removeParticipant(AgentInfo agent) {
participants.remove(agent); participants.remove(agent);
speakingQueue.remove(agent);
votes.remove(agent.getId()); votes.remove(agent.getId());
webSocketService.sendUpdate("conversation_participants", participants); webSocketService.sendUpdate("conversation_participants", participants);
} }
@ -61,62 +70,69 @@ public class ConversationFSM {
} }
private void handleMessage(Message message) { private void handleMessage(Message message) {
if (message == null) { if (message.getContent().startsWith("/vote")) {
System.out.println("DEBUG: Received null message");
return;
}
String content = message.getContent();
if (content == null) {
System.out.println("DEBUG: Message content is null");
return;
}
System.out.println("DEBUG: Received message: " + content);
if (content.startsWith("/vote")) {
handleVote(message.getSender()); handleVote(message.getSender());
} else { } else {
String stateTransitionTask = "Decide the next conversation state based on this message: " + content + broadcastMessage(message);
"\nCurrent state: " + currentState + updateState(message);
"\nParticipants: " + participants + notifyNextSpeaker();
"\nPossible states: " + Arrays.toString(ConversationState.values());
String reasoning = model.generate(stateTransitionTask, null);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state. Choose from: " +
Arrays.toString(ConversationState.values()) +
"\nResponse format: STATE: <state_name>";
String response = model.generate(decisionPrompt, null);
ConversationState newState = parseStateFromResponse(response);
transitionTo(newState);
// Broadcast the message to all participants
for (AgentInfo agent : participants) {
if (!agent.getId().equals(message.getSender())) {
agent.receiveMessage(message);
}
}
webSocketService.sendUpdate("conversation_message", message);
} }
} }
private ConversationState parseStateFromResponse(String response) { private void broadcastMessage(Message message) {
String[] parts = response.split(":"); for (AgentInfo agent : participants) {
if (parts.length > 1) { if (!agent.getId().equals(message.getSender())) {
String stateName = parts[1].trim().toUpperCase(); agent.receiveMessage(message);
try {
return ConversationState.valueOf(stateName);
} catch (IllegalArgumentException e) {
System.out.println("Invalid state name: " + stateName + ". Defaulting to DISCUSSION.");
return ConversationState.DISCUSSION;
} }
} }
System.out.println("Could not parse state from response: " + response + ". Defaulting to DISCUSSION."); webSocketService.sendUpdate("conversation_message", message);
return ConversationState.DISCUSSION; }
private void updateState(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState +
"\nParticipants: " + participants +
"\nProvide the next conversation state (" +
"DISCUSSION," +
"RESEARCH," +
"RESEARCH_TASK," +
"TASK_GATHERING_INFO," +
"TASK," +
"TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String reasoning = model.generate(stateTransitionTask, null);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (" +
"DISCUSSION," +
"RESEARCH," +
"RESEARCH_TASK," +
"TASK_GATHERING_INFO," +
"TASK," +
"TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String response = model.generate(decisionPrompt, null);
ConversationState newState = ConversationState.valueOf(response.trim());
transitionTo(newState);
}
private void notifyNextSpeaker() {
AgentInfo nextSpeaker = speakingQueue.poll();
if (nextSpeaker != null) {
// Instead of calling notifyTurn, we'll create a turn notification message
Message turnNotification = new Message(this.conversationId, "SYSTEM",
"It's " + nextSpeaker.getName() + "'s turn to speak.");
broadcastMessage(turnNotification);
speakingQueue.offer(nextSpeaker);
}
} }
private void handleVote(String agentId) { private void handleVote(String agentId) {
@ -152,8 +168,7 @@ public class ConversationFSM {
} }
private List<String> getPossibleTransitions() { private List<String> getPossibleTransitions() {
// This is a simplified version. You might want to implement more complex logic. return Arrays.stream(ConversationState.values())
return Arrays.asList(ConversationState.values()).stream()
.map(Enum::name) .map(Enum::name)
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View File

@ -3,32 +3,52 @@ package com.ioa.conversation;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.*; import java.util.Map;
import java.util.concurrent.*; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@Component @Component
public class ConversationManager { public class ConversationManager {
private final Map<String, ConversationFSM> conversations; private final Map<String, ConversationFSM> conversations = new ConcurrentHashMap<>();
private final BedrockLanguageModel model; private final BedrockLanguageModel model;
private final WebSocketService webSocketService; private final WebSocketService webSocketService;
private final ScheduledExecutorService executorService;
@Autowired
public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) { public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
this.conversations = new ConcurrentHashMap<>();
this.model = model; this.model = model;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.executorService = Executors.newScheduledThreadPool(1);
} }
public String createConversation() { public String createConversation() {
String conversationId = UUID.randomUUID().toString(); String conversationId = UUID.randomUUID().toString();
ConversationFSM conversation = new ConversationFSM(model, webSocketService); ConversationFSM conversation = new ConversationFSM(model, webSocketService);
conversation.initialize(conversationId);
conversations.put(conversationId, conversation); conversations.put(conversationId, conversation);
return conversationId; return conversationId;
} }
public void joinConversation(String conversationId) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
// Logic to join a conversation (e.g., add participant)
webSocketService.sendUpdate("conversation_joined", conversationId);
}
}
public void postMessage(String conversationId, String senderId, String content) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
conversation.postMessage(new Message(conversationId, senderId, content));
}
}
public ConversationFSM getConversation(String conversationId) {
return conversations.get(conversationId);
}
public void addParticipant(String conversationId, AgentInfo agent) { public void addParticipant(String conversationId, AgentInfo agent) {
ConversationFSM conversation = conversations.get(conversationId); ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) { if (conversation != null) {
@ -43,27 +63,13 @@ public class ConversationManager {
} }
} }
public void postMessage(String conversationId, String senderId, String content) {
System.out.println("DEBUG: Posting message - ConversationId: " + conversationId + ", SenderId: " + senderId + ", Content: " + content);
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
if (content == null) {
System.out.println("WARNING: Attempting to post null content message");
return;
}
conversation.postMessage(new Message(conversationId, senderId, content));
} else {
System.out.println("WARNING: Conversation not found for id: " + conversationId);
}
}
public void startConversation(String conversationId, String initialMessage) { public void startConversation(String conversationId, String initialMessage) {
ConversationFSM conversation = conversations.get(conversationId); ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) { if (conversation != null) {
conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage)); conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage));
// Start a timer to end the conversation after 10 minutes // Start a timer to end the conversation after 10 minutes
executorService.schedule(() -> { Executor.schedule(() -> {
if (!conversation.isFinished()) { if (!conversation.isFinished()) {
conversation.finish("Time limit reached"); conversation.finish("Time limit reached");
} }

View File

@ -2,7 +2,11 @@ package com.ioa.conversation;
public enum ConversationState { public enum ConversationState {
DISCUSSION, DISCUSSION,
RESEARCH,
RESEARCH_TASK,
TASK_GATHERING_INFO, TASK_GATHERING_INFO,
TASK,
TASK_PLANNING,
TASK_ASSIGNMENT, TASK_ASSIGNMENT,
EXECUTION, EXECUTION,
CONCLUSION CONCLUSION

View File

@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Base64; import java.util.Base64;
import java.util.function.Consumer;
@Component @Component
public class BedrockLanguageModel { public class BedrockLanguageModel {
@ -36,33 +37,7 @@ public class BedrockLanguageModel {
public String generate(String prompt, String imagePath) { public String generate(String prompt, String imagePath) {
System.out.println("DEBUG: Generating response for prompt: " + prompt); System.out.println("DEBUG: Generating response for prompt: " + prompt);
try { try {
ObjectNode requestBody = objectMapper.createObjectNode(); ObjectNode requestBody = createRequestBody(prompt, imagePath);
requestBody.put("anthropic_version", "bedrock-2023-05-31");
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image"); // Add type field
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text"); // Add type field
textNode.put("text", prompt);
String jsonPayload = objectMapper.writeValueAsString(requestBody); String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder() InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
@ -100,4 +75,77 @@ public class BedrockLanguageModel {
return "Error: " + e.getMessage(); return "Error: " + e.getMessage();
} }
} }
public void generateStream(String prompt, String imagePath, Consumer<String> chunkConsumer) {
System.out.println("DEBUG: Generating streaming response for prompt: " + prompt);
try {
ObjectNode requestBody = createRequestBody(prompt, imagePath);
String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
.modelId(modelId)
.contentType("application/json")
.accept("application/json")
.body(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest);
String responseBody = response.body().asUtf8String();
System.out.println("DEBUG: Raw response from Bedrock: " + responseBody);
JsonNode responseJson = objectMapper.readTree(responseBody);
JsonNode contentArray = responseJson.path("content");
if (contentArray.isArray() && contentArray.size() > 0) {
for (JsonNode content : contentArray) {
String chunk = content.path("text").asText();
chunkConsumer.accept(chunk);
}
} else {
System.out.println("WARNING: Unexpected response format. Full response: " + responseBody);
chunkConsumer.accept("Unexpected response format");
}
} catch (Exception e) {
System.out.println("ERROR: Failed to generate streaming text with Bedrock: " + e.getMessage());
e.printStackTrace();
chunkConsumer.accept("Error: " + e.getMessage());
}
}
private ObjectNode createRequestBody(String prompt, String imagePath) {
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("anthropic_version", "bedrock-2023-05-31");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
try {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image");
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
} catch (Exception e) {
System.err.println("Error reading image file: " + e.getMessage());
e.printStackTrace();
}
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text");
textNode.put("text", prompt);
return requestBody;
}
} }

View File

@ -1,62 +1,17 @@
package com.ioa.service; package com.ioa.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy;
@Service @Service
public class WebSocketService { public class WebSocketService {
private final SimpMessagingTemplate messagingTemplate; private final SimpMessagingTemplate messagingTemplate;
private final ConversationManager conversationManager;
private final ObjectMapper objectMapper;
@Autowired public WebSocketService(SimpMessagingTemplate messagingTemplate) {
public WebSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
this.messagingTemplate = messagingTemplate; this.messagingTemplate = messagingTemplate;
this.conversationManager = conversationManager;
this.objectMapper = new ObjectMapper();
} }
public void sendUpdate(String topic, Object payload) { public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload); messagingTemplate.convertAndSend("/topic/" + topic, payload);
} }
public void handleWebSocketMessage(String message) {
System.out.println("DEBUG: Received WebSocket message: " + message);
// Parse the WebSocket frame
String[] parts = message.split("\n\n", 2);
if (parts.length < 2) {
System.out.println("DEBUG: Invalid WebSocket message format");
return;
}
String headers = parts[0];
String payload = parts[1];
// Parse the JSON payload
try {
JsonNode jsonNode = objectMapper.readTree(payload);
// Extract relevant information from the JSON
// Adjust this based on the actual structure of your WebSocket messages
String conversationId = jsonNode.path("conversationId").asText();
String sender = jsonNode.path("sender").asText();
String content = jsonNode.path("content").asText();
// Create a new Message object
Message parsedMessage = new Message(conversationId, sender, content);
// Process the message
conversationManager.postMessage(conversationId, sender, content);
} catch (Exception e) {
System.out.println("DEBUG: Error parsing WebSocket message: " + e.getMessage());
e.printStackTrace();
}
}
} }

View File

@ -1,11 +1,8 @@
package com.ioa.task; package com.ioa.task;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import lombok.Data;
import java.util.List; import java.util.List;
@Data
public class Task { public class Task {
private String id; private String id;
private String description; private String description;
@ -14,10 +11,28 @@ public class Task {
private AgentInfo assignedAgent; private AgentInfo assignedAgent;
private String result; private String result;
// Default constructor
public Task() {}
// Existing constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) { public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id; this.id = id;
this.description = description; this.description = description;
this.requiredCapabilities = requiredCapabilities; this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools; this.requiredTools = requiredTools;
} }
// Getters and setters for all fields
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getDescription() { return description; }
public void setDescription(String description) { this.description = description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public void setRequiredCapabilities(List<String> requiredCapabilities) { this.requiredCapabilities = requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public void setRequiredTools(List<String> requiredTools) { this.requiredTools = requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
} }

View File

@ -0,0 +1,14 @@
package com.ioa.task;
import java.util.List;
public class TaskAssignment {
private String taskId;
private List<String> agentIds;
// Getters and setters
public String getTaskId() { return taskId; }
public void setTaskId(String taskId) { this.taskId = taskId; }
public List<String> getAgentIds() { return agentIds; }
public void setAgentIds(List<String> agentIds) { this.agentIds = agentIds; }
}

View File

@ -1,93 +1,133 @@
package com.ioa.task; package com.ioa.task;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry; import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationFSM;
import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought; import com.ioa.util.TreeOfThought;
import com.ioa.conversation.ConversationManager;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
@Component @Component
public class TaskManager { public class TaskManager {
private Map<String, Task> tasks = new HashMap<>(); private Map<String, Task> tasks = new ConcurrentHashMap<>();
private final SimpMessagingTemplate messagingTemplate;
private AgentRegistry agentRegistry; private AgentRegistry agentRegistry;
private BedrockLanguageModel model; private BedrockLanguageModel model;
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
private TreeOfThought treeOfThought; private TreeOfThought treeOfThought;
private ConversationManager conversationManager; private ConversationManager conversationManager;
private WebSocketService webSocketService;
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, TreeOfThought treeOfThought, ConversationManager conversationManager) { @Autowired
private TeamFormation teamFormation;
@Autowired
public TaskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
this.agentRegistry = agentRegistry; this.agentRegistry = agentRegistry;
this.model = model; this.model = model;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.treeOfThought = treeOfThought; this.treeOfThought = treeOfThought;
this.conversationManager = conversationManager; this.conversationManager = conversationManager;
this.webSocketService = webSocketService;
this.messagingTemplate = messagingTemplate;
}
public String createTask(Task task) {
String taskId = UUID.randomUUID().toString();
task.setId(taskId);
tasks.put(taskId, task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
// Automatically assign agents and execute the task
List<AgentInfo> team = teamFormation.formTeam(task);
if (!team.isEmpty()) {
executeTask(taskId, team);
} else {
System.out.println("No suitable agents found for task: " + taskId);
}
return taskId;
}
public void assignTask(String taskId, List<String> agentIds) {
Task task = tasks.get(taskId);
if (task != null) {
List<AgentInfo> team = agentIds.stream()
.map(agentRegistry::getAgent)
.filter(Objects::nonNull)
.collect(Collectors.toList());
executeTask(taskId, team);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task assigned to: " + String.join(", ", agentIds));
}
} }
public void addTask(Task task) { public void addTask(Task task) {
tasks.put(task.getId(), task); tasks.put(task.getId(), task);
} }
public void executeTask(String taskId, String conversationId) { public void executeTask(String taskId, List<AgentInfo> team) {
Task task = tasks.get(taskId); Task task = tasks.get(taskId);
AgentInfo agent = task.getAssignedAgent(); if (task == null) {
System.out.println("ERROR: Task with ID " + taskId + " not found.");
System.out.println("DEBUG: Executing task: " + taskId + " for agent: " + agent.getId()); return;
conversationManager.postMessage(conversationId, agent.getId(), "Starting task: " + task.getDescription());
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
"\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools();
System.out.println("DEBUG: Generating execution plan for task: " + taskId);
Map<String, Object> reasoningResult = treeOfThought.reason(executionPlanningTask, 3, 2);
String reasoning = (String) reasoningResult.get("reasoning");
System.out.println("DEBUG: Execution plan generated: " + reasoning);
if (reasoning == null || reasoning.isEmpty()) {
System.out.println("WARNING: Empty execution plan generated for task: " + taskId);
reasoning = "No execution plan generated. Proceeding with a general approach to organize execution plan.";
} }
conversationManager.postMessage(conversationId, agent.getId(), "Task execution plan:\n" + reasoning); String conversationId = conversationManager.createConversation();
ConversationFSM conversation = conversationManager.getConversation(conversationId);
String executionPrompt = "Based on this execution plan:\n" + reasoning + for (AgentInfo agent : team) {
"\nExecute the task using the available tools and provide the result."; conversation.addParticipant(agent);
Map<String, Object> executionResult = treeOfThought.reason(executionPrompt, 1, 1);
String response = (String) executionResult.get("reasoning");
if (response == null || response.isEmpty()) {
System.out.println("WARNING: Empty response generated for task execution: " + taskId);
response = "Unable to execute the task due to technical difficulties. Please try again or seek assistance.";
} }
String result = executeToolsFromResponse(response, agent); conversation.postMessage(new Message(conversationId, "SYSTEM", "Let's work on the task: " + task.getDescription()));
if (result == null || result.isEmpty()) { // Use the LLM to generate a plan for the task
result = "No specific actions were taken based on the execution plan. Please review the plan and provide more detailed instructions if necessary."; String planPrompt = "Generate a step-by-step plan to accomplish the following task: " + task.getDescription();
} String plan = model.generate(planPrompt, null);
conversation.postMessage(new Message(conversationId, "SYSTEM", "Here's the plan: " + plan));
task.setResult(result); // Allow agents to interact for a maximum of 40 minutes
long startTime = System.currentTimeMillis();
conversationManager.postMessage(conversationId, agent.getId(), "Task result: " + result); while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(40)) {
} try {
Thread.sleep(1000); // Check every second
private String executeToolsFromResponse(String response, AgentInfo agent) { } catch (InterruptedException e) {
StringBuilder result = new StringBuilder(); e.printStackTrace();
for (String tool : agent.getTools()) {
if (response.contains(tool)) {
Object toolInstance = toolRegistry.getTool(tool);
// Execute the tool (this is a simplified representation)
result.append(tool).append(" result: ").append(toolInstance.toString()).append("\n");
} }
} }
return result.toString();
if (!conversation.isFinished()) {
conversation.finish("Time limit reached");
}
String result = conversation.getResult();
task.setResult(result);
System.out.println("Task completed. Result: " + result);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task completed: " + result);
}
public Task getTask(String taskId) {
return tasks.get(taskId);
} }
} }

View File

@ -0,0 +1,31 @@
package com.ioa.task;
public class TaskResult {
private String taskId;
private String result;
// Constructors
public TaskResult() {}
public TaskResult(String taskId, String result) {
this.taskId = taskId;
this.result = result;
}
// Getters and setters
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
}

View File

@ -0,0 +1,29 @@
package com.ioa.tool;
public class ToolInfo {
private String agentId;
private String toolName;
public ToolInfo() {}
public ToolInfo(String agentId, String toolName) {
this.agentId = agentId;
this.toolName = toolName;
}
public String getAgentId() {
return agentId;
}
public void setAgentId(String agentId) {
this.agentId = agentId;
}
public String getToolName() {
return toolName;
}
public void setToolName(String toolName) {
this.toolName = toolName;
}
}

View File

@ -45,12 +45,8 @@ public class TreeOfThought {
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
String thought = model.generate(branchPrompt, null); String thought = model.generate(branchPrompt, null);
if (!thought.equals("No response generated") && !thought.startsWith("Error:")) { Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought); children.add(childNode);
children.add(childNode);
} else {
System.out.println("WARNING: Failed to generate thought. Result: " + thought);
}
} }
node.put("children", children); node.put("children", children);

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentProcessingUpdate {
private String agentId;
private String responseId;
private String status; // "start" or "complete"
public AgentProcessingUpdate(String agentId, String responseId, String status) {
this.agentId = agentId;
this.responseId = responseId;
this.status = status;
}
}

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentResponseUpdate {
private String agentId;
private String responseId;
private String chunk;
public AgentResponseUpdate(String agentId, String responseId, String chunk) {
this.agentId = agentId;
this.responseId = responseId;
this.chunk = chunk;
}
}

View File

@ -0,0 +1,50 @@
package com.ioa.websocket;
import com.ioa.conversation.ConversationManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.stereotype.Controller;
@Controller
public class ChatController {
private final ConversationManager conversationManager;
@Autowired
public ChatController(ConversationManager conversationManager) {
this.conversationManager = conversationManager;
}
@MessageMapping("/chat/create")
@SendTo("/topic/rooms")
public String createRoom(String roomName) {
String roomId = conversationManager.createConversation();
return "Room created: " + roomId;
}
@MessageMapping("/chat/join")
@SendTo("/topic/rooms")
public String joinRoom(String roomId) {
conversationManager.joinConversation(roomId);
return "Joined room: " + roomId;
}
@MessageMapping("/chat/message")
public void sendMessage(ChatMessage message) {
conversationManager.postMessage(message.getRoomId(), message.getSender(), message.getContent());
}
}
class ChatMessage {
private String roomId;
private String sender;
private String content;
// Getters and setters
public String getRoomId() { return roomId; }
public void setRoomId(String roomId) { this.roomId = roomId; }
public String getSender() { return sender; }
public void setSender(String sender) { this.sender = sender; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -0,0 +1,91 @@
package com.ioa.websocket;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Controller;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.task.Task;
import com.ioa.task.TaskAssignment;
import com.ioa.task.TaskManager;
import com.ioa.task.TaskResult;
import com.ioa.tool.ToolInfo;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
@Controller
public class WebSocketHandler {
private final AgentRegistry agentRegistry;
private final TaskManager taskManager;
private final SimpMessagingTemplate messagingTemplate;
private final TreeOfThought treeOfThought;
private final WebSocketService webSocketService;
private final ToolRegistry toolRegistry;
private final BedrockLanguageModel model;
@Autowired
public WebSocketHandler(AgentRegistry agentRegistry, TaskManager taskManager,
SimpMessagingTemplate messagingTemplate, TreeOfThought treeOfThought,
WebSocketService webSocketService, ToolRegistry toolRegistry,
BedrockLanguageModel model) {
this.agentRegistry = agentRegistry;
this.taskManager = taskManager;
this.messagingTemplate = messagingTemplate;
this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry;
this.model = model;
}
@MessageMapping("/agent/register")
public void registerAgent(@Payload AgentInfo agentInfo) {
// Initialize dependencies
agentInfo.setDependencies(treeOfThought, webSocketService, toolRegistry, model);
agentRegistry.registerAgent(agentInfo);
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
@MessageMapping("/tool/register")
public void registerTool(@Payload ToolInfo toolInfo) {
AgentInfo agent = agentRegistry.getAgent(toolInfo.getAgentId());
if (agent != null) {
agent.getTools().add(toolInfo.getToolName());
messagingTemplate.convertAndSend("/topic/tools", "New tool registered: " + toolInfo.getToolName());
}
}
@MessageMapping("/task/create")
public void createTask(@Payload Task task) {
System.out.println("Creating task: " + task.getId());
String taskId = taskManager.createTask(task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
}
@MessageMapping("/agent/unregister")
public void unregisterAgent(@Payload String agentId) {
agentRegistry.unregisterAgent(agentId);
messagingTemplate.convertAndSend("/topic/agent/" + agentId, "Unregistration successful");
}
@MessageMapping("/task/assign")
public void assignTask(@Payload TaskAssignment assignment) {
taskManager.assignTask(assignment.getTaskId(), assignment.getAgentIds());
messagingTemplate.convertAndSend("/topic/tasks/" + assignment.getTaskId(), "Task assigned");
}
@MessageMapping("/task/result")
public void handleTaskResult(@Payload TaskResult result) {
Task task = taskManager.getTask(result.getTaskId());
if (task != null) {
task.setResult(result.getResult());
messagingTemplate.convertAndSend("/topic/tasks/" + result.getTaskId(), "Task completed: " + result.getResult());
}
}
}

View File

@ -1,13 +1,26 @@
com/ioa/tool/common/TranslationTool.class
com/ioa/tool/common/PriceComparisonTool.class
com/ioa/service/WebSocketService.class com/ioa/service/WebSocketService.class
com/ioa/tool/common/WebSearchTool.class com/ioa/tool/common/WebSearchTool.class
com/ioa/conversation/ConversationFSM.class com/ioa/conversation/ConversationFSM.class
com/ioa/conversation/Message.class com/ioa/conversation/Message.class
com/ioa/util/TreeOfThought.class com/ioa/util/TreeOfThought.class
com/ioa/tool/common/NewsUpdateTool.class com/ioa/tool/common/NewsUpdateTool.class
com/ioa/tool/common/RestaurantFinderTool.class
com/ioa/conversation/ConversationFSM$ConversationStateUpdate.class com/ioa/conversation/ConversationFSM$ConversationStateUpdate.class
com/ioa/IoASystem.class
com/ioa/tool/ToolRegistry.class
com/ioa/conversation/ConversationState.class
com/ioa/tool/common/FinancialAdviceTool.class com/ioa/tool/common/FinancialAdviceTool.class
com/ioa/agent/AgentRegistry.class
com/ioa/team/TeamFormation.class com/ioa/team/TeamFormation.class
com/ioa/task/TaskManager.class
com/ioa/model/BedrockLanguageModel.class
com/ioa/tool/common/DistanceCalculatorTool.class
com/ioa/tool/common/MovieRecommendationTool.class com/ioa/tool/common/MovieRecommendationTool.class
com/ioa/tool/common/AppointmentSchedulerTool.class
com/ioa/agent/AgentInfo.class
com/ioa/task/Task.class
com/ioa/config/WebSocketConfig.class com/ioa/config/WebSocketConfig.class
com/ioa/conversation/ConversationManager.class com/ioa/conversation/ConversationManager.class
com/ioa/tool/common/WeatherTool.class com/ioa/tool/common/WeatherTool.class
@ -15,17 +28,4 @@ com/ioa/tool/common/TravelBookingTool.class
com/ioa/tool/common/FitnessClassFinderTool.class com/ioa/tool/common/FitnessClassFinderTool.class
com/ioa/tool/Tool.class com/ioa/tool/Tool.class
com/ioa/tool/common/ReminderTool.class com/ioa/tool/common/ReminderTool.class
com/ioa/tool/common/TranslationTool.class
com/ioa/tool/common/PriceComparisonTool.class
com/ioa/tool/common/RestaurantFinderTool.class
com/ioa/IoASystem.class
com/ioa/tool/ToolRegistry.class
com/ioa/conversation/ConversationState.class
com/ioa/agent/AgentRegistry.class
com/ioa/task/TaskManager.class
com/ioa/model/BedrockLanguageModel.class
com/ioa/tool/common/DistanceCalculatorTool.class
com/ioa/tool/common/AppointmentSchedulerTool.class
com/ioa/agent/AgentInfo.class
com/ioa/task/Task.class
com/ioa/tool/common/RecipeTool.class com/ioa/tool/common/RecipeTool.class

View File

@ -1,16 +1,30 @@
/Users/emkay/Projects/totioa/src/main/java/com/ioa/util/TreeOfThought.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/util/TreeOfThought.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/WeatherTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/ConversationFSM.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/ConversationFSM.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/model/BedrockLanguageModel.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/model/BedrockLanguageModel.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/ConversationState.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/ConversationState.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/task/Task.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/FinancialAdviceTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/Message.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/Message.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/service/WebSocketService.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/TravelBookingTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/agent/AgentInfo.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/agent/AgentInfo.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/DistanceCalculatorTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/TranslationTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/AppointmentSchedulerTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/agent/AgentRegistry.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/task/Task.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/MovieRecommendationTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/RestaurantFinderTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/conversation/ConversationManager.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/RecipeTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/service/WebSocketService.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/PriceComparisonTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/FitnessClassFinderTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/ReminderTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/Tool.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/Tool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/team/TeamFormation.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/team/TeamFormation.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/ToolRegistry.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/ToolRegistry.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/CommonTools.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/task/TaskManager.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/task/TaskManager.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/IoASystem.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/IoASystem.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/agent/AgentRegistry.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/NewsUpdateTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/tool/common/WebSearchTool.java
/Users/emkay/Projects/totioa/src/main/java/com/ioa/config/WebSocketConfig.java /Users/emkay/Projects/totioa/src/main/java/com/ioa/config/WebSocketConfig.java