Mostly working system

This commit is contained in:
Mahesh Kommareddi 2024-07-30 23:21:45 -04:00
parent 9cf98dbaf7
commit 6789f035b2
51 changed files with 623 additions and 365 deletions

3
README.md Normal file
View File

@ -0,0 +1,3 @@
mvn spring-boot:run
http://localhost:8080/

2
application.properties Normal file
View File

@ -0,0 +1,2 @@
logging.level.org.springframework.messaging=TRACE
logging.level.org.springframework.web.socket=TRACE

View File

@ -1,59 +1,51 @@
package com.ioa; package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationManager;
import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry;
import com.ioa.tool.Tool;
import com.ioa.tool.common.*;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.util.TreeOfThought;
import com.ioa.tool.common.AppointmentSchedulerTool;
import com.ioa.tool.common.DistanceCalculatorTool;
import com.ioa.tool.common.FinancialAdviceTool;
import com.ioa.tool.common.FitnessClassFinderTool;
import com.ioa.tool.common.MovieRecommendationTool;
import com.ioa.tool.common.NewsUpdateTool;
import com.ioa.tool.common.PriceComparisonTool;
import com.ioa.tool.common.RecipeTool;
import com.ioa.tool.common.ReminderTool;
import com.ioa.tool.common.RestaurantFinderTool;
import com.ioa.tool.common.TranslationTool;
import com.ioa.tool.common.TravelBookingTool;
import com.ioa.tool.common.WeatherTool;
import com.ioa.tool.common.WebSearchTool;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.ComponentScan;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import java.util.Arrays; import com.ioa.agent.AgentRegistry;
import java.util.List; import com.ioa.task.TaskManager;
import com.ioa.websocket.WebSocketHandler;
import java.util.concurrent.ExecutorService; import com.ioa.tool.ToolRegistry;
import java.util.concurrent.Executors; import com.ioa.model.BedrockLanguageModel;
import java.util.concurrent.TimeUnit; import com.ioa.util.TreeOfThought;
import java.util.concurrent.CountDownLatch; import com.ioa.service.WebSocketService;
import com.ioa.conversation.ConversationManager;
@SpringBootApplication @SpringBootApplication
@ComponentScan(basePackages = "com.ioa")
public class IoASystem { public class IoASystem {
@Bean @Bean
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) { public WebSocketHandler webSocketHandler(AgentRegistry agentRegistry,
return new WebSocketService(messagingTemplate, conversationManager); TaskManager taskManager,
SimpMessagingTemplate messagingTemplate,
TreeOfThought treeOfThought,
WebSocketService webSocketService,
ToolRegistry toolRegistry,
BedrockLanguageModel model) {
return new WebSocketHandler(agentRegistry, taskManager, messagingTemplate,
treeOfThought, webSocketService, toolRegistry, model);
} }
@Bean @Bean
public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) { public AgentRegistry agentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
return new ConversationManager(model, webSocketService); return new AgentRegistry(messagingTemplate, toolRegistry);
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought,
conversationManager, webSocketService, messagingTemplate);
} }
@Bean @Bean
@ -67,149 +59,21 @@ public class IoASystem {
} }
@Bean @Bean
public AgentRegistry agentRegistry(ToolRegistry toolRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, ConversationManager conversationManager) { public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate) {
AgentRegistry registry = new AgentRegistry(toolRegistry); return new WebSocketService(messagingTemplate);
// Agent creation is now moved to processTasksAndAgents method
return registry;
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, TreeOfThought treeOfThought, ConversationManager conversationManager) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought, conversationManager);
}
@Bean
public TeamFormation teamFormation(AgentRegistry agentRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, BedrockLanguageModel model) {
return new TeamFormation(agentRegistry, treeOfThought, webSocketService, model);
} }
@Bean @Bean
public ToolRegistry toolRegistry() { public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry(); return new ToolRegistry();
}
// Register all tools
toolRegistry.registerTool("webSearch", new WebSearchTool()); @Bean
toolRegistry.registerTool("getWeather", new WeatherTool()); public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
toolRegistry.registerTool("setReminder", new ReminderTool()); return new ConversationManager(model, webSocketService);
toolRegistry.registerTool("bookTravel", new TravelBookingTool());
toolRegistry.registerTool("calculateDistance", new DistanceCalculatorTool());
toolRegistry.registerTool("findRestaurants", new RestaurantFinderTool());
toolRegistry.registerTool("scheduleAppointment", new AppointmentSchedulerTool());
toolRegistry.registerTool("findFitnessClasses", new FitnessClassFinderTool());
toolRegistry.registerTool("getRecipe", new RecipeTool());
toolRegistry.registerTool("getNewsUpdates", new NewsUpdateTool());
toolRegistry.registerTool("translate", new TranslationTool());
toolRegistry.registerTool("compareProductPrices", new PriceComparisonTool());
toolRegistry.registerTool("getMovieRecommendations", new MovieRecommendationTool());
toolRegistry.registerTool("getFinancialAdvice", new FinancialAdviceTool());
return toolRegistry;
} }
public static void main(String[] args) { public static void main(String[] args) {
ConfigurableApplicationContext context = SpringApplication.run(IoASystem.class, args); SpringApplication.run(IoASystem.class, args);
IoASystem system = context.getBean(IoASystem.class);
system.processTasksAndAgents(context);
} }
}
public void processTasksAndAgents(ConfigurableApplicationContext context) {
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class);
TreeOfThought treeOfThought = context.getBean(TreeOfThought.class);
WebSocketService webSocketService = context.getBean(WebSocketService.class);
ToolRegistry toolRegistry = context.getBean(ToolRegistry.class);
ConversationManager conversationManager = context.getBean(ConversationManager.class);
BedrockLanguageModel model = context.getBean(BedrockLanguageModel.class);
// Register all agents
agentRegistry.registerAgent("agent1", new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent2", new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent3", new AgentInfo("agent3", "Event Planner Extraordinaire",
Arrays.asList("event planning", "team management", "booking"),
Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment", "getWeather"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent4", new AgentInfo("agent4", "Fitness Guru",
Arrays.asList("health", "nutrition", "motivation"),
Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent5", new AgentInfo("agent5", "Research Specialist",
Arrays.asList("research", "writing", "analysis"),
Arrays.asList("webSearch", "getNewsUpdates", "translate", "compareProductPrices"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent6", new AgentInfo("agent6", "Digital Marketing Expert",
Arrays.asList("marketing", "social media", "content creation"),
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "getMovieRecommendations"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent7", new AgentInfo("agent7", "Family Travel Coordinator",
Arrays.asList("travel", "family planning", "budgeting"),
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants", "getFinancialAdvice"),
treeOfThought, webSocketService, toolRegistry, model));
// Create all tasks
List<Task> tasks = Arrays.asList(
new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"))//,
// new Task("task2", "Organize a corporate team-building event in New York",
// Arrays.asList("event planning", "team management"),
// Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment")),
// new Task("task3", "Develop a personalized fitness and nutrition plan",
// Arrays.asList("health", "nutrition"),
// Arrays.asList("getWeather", "findFitnessClasses", "getRecipe")),
// new Task("task4", "Research and summarize recent advancements in renewable energy",
// Arrays.asList("research", "writing"),
// Arrays.asList("webSearch", "getNewsUpdates", "translate")),
// new Task("task5", "Plan and execute a social media marketing campaign for a new product launch",
// Arrays.asList("marketing", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
// new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
// Arrays.asList("travel", "family planning"),
// Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants")),
// new Task("task7", "Organize an international tech conference with virtual and in-person components",
// Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
// Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates"))//,
// new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
// Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
// Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
// new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
// Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
// Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
// new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
// Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
// new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
// Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
// Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
);
for (Task task : tasks) {
taskManager.addTask(task); // Add each task to the TaskManager
System.out.println("\nProcessing task: " + task.getDescription());
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
if (!team.isEmpty()) {
taskManager.executeTask(task.getId(), team);
System.out.println("Task result: " + task.getResult());
} else {
System.out.println("No suitable agents found for this task. Consider updating the agent pool or revising the task requirements.");
}
}
}
}

View File

@ -1,15 +1,21 @@
package com.ioa.agent; package com.ioa.agent;
import com.ioa.conversation.ConversationFSM;
import com.ioa.conversation.Message; import com.ioa.conversation.Message;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought; import com.ioa.util.TreeOfThought;
import com.ioa.websocket.AgentProcessingUpdate;
import com.ioa.websocket.AgentResponseUpdate;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.UUID;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
public class AgentInfo { public class AgentInfo {
private String id; private String id;
private String name; private String name;
@ -21,59 +27,56 @@ public class AgentInfo {
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
private BedrockLanguageModel model; private BedrockLanguageModel model;
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools, public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, BedrockLanguageModel model) {
this.id = id; this.id = id;
this.name = name; this.name = name;
this.capabilities = capabilities; this.capabilities = capabilities;
this.tools = tools; this.tools = tools;
this.memory = new Memory();
}
public void setDependencies(TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, BedrockLanguageModel model) {
this.treeOfThought = treeOfThought; this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.model = model; this.model = model;
this.memory = new Memory();
} }
public void receiveMessage(Message message) { public void receiveMessageLocal(Message message) {
if (this.memory == null)
this.memory = new Memory();
memory.addToHistory(message.getContent()); memory.addToHistory(message.getContent());
String prompt = "You are " + name + " with capabilities: " + capabilities + String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nYou received a message: " + message.getContent() + "\nYou received a message: " + message.getContent() +
"\nBased on your memory and context, how would you respond or what actions would you take?" + "\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory(); "\n\nMemory:\n" + memory.getFormattedMemory();
String response = model.generate(prompt, null); String response = model.generate(prompt, null);
System.out.println("DEBUG: " + name + " processed message: " + message.getContent()); System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
System.out.println("DEBUG: " + name + " response: " + response); System.out.println("DEBUG: " + name + " response: " + response);
// Add the response to memory // Add the response to memory
memory.addToHistory("My response: " + response); memory.addToHistory("My response: " + response);
} }
public void performTreeOfThought(String task) { public void receiveMessage(Message message) {
String prompt = "You are " + name + " with capabilities: " + capabilities + if (this.memory == null)
"\nTask: " + task + this.memory = new Memory();
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
"\n\nMemory:\n" + memory.getFormattedMemory();
// This is a turn notification for this agent
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2); respondToTurn(message.getConversationId());
String reasoning = (String) totResult.get("reasoning");
// Add the reasoning to memory
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
} }
public void notifyTurn(ConversationFSM conversation) { private void respondToTurn(String conversationId) {
String prompt = "You are " + name + " with capabilities: " + capabilities + String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nIt's your turn to speak in the conversation. What would you like to say or do?"; "\nIt's your turn to speak in the conversation. What would you like to say or do?" +
String response = model.generate(prompt, null); "\n\nMemory:\n" + memory.getFormattedMemory();
conversation.postMessage(new Message(conversation.getConversationId(), id, response));
// Send the response back to the conversation
Message responseMessage = new Message(conversationId, this.id, prompt);
webSocketService.sendUpdate("conversation_message", responseMessage);
} }
// Getters and setters
public String getId() { return id; }
public String getName() { return name; }
public List<String> getCapabilities() { return capabilities; }
public List<String> getTools() { return tools; }
} }

View File

@ -0,0 +1,12 @@
package com.ioa.agent;
public class AgentMessage {
private String agentId;
private String content;
// Getters and setters
public String getAgentId() { return agentId; }
public void setAgentId(String agentId) { this.agentId = agentId; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -1,34 +1,53 @@
package com.ioa.agent; package com.ioa.agent;
import com.ioa.tool.ToolRegistry; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import com.ioa.tool.ToolRegistry;
import java.util.*; import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Component @Component
public class AgentRegistry { public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>(); private Map<String, AgentInfo> agents = new ConcurrentHashMap<>();
private ToolRegistry toolRegistry; private final SimpMessagingTemplate messagingTemplate;
private final ToolRegistry toolRegistry;
public AgentRegistry(ToolRegistry toolRegistry) { @Autowired
public AgentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
this.messagingTemplate = messagingTemplate;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
} }
public void registerAgent(String agentId, AgentInfo agentInfo) { public void registerAgent(AgentInfo agentInfo) {
agents.put(agentId, agentInfo); agents.put(agentInfo.getId(), agentInfo);
// Verify that all tools the agent claims to have are registered // Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) { for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) { if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool); throw new IllegalArgumentException("Tool not found in registry: " + tool);
} }
} }
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
public void unregisterAgent(String agentId) {
agents.remove(agentId);
messagingTemplate.convertAndSend("/topic/agents", "Agent unregistered: " + agentId);
} }
public AgentInfo getAgent(String agentId) { public AgentInfo getAgent(String agentId) {
return agents.get(agentId); return agents.get(agentId);
} }
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
public List<AgentInfo> searchAgents(List<String> capabilities) { public List<AgentInfo> searchAgents(List<String> capabilities) {
return searchAgents(capabilities, 1.0); // Default to exact match return searchAgents(capabilities, 1.0); // Default to exact match
} }
@ -50,8 +69,4 @@ public class AgentRegistry {
.count(); .count();
return (double) matchingCapabilities / requiredCapabilities.size(); return (double) matchingCapabilities / requiredCapabilities.size();
} }
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
} }

View File

@ -1,7 +1,19 @@
package com.ioa.config; package com.ioa.config;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.MimeTypeUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.converter.DefaultContentTypeResolver;
import java.util.List;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@ -18,6 +30,34 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override @Override
public void registerStompEndpoints(StompEndpointRegistry registry) { public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").withSockJS(); registry.addEndpoint("/ws")
.setAllowedOriginPatterns("*")
.withSockJS();
} }
}
@Override
public boolean configureMessageConverters(List<MessageConverter> messageConverters) {
MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter();
DefaultContentTypeResolver resolver = new DefaultContentTypeResolver();
resolver.setDefaultMimeType(MimeTypeUtils.APPLICATION_JSON);
converter.setContentTypeResolver(resolver);
messageConverters.add(converter);
return false;
}
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (accessor != null && accessor.getCommand() != null) {
System.out.println("Received STOMP Frame: " + accessor.getCommand());
System.out.println("Destination: " + accessor.getDestination());
System.out.println("Payload: " + new String((byte[]) message.getPayload()));
}
return message;
}
});
}
}

View File

@ -91,18 +91,33 @@ public class ConversationFSM {
private void updateState(Message message) { private void updateState(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() + String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState + "\nCurrent state: " + currentState +
"\nParticipants: " + participants; "\nParticipants: " + participants +
"\nProvide the next conversation state (" +
"DISCUSSION," +
"RESEARCH," +
"RESEARCH_TASK," +
"TASK_GATHERING_INFO," +
"TASK," +
"TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String reasoning = model.generate(stateTransitionTask, null); String reasoning = model.generate(stateTransitionTask, null);
String decisionPrompt = "Based on this reasoning:\n" + reasoning + String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION,\n" + // "\nProvide the next conversation state (" +
" TASK_GATHERING_INFO,\n" + // "DISCUSSION," +
" TASK,\n" + // "RESEARCH," +
" TASK_PLANNING,\n" + // "RESEARCH_TASK," +
" TASK_ASSIGNMENT,\n" + // "TASK_GATHERING_INFO," +
" EXECUTION,\n" + // "TASK," +
" CONCLUSION). Only give the single word answer in all caps only from the given options."; "TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String response = model.generate(decisionPrompt, null); String response = model.generate(decisionPrompt, null);
ConversationState newState = ConversationState.valueOf(response.trim()); ConversationState newState = ConversationState.valueOf(response.trim());
@ -112,7 +127,10 @@ public class ConversationFSM {
private void notifyNextSpeaker() { private void notifyNextSpeaker() {
AgentInfo nextSpeaker = speakingQueue.poll(); AgentInfo nextSpeaker = speakingQueue.poll();
if (nextSpeaker != null) { if (nextSpeaker != null) {
nextSpeaker.notifyTurn(this); // Instead of calling notifyTurn, we'll create a turn notification message
Message turnNotification = new Message(this.conversationId, "SYSTEM",
"It's " + nextSpeaker.getName() + "'s turn to speak.");
broadcastMessage(turnNotification);
speakingQueue.offer(nextSpeaker); speakingQueue.offer(nextSpeaker);
} }
} }

View File

@ -3,23 +3,23 @@ package com.ioa.conversation;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.*; import java.util.Map;
import java.util.concurrent.*; import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@Component @Component
public class ConversationManager { public class ConversationManager {
private final Map<String, ConversationFSM> conversations; private final Map<String, ConversationFSM> conversations = new ConcurrentHashMap<>();
private final BedrockLanguageModel model; private final BedrockLanguageModel model;
private final WebSocketService webSocketService; private final WebSocketService webSocketService;
private final ScheduledExecutorService executorService;
@Autowired
public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) { public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
this.conversations = new ConcurrentHashMap<>();
this.model = model; this.model = model;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.executorService = Executors.newScheduledThreadPool(1);
} }
public String createConversation() { public String createConversation() {
@ -30,6 +30,21 @@ public class ConversationManager {
return conversationId; return conversationId;
} }
public void joinConversation(String conversationId) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
// Logic to join a conversation (e.g., add participant)
webSocketService.sendUpdate("conversation_joined", conversationId);
}
}
public void postMessage(String conversationId, String senderId, String content) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
conversation.postMessage(new Message(conversationId, senderId, content));
}
}
public ConversationFSM getConversation(String conversationId) { public ConversationFSM getConversation(String conversationId) {
return conversations.get(conversationId); return conversations.get(conversationId);
} }
@ -47,21 +62,6 @@ public class ConversationManager {
conversation.removeParticipant(agent); conversation.removeParticipant(agent);
} }
} }
public void postMessage(String conversationId, String senderId, String content) {
System.out.println("DEBUG: Posting message - ConversationId: " + conversationId + ", SenderId: " + senderId + ", Content: " + content);
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
if (content == null) {
Arrays.toString(Thread.currentThread().getStackTrace()).replace( ',', '\n' );
System.out.println("WARNING: Attempting to post null content message");
return;
}
conversation.postMessage(new Message(conversationId, senderId, content));
} else {
System.out.println("WARNING: Conversation not found for id: " + conversationId);
}
}
public void startConversation(String conversationId, String initialMessage) { public void startConversation(String conversationId, String initialMessage) {
ConversationFSM conversation = conversations.get(conversationId); ConversationFSM conversation = conversations.get(conversationId);
@ -69,7 +69,7 @@ public class ConversationManager {
conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage)); conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage));
// Start a timer to end the conversation after 10 minutes // Start a timer to end the conversation after 10 minutes
executorService.schedule(() -> { Executor.schedule(() -> {
if (!conversation.isFinished()) { if (!conversation.isFinished()) {
conversation.finish("Time limit reached"); conversation.finish("Time limit reached");
} }

View File

@ -2,6 +2,8 @@ package com.ioa.conversation;
public enum ConversationState { public enum ConversationState {
DISCUSSION, DISCUSSION,
RESEARCH,
RESEARCH_TASK,
TASK_GATHERING_INFO, TASK_GATHERING_INFO,
TASK, TASK,
TASK_PLANNING, TASK_PLANNING,

View File

@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Base64; import java.util.Base64;
import java.util.function.Consumer;
@Component @Component
public class BedrockLanguageModel { public class BedrockLanguageModel {
@ -36,33 +37,7 @@ public class BedrockLanguageModel {
public String generate(String prompt, String imagePath) { public String generate(String prompt, String imagePath) {
System.out.println("DEBUG: Generating response for prompt: " + prompt); System.out.println("DEBUG: Generating response for prompt: " + prompt);
try { try {
ObjectNode requestBody = objectMapper.createObjectNode(); ObjectNode requestBody = createRequestBody(prompt, imagePath);
requestBody.put("anthropic_version", "bedrock-2023-05-31");
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image"); // Add type field
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text"); // Add type field
textNode.put("text", prompt);
String jsonPayload = objectMapper.writeValueAsString(requestBody); String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder() InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
@ -100,4 +75,77 @@ public class BedrockLanguageModel {
return "Error: " + e.getMessage(); return "Error: " + e.getMessage();
} }
} }
public void generateStream(String prompt, String imagePath, Consumer<String> chunkConsumer) {
System.out.println("DEBUG: Generating streaming response for prompt: " + prompt);
try {
ObjectNode requestBody = createRequestBody(prompt, imagePath);
String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
.modelId(modelId)
.contentType("application/json")
.accept("application/json")
.body(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest);
String responseBody = response.body().asUtf8String();
System.out.println("DEBUG: Raw response from Bedrock: " + responseBody);
JsonNode responseJson = objectMapper.readTree(responseBody);
JsonNode contentArray = responseJson.path("content");
if (contentArray.isArray() && contentArray.size() > 0) {
for (JsonNode content : contentArray) {
String chunk = content.path("text").asText();
chunkConsumer.accept(chunk);
}
} else {
System.out.println("WARNING: Unexpected response format. Full response: " + responseBody);
chunkConsumer.accept("Unexpected response format");
}
} catch (Exception e) {
System.out.println("ERROR: Failed to generate streaming text with Bedrock: " + e.getMessage());
e.printStackTrace();
chunkConsumer.accept("Error: " + e.getMessage());
}
}
private ObjectNode createRequestBody(String prompt, String imagePath) {
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("anthropic_version", "bedrock-2023-05-31");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
try {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image");
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
} catch (Exception e) {
System.err.println("Error reading image file: " + e.getMessage());
e.printStackTrace();
}
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text");
textNode.put("text", prompt);
return requestBody;
}
} }

View File

@ -1,81 +1,17 @@
package com.ioa.service; package com.ioa.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
@Service @Service
public class WebSocketService { public class WebSocketService {
private final SimpMessagingTemplate messagingTemplate; private final SimpMessagingTemplate messagingTemplate;
private final ConversationManager conversationManager;
private final ObjectMapper objectMapper;
@Autowired public WebSocketService(SimpMessagingTemplate messagingTemplate) {
public WebSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
this.messagingTemplate = messagingTemplate; this.messagingTemplate = messagingTemplate;
this.conversationManager = conversationManager;
this.objectMapper = new ObjectMapper();
}
public void handleMessage(WebSocketSession session, TextMessage message) {
try {
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
if (messageArray.length > 0) {
String content = messageArray[0];
// Assume we're using a default conversation ID for simplicity
String conversationId = "default";
// Assume we're using a default user ID for simplicity
String userId = "user";
conversationManager.postMessage(conversationId, userId, content);
}
} catch (Exception e) {
e.printStackTrace();
}
} }
public void sendUpdate(String topic, Object payload) { public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload); messagingTemplate.convertAndSend("/topic/" + topic, payload);
} }
public void handleWebSocketMessage(String message) {
System.out.println("DEBUG: Received WebSocket message: " + message);
// Parse the WebSocket frame
String[] parts = message.split("\n\n", 2);
if (parts.length < 2) {
System.out.println("DEBUG: Invalid WebSocket message format");
return;
}
String headers = parts[0];
String payload = parts[1];
// Parse the JSON payload
try {
JsonNode jsonNode = objectMapper.readTree(payload);
// Extract relevant information from the JSON
// Adjust this based on the actual structure of your WebSocket messages
String conversationId = jsonNode.path("conversationId").asText();
String sender = jsonNode.path("sender").asText();
String content = jsonNode.path("content").asText();
// Create a new Message object
Message parsedMessage = new Message(conversationId, sender, content);
System.out.println("DEBUG: WebSocket message: " + payload);
// Process the message
conversationManager.postMessage(conversationId, sender, content);
} catch (Exception e) {
System.out.println("DEBUG: Error parsing WebSocket message: " + e.getMessage());
e.printStackTrace();
}
}
} }

View File

@ -1,11 +1,8 @@
package com.ioa.task; package com.ioa.task;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import lombok.Data;
import java.util.List; import java.util.List;
@Data
public class Task { public class Task {
private String id; private String id;
private String description; private String description;
@ -14,10 +11,28 @@ public class Task {
private AgentInfo assignedAgent; private AgentInfo assignedAgent;
private String result; private String result;
// Default constructor
public Task() {}
// Existing constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) { public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id; this.id = id;
this.description = description; this.description = description;
this.requiredCapabilities = requiredCapabilities; this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools; this.requiredTools = requiredTools;
} }
// Getters and setters for all fields
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getDescription() { return description; }
public void setDescription(String description) { this.description = description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public void setRequiredCapabilities(List<String> requiredCapabilities) { this.requiredCapabilities = requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public void setRequiredTools(List<String> requiredTools) { this.requiredTools = requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
} }

View File

@ -0,0 +1,14 @@
package com.ioa.task;
import java.util.List;
public class TaskAssignment {
private String taskId;
private List<String> agentIds;
// Getters and setters
public String getTaskId() { return taskId; }
public void setTaskId(String taskId) { this.taskId = taskId; }
public List<String> getAgentIds() { return agentIds; }
public void setAgentIds(List<String> agentIds) { this.agentIds = agentIds; }
}

View File

@ -1,37 +1,84 @@
package com.ioa.task; package com.ioa.task;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component;
import com.ioa.agent.AgentInfo; import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry; import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationFSM; import com.ioa.conversation.ConversationFSM;
import com.ioa.conversation.ConversationManager; import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import com.ioa.model.BedrockLanguageModel; import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService; import com.ioa.service.WebSocketService;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry; import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought; import com.ioa.util.TreeOfThought;
import com.ioa.conversation.Message;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@Component @Component
public class TaskManager { public class TaskManager {
private Map<String, Task> tasks = new HashMap<>(); private Map<String, Task> tasks = new ConcurrentHashMap<>();
private final SimpMessagingTemplate messagingTemplate;
private AgentRegistry agentRegistry; private AgentRegistry agentRegistry;
private BedrockLanguageModel model; private BedrockLanguageModel model;
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
private TreeOfThought treeOfThought; private TreeOfThought treeOfThought;
private ConversationManager conversationManager; private ConversationManager conversationManager;
private WebSocketService webSocketService;
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, @Autowired
TreeOfThought treeOfThought, ConversationManager conversationManager) { private TeamFormation teamFormation;
@Autowired
public TaskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
this.agentRegistry = agentRegistry; this.agentRegistry = agentRegistry;
this.model = model; this.model = model;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.treeOfThought = treeOfThought; this.treeOfThought = treeOfThought;
this.conversationManager = conversationManager; this.conversationManager = conversationManager;
this.webSocketService = webSocketService;
this.messagingTemplate = messagingTemplate;
}
public String createTask(Task task) {
String taskId = UUID.randomUUID().toString();
task.setId(taskId);
tasks.put(taskId, task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
// Automatically assign agents and execute the task
List<AgentInfo> team = teamFormation.formTeam(task);
if (!team.isEmpty()) {
executeTask(taskId, team);
} else {
System.out.println("No suitable agents found for task: " + taskId);
}
return taskId;
}
public void assignTask(String taskId, List<String> agentIds) {
Task task = tasks.get(taskId);
if (task != null) {
List<AgentInfo> team = agentIds.stream()
.map(agentRegistry::getAgent)
.filter(Objects::nonNull)
.collect(Collectors.toList());
executeTask(taskId, team);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task assigned to: " + String.join(", ", agentIds));
}
} }
public void addTask(Task task) { public void addTask(Task task) {
@ -54,9 +101,13 @@ public class TaskManager {
conversation.postMessage(new Message(conversationId, "SYSTEM", "Let's work on the task: " + task.getDescription())); conversation.postMessage(new Message(conversationId, "SYSTEM", "Let's work on the task: " + task.getDescription()));
// Allow agents to interact for a maximum of 10 minutes // Use the LLM to generate a plan for the task
String planPrompt = "Generate a step-by-step plan to accomplish the following task: " + task.getDescription();
String plan = model.generate(planPrompt, null);
conversation.postMessage(new Message(conversationId, "SYSTEM", "Here's the plan: " + plan));
// Allow agents to interact for a maximum of 40 minutes
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
//while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(10)) {
while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(40)) { while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(40)) {
try { try {
Thread.sleep(1000); // Check every second Thread.sleep(1000); // Check every second
@ -72,8 +123,10 @@ public class TaskManager {
String result = conversation.getResult(); String result = conversation.getResult();
task.setResult(result); task.setResult(result);
System.out.println("Task completed. Result: " + result); System.out.println("Task completed. Result: " + result);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task completed: " + result);
} }
public Task getTask(String taskId) { public Task getTask(String taskId) {
return tasks.get(taskId); return tasks.get(taskId);
} }

View File

@ -0,0 +1,31 @@
package com.ioa.task;
public class TaskResult {
private String taskId;
private String result;
// Constructors
public TaskResult() {}
public TaskResult(String taskId, String result) {
this.taskId = taskId;
this.result = result;
}
// Getters and setters
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
}

View File

@ -0,0 +1,29 @@
package com.ioa.tool;
public class ToolInfo {
private String agentId;
private String toolName;
public ToolInfo() {}
public ToolInfo(String agentId, String toolName) {
this.agentId = agentId;
this.toolName = toolName;
}
public String getAgentId() {
return agentId;
}
public void setAgentId(String agentId) {
this.agentId = agentId;
}
public String getToolName() {
return toolName;
}
public void setToolName(String toolName) {
this.toolName = toolName;
}
}

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentProcessingUpdate {
private String agentId;
private String responseId;
private String status; // "start" or "complete"
public AgentProcessingUpdate(String agentId, String responseId, String status) {
this.agentId = agentId;
this.responseId = responseId;
this.status = status;
}
}

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentResponseUpdate {
private String agentId;
private String responseId;
private String chunk;
public AgentResponseUpdate(String agentId, String responseId, String chunk) {
this.agentId = agentId;
this.responseId = responseId;
this.chunk = chunk;
}
}

View File

@ -0,0 +1,50 @@
package com.ioa.websocket;
import com.ioa.conversation.ConversationManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.stereotype.Controller;
@Controller
public class ChatController {
private final ConversationManager conversationManager;
@Autowired
public ChatController(ConversationManager conversationManager) {
this.conversationManager = conversationManager;
}
@MessageMapping("/chat/create")
@SendTo("/topic/rooms")
public String createRoom(String roomName) {
String roomId = conversationManager.createConversation();
return "Room created: " + roomId;
}
@MessageMapping("/chat/join")
@SendTo("/topic/rooms")
public String joinRoom(String roomId) {
conversationManager.joinConversation(roomId);
return "Joined room: " + roomId;
}
@MessageMapping("/chat/message")
public void sendMessage(ChatMessage message) {
conversationManager.postMessage(message.getRoomId(), message.getSender(), message.getContent());
}
}
class ChatMessage {
private String roomId;
private String sender;
private String content;
// Getters and setters
public String getRoomId() { return roomId; }
public void setRoomId(String roomId) { this.roomId = roomId; }
public String getSender() { return sender; }
public void setSender(String sender) { this.sender = sender; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -0,0 +1,91 @@
package com.ioa.websocket;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Controller;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.task.Task;
import com.ioa.task.TaskAssignment;
import com.ioa.task.TaskManager;
import com.ioa.task.TaskResult;
import com.ioa.tool.ToolInfo;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
@Controller
public class WebSocketHandler {
private final AgentRegistry agentRegistry;
private final TaskManager taskManager;
private final SimpMessagingTemplate messagingTemplate;
private final TreeOfThought treeOfThought;
private final WebSocketService webSocketService;
private final ToolRegistry toolRegistry;
private final BedrockLanguageModel model;
@Autowired
public WebSocketHandler(AgentRegistry agentRegistry, TaskManager taskManager,
SimpMessagingTemplate messagingTemplate, TreeOfThought treeOfThought,
WebSocketService webSocketService, ToolRegistry toolRegistry,
BedrockLanguageModel model) {
this.agentRegistry = agentRegistry;
this.taskManager = taskManager;
this.messagingTemplate = messagingTemplate;
this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry;
this.model = model;
}
@MessageMapping("/agent/register")
public void registerAgent(@Payload AgentInfo agentInfo) {
// Initialize dependencies
agentInfo.setDependencies(treeOfThought, webSocketService, toolRegistry, model);
agentRegistry.registerAgent(agentInfo);
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
@MessageMapping("/tool/register")
public void registerTool(@Payload ToolInfo toolInfo) {
AgentInfo agent = agentRegistry.getAgent(toolInfo.getAgentId());
if (agent != null) {
agent.getTools().add(toolInfo.getToolName());
messagingTemplate.convertAndSend("/topic/tools", "New tool registered: " + toolInfo.getToolName());
}
}
@MessageMapping("/task/create")
public void createTask(@Payload Task task) {
System.out.println("Creating task: " + task.getId());
String taskId = taskManager.createTask(task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
}
@MessageMapping("/agent/unregister")
public void unregisterAgent(@Payload String agentId) {
agentRegistry.unregisterAgent(agentId);
messagingTemplate.convertAndSend("/topic/agent/" + agentId, "Unregistration successful");
}
@MessageMapping("/task/assign")
public void assignTask(@Payload TaskAssignment assignment) {
taskManager.assignTask(assignment.getTaskId(), assignment.getAgentIds());
messagingTemplate.convertAndSend("/topic/tasks/" + assignment.getTaskId(), "Task assigned");
}
@MessageMapping("/task/result")
public void handleTaskResult(@Payload TaskResult result) {
Task task = taskManager.getTask(result.getTaskId());
if (task != null) {
task.setResult(result.getResult());
messagingTemplate.convertAndSend("/topic/tasks/" + result.getTaskId(), "Task completed: " + result.getResult());
}
}
}