Mostly working system
This commit is contained in:
parent
9cf98dbaf7
commit
6789f035b2
2
application.properties
Normal file
2
application.properties
Normal file
|
@ -0,0 +1,2 @@
|
|||
logging.level.org.springframework.messaging=TRACE
|
||||
logging.level.org.springframework.web.socket=TRACE
|
|
@ -1,59 +1,51 @@
|
|||
package com.ioa;
|
||||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.conversation.ConversationManager;
|
||||
import com.ioa.task.Task;
|
||||
import com.ioa.task.TaskManager;
|
||||
import com.ioa.team.TeamFormation;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.tool.Tool;
|
||||
import com.ioa.tool.common.*;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
|
||||
import com.ioa.tool.common.AppointmentSchedulerTool;
|
||||
import com.ioa.tool.common.DistanceCalculatorTool;
|
||||
import com.ioa.tool.common.FinancialAdviceTool;
|
||||
import com.ioa.tool.common.FitnessClassFinderTool;
|
||||
import com.ioa.tool.common.MovieRecommendationTool;
|
||||
import com.ioa.tool.common.NewsUpdateTool;
|
||||
import com.ioa.tool.common.PriceComparisonTool;
|
||||
import com.ioa.tool.common.RecipeTool;
|
||||
import com.ioa.tool.common.ReminderTool;
|
||||
import com.ioa.tool.common.RestaurantFinderTool;
|
||||
import com.ioa.tool.common.TranslationTool;
|
||||
import com.ioa.tool.common.TravelBookingTool;
|
||||
import com.ioa.tool.common.WeatherTool;
|
||||
import com.ioa.tool.common.WebSearchTool;
|
||||
|
||||
import org.springframework.context.ConfigurableApplicationContext;
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.context.annotation.ComponentScan;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.task.TaskManager;
|
||||
import com.ioa.websocket.WebSocketHandler;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.conversation.ConversationManager;
|
||||
|
||||
@SpringBootApplication
|
||||
@ComponentScan(basePackages = "com.ioa")
|
||||
public class IoASystem {
|
||||
|
||||
@Bean
|
||||
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
|
||||
return new WebSocketService(messagingTemplate, conversationManager);
|
||||
public WebSocketHandler webSocketHandler(AgentRegistry agentRegistry,
|
||||
TaskManager taskManager,
|
||||
SimpMessagingTemplate messagingTemplate,
|
||||
TreeOfThought treeOfThought,
|
||||
WebSocketService webSocketService,
|
||||
ToolRegistry toolRegistry,
|
||||
BedrockLanguageModel model) {
|
||||
return new WebSocketHandler(agentRegistry, taskManager, messagingTemplate,
|
||||
treeOfThought, webSocketService, toolRegistry, model);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
|
||||
return new ConversationManager(model, webSocketService);
|
||||
public AgentRegistry agentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
|
||||
return new AgentRegistry(messagingTemplate, toolRegistry);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TaskManager taskManager(AgentRegistry agentRegistry,
|
||||
BedrockLanguageModel model,
|
||||
ToolRegistry toolRegistry,
|
||||
TreeOfThought treeOfThought,
|
||||
ConversationManager conversationManager,
|
||||
WebSocketService webSocketService,
|
||||
SimpMessagingTemplate messagingTemplate) {
|
||||
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought,
|
||||
conversationManager, webSocketService, messagingTemplate);
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
@ -67,149 +59,21 @@ public class IoASystem {
|
|||
}
|
||||
|
||||
@Bean
|
||||
public AgentRegistry agentRegistry(ToolRegistry toolRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, ConversationManager conversationManager) {
|
||||
AgentRegistry registry = new AgentRegistry(toolRegistry);
|
||||
|
||||
// Agent creation is now moved to processTasksAndAgents method
|
||||
|
||||
return registry;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, TreeOfThought treeOfThought, ConversationManager conversationManager) {
|
||||
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought, conversationManager);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public TeamFormation teamFormation(AgentRegistry agentRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, BedrockLanguageModel model) {
|
||||
return new TeamFormation(agentRegistry, treeOfThought, webSocketService, model);
|
||||
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate) {
|
||||
return new WebSocketService(messagingTemplate);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ToolRegistry toolRegistry() {
|
||||
ToolRegistry toolRegistry = new ToolRegistry();
|
||||
|
||||
// Register all tools
|
||||
toolRegistry.registerTool("webSearch", new WebSearchTool());
|
||||
toolRegistry.registerTool("getWeather", new WeatherTool());
|
||||
toolRegistry.registerTool("setReminder", new ReminderTool());
|
||||
toolRegistry.registerTool("bookTravel", new TravelBookingTool());
|
||||
toolRegistry.registerTool("calculateDistance", new DistanceCalculatorTool());
|
||||
toolRegistry.registerTool("findRestaurants", new RestaurantFinderTool());
|
||||
toolRegistry.registerTool("scheduleAppointment", new AppointmentSchedulerTool());
|
||||
toolRegistry.registerTool("findFitnessClasses", new FitnessClassFinderTool());
|
||||
toolRegistry.registerTool("getRecipe", new RecipeTool());
|
||||
toolRegistry.registerTool("getNewsUpdates", new NewsUpdateTool());
|
||||
toolRegistry.registerTool("translate", new TranslationTool());
|
||||
toolRegistry.registerTool("compareProductPrices", new PriceComparisonTool());
|
||||
toolRegistry.registerTool("getMovieRecommendations", new MovieRecommendationTool());
|
||||
toolRegistry.registerTool("getFinancialAdvice", new FinancialAdviceTool());
|
||||
|
||||
return toolRegistry;
|
||||
return new ToolRegistry();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
|
||||
return new ConversationManager(model, webSocketService);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
ConfigurableApplicationContext context = SpringApplication.run(IoASystem.class, args);
|
||||
IoASystem system = context.getBean(IoASystem.class);
|
||||
system.processTasksAndAgents(context);
|
||||
SpringApplication.run(IoASystem.class, args);
|
||||
}
|
||||
|
||||
public void processTasksAndAgents(ConfigurableApplicationContext context) {
|
||||
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
|
||||
TeamFormation teamFormation = context.getBean(TeamFormation.class);
|
||||
TaskManager taskManager = context.getBean(TaskManager.class);
|
||||
TreeOfThought treeOfThought = context.getBean(TreeOfThought.class);
|
||||
WebSocketService webSocketService = context.getBean(WebSocketService.class);
|
||||
ToolRegistry toolRegistry = context.getBean(ToolRegistry.class);
|
||||
ConversationManager conversationManager = context.getBean(ConversationManager.class);
|
||||
BedrockLanguageModel model = context.getBean(BedrockLanguageModel.class);
|
||||
|
||||
|
||||
// Register all agents
|
||||
agentRegistry.registerAgent("agent1", new AgentInfo("agent1", "General Assistant",
|
||||
Arrays.asList("general", "search"),
|
||||
Arrays.asList("webSearch", "getWeather", "setReminder"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent2", new AgentInfo("agent2", "Travel Expert",
|
||||
Arrays.asList("travel", "booking"),
|
||||
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent3", new AgentInfo("agent3", "Event Planner Extraordinaire",
|
||||
Arrays.asList("event planning", "team management", "booking"),
|
||||
Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment", "getWeather"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent4", new AgentInfo("agent4", "Fitness Guru",
|
||||
Arrays.asList("health", "nutrition", "motivation"),
|
||||
Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent5", new AgentInfo("agent5", "Research Specialist",
|
||||
Arrays.asList("research", "writing", "analysis"),
|
||||
Arrays.asList("webSearch", "getNewsUpdates", "translate", "compareProductPrices"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent6", new AgentInfo("agent6", "Digital Marketing Expert",
|
||||
Arrays.asList("marketing", "social media", "content creation"),
|
||||
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "getMovieRecommendations"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
agentRegistry.registerAgent("agent7", new AgentInfo("agent7", "Family Travel Coordinator",
|
||||
Arrays.asList("travel", "family planning", "budgeting"),
|
||||
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants", "getFinancialAdvice"),
|
||||
treeOfThought, webSocketService, toolRegistry, model));
|
||||
|
||||
// Create all tasks
|
||||
List<Task> tasks = Arrays.asList(
|
||||
new Task("task1", "Plan a weekend trip to Paris",
|
||||
Arrays.asList("travel", "booking"),
|
||||
Arrays.asList("bookTravel", "findRestaurants", "getWeather"))//,
|
||||
// new Task("task2", "Organize a corporate team-building event in New York",
|
||||
// Arrays.asList("event planning", "team management"),
|
||||
// Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment")),
|
||||
// new Task("task3", "Develop a personalized fitness and nutrition plan",
|
||||
// Arrays.asList("health", "nutrition"),
|
||||
// Arrays.asList("getWeather", "findFitnessClasses", "getRecipe")),
|
||||
// new Task("task4", "Research and summarize recent advancements in renewable energy",
|
||||
// Arrays.asList("research", "writing"),
|
||||
// Arrays.asList("webSearch", "getNewsUpdates", "translate")),
|
||||
// new Task("task5", "Plan and execute a social media marketing campaign for a new product launch",
|
||||
// Arrays.asList("marketing", "social media"),
|
||||
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
|
||||
// new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
|
||||
// Arrays.asList("travel", "family planning"),
|
||||
// Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants")),
|
||||
|
||||
// new Task("task7", "Organize an international tech conference with virtual and in-person components",
|
||||
// Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
|
||||
// Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates"))//,
|
||||
|
||||
// new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
|
||||
// Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
|
||||
// Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
|
||||
|
||||
// new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
|
||||
// Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
|
||||
// Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
|
||||
|
||||
// new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
|
||||
// Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
|
||||
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
|
||||
|
||||
// new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
|
||||
// Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
|
||||
// Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
|
||||
|
||||
);
|
||||
|
||||
for (Task task : tasks) {
|
||||
taskManager.addTask(task); // Add each task to the TaskManager
|
||||
System.out.println("\nProcessing task: " + task.getDescription());
|
||||
List<AgentInfo> team = teamFormation.formTeam(task);
|
||||
System.out.println("Formed team: " + team);
|
||||
|
||||
if (!team.isEmpty()) {
|
||||
taskManager.executeTask(task.getId(), team);
|
||||
System.out.println("Task result: " + task.getResult());
|
||||
} else {
|
||||
System.out.println("No suitable agents found for this task. Consider updating the agent pool or revising the task requirements.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,15 +1,21 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
import com.ioa.conversation.ConversationFSM;
|
||||
import com.ioa.conversation.Message;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import com.ioa.websocket.AgentProcessingUpdate;
|
||||
import com.ioa.websocket.AgentResponseUpdate;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class AgentInfo {
|
||||
private String id;
|
||||
private String name;
|
||||
|
@ -21,59 +27,56 @@ public class AgentInfo {
|
|||
private ToolRegistry toolRegistry;
|
||||
private BedrockLanguageModel model;
|
||||
|
||||
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools,
|
||||
TreeOfThought treeOfThought, WebSocketService webSocketService,
|
||||
ToolRegistry toolRegistry, BedrockLanguageModel model) {
|
||||
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
|
||||
this.id = id;
|
||||
this.name = name;
|
||||
this.capabilities = capabilities;
|
||||
this.tools = tools;
|
||||
this.memory = new Memory();
|
||||
}
|
||||
|
||||
public void setDependencies(TreeOfThought treeOfThought, WebSocketService webSocketService,
|
||||
ToolRegistry toolRegistry, BedrockLanguageModel model) {
|
||||
this.treeOfThought = treeOfThought;
|
||||
this.webSocketService = webSocketService;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.model = model;
|
||||
this.memory = new Memory();
|
||||
}
|
||||
|
||||
public void receiveMessage(Message message) {
|
||||
public void receiveMessageLocal(Message message) {
|
||||
if (this.memory == null)
|
||||
this.memory = new Memory();
|
||||
memory.addToHistory(message.getContent());
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nYou received a message: " + message.getContent() +
|
||||
"\nBased on your memory and context, how would you respond or what actions would you take?" +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nYou received a message: " + message.getContent() +
|
||||
"\nBased on your memory and context, how would you respond or what actions would you take?" +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
String response = model.generate(prompt, null);
|
||||
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
|
||||
System.out.println("DEBUG: " + name + " response: " + response);
|
||||
|
||||
|
||||
// Add the response to memory
|
||||
memory.addToHistory("My response: " + response);
|
||||
}
|
||||
|
||||
public void performTreeOfThought(String task) {
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nTask: " + task +
|
||||
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
|
||||
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
|
||||
String reasoning = (String) totResult.get("reasoning");
|
||||
|
||||
// Add the reasoning to memory
|
||||
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
|
||||
|
||||
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
|
||||
public void receiveMessage(Message message) {
|
||||
if (this.memory == null)
|
||||
this.memory = new Memory();
|
||||
|
||||
|
||||
// This is a turn notification for this agent
|
||||
respondToTurn(message.getConversationId());
|
||||
|
||||
}
|
||||
|
||||
public void notifyTurn(ConversationFSM conversation) {
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nIt's your turn to speak in the conversation. What would you like to say or do?";
|
||||
String response = model.generate(prompt, null);
|
||||
conversation.postMessage(new Message(conversation.getConversationId(), id, response));
|
||||
private void respondToTurn(String conversationId) {
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nIt's your turn to speak in the conversation. What would you like to say or do?" +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
|
||||
// Send the response back to the conversation
|
||||
Message responseMessage = new Message(conversationId, this.id, prompt);
|
||||
webSocketService.sendUpdate("conversation_message", responseMessage);
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public String getId() { return id; }
|
||||
public String getName() { return name; }
|
||||
public List<String> getCapabilities() { return capabilities; }
|
||||
public List<String> getTools() { return tools; }
|
||||
}
|
12
src/main/java/com/ioa/agent/AgentMessage.java
Normal file
12
src/main/java/com/ioa/agent/AgentMessage.java
Normal file
|
@ -0,0 +1,12 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
public class AgentMessage {
|
||||
private String agentId;
|
||||
private String content;
|
||||
|
||||
// Getters and setters
|
||||
public String getAgentId() { return agentId; }
|
||||
public void setAgentId(String agentId) { this.agentId = agentId; }
|
||||
public String getContent() { return content; }
|
||||
public void setContent(String content) { this.content = content; }
|
||||
}
|
|
@ -1,34 +1,53 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
public class AgentRegistry {
|
||||
private Map<String, AgentInfo> agents = new HashMap<>();
|
||||
private ToolRegistry toolRegistry;
|
||||
private Map<String, AgentInfo> agents = new ConcurrentHashMap<>();
|
||||
private final SimpMessagingTemplate messagingTemplate;
|
||||
private final ToolRegistry toolRegistry;
|
||||
|
||||
public AgentRegistry(ToolRegistry toolRegistry) {
|
||||
@Autowired
|
||||
public AgentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
|
||||
this.messagingTemplate = messagingTemplate;
|
||||
this.toolRegistry = toolRegistry;
|
||||
}
|
||||
|
||||
public void registerAgent(String agentId, AgentInfo agentInfo) {
|
||||
agents.put(agentId, agentInfo);
|
||||
public void registerAgent(AgentInfo agentInfo) {
|
||||
agents.put(agentInfo.getId(), agentInfo);
|
||||
// Verify that all tools the agent claims to have are registered
|
||||
for (String tool : agentInfo.getTools()) {
|
||||
if (toolRegistry.getTool(tool) == null) {
|
||||
throw new IllegalArgumentException("Tool not found in registry: " + tool);
|
||||
}
|
||||
}
|
||||
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
|
||||
}
|
||||
|
||||
public void unregisterAgent(String agentId) {
|
||||
agents.remove(agentId);
|
||||
messagingTemplate.convertAndSend("/topic/agents", "Agent unregistered: " + agentId);
|
||||
}
|
||||
|
||||
public AgentInfo getAgent(String agentId) {
|
||||
return agents.get(agentId);
|
||||
}
|
||||
|
||||
public List<AgentInfo> getAllAgents() {
|
||||
return new ArrayList<>(agents.values());
|
||||
}
|
||||
|
||||
public List<AgentInfo> searchAgents(List<String> capabilities) {
|
||||
return searchAgents(capabilities, 1.0); // Default to exact match
|
||||
}
|
||||
|
@ -50,8 +69,4 @@ public class AgentRegistry {
|
|||
.count();
|
||||
return (double) matchingCapabilities / requiredCapabilities.size();
|
||||
}
|
||||
|
||||
public List<AgentInfo> getAllAgents() {
|
||||
return new ArrayList<>(agents.values());
|
||||
}
|
||||
}
|
|
@ -1,7 +1,19 @@
|
|||
package com.ioa.config;
|
||||
|
||||
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
|
||||
import org.springframework.messaging.converter.MessageConverter;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.MessageChannel;
|
||||
import org.springframework.messaging.converter.DefaultContentTypeResolver;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.messaging.simp.config.ChannelRegistration;
|
||||
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
||||
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
|
||||
import org.springframework.messaging.support.ChannelInterceptor;
|
||||
import org.springframework.messaging.support.MessageHeaderAccessor;
|
||||
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
||||
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
|
||||
|
@ -18,6 +30,34 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
|
|||
|
||||
@Override
|
||||
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||
registry.addEndpoint("/ws").withSockJS();
|
||||
registry.addEndpoint("/ws")
|
||||
.setAllowedOriginPatterns("*")
|
||||
.withSockJS();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean configureMessageConverters(List<MessageConverter> messageConverters) {
|
||||
MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter();
|
||||
DefaultContentTypeResolver resolver = new DefaultContentTypeResolver();
|
||||
resolver.setDefaultMimeType(MimeTypeUtils.APPLICATION_JSON);
|
||||
converter.setContentTypeResolver(resolver);
|
||||
messageConverters.add(converter);
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void configureClientInboundChannel(ChannelRegistration registration) {
|
||||
registration.interceptors(new ChannelInterceptor() {
|
||||
@Override
|
||||
public Message<?> preSend(Message<?> message, MessageChannel channel) {
|
||||
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
|
||||
if (accessor != null && accessor.getCommand() != null) {
|
||||
System.out.println("Received STOMP Frame: " + accessor.getCommand());
|
||||
System.out.println("Destination: " + accessor.getDestination());
|
||||
System.out.println("Payload: " + new String((byte[]) message.getPayload()));
|
||||
}
|
||||
return message;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -91,18 +91,33 @@ public class ConversationFSM {
|
|||
private void updateState(Message message) {
|
||||
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
||||
"\nCurrent state: " + currentState +
|
||||
"\nParticipants: " + participants;
|
||||
"\nParticipants: " + participants +
|
||||
"\nProvide the next conversation state (" +
|
||||
"DISCUSSION," +
|
||||
"RESEARCH," +
|
||||
"RESEARCH_TASK," +
|
||||
"TASK_GATHERING_INFO," +
|
||||
"TASK," +
|
||||
"TASK_PLANNING," +
|
||||
"TASK_ASSIGNMENT," +
|
||||
"EXECUTION," +
|
||||
"CONCLUSION" +
|
||||
"\n)\nOnly give the single word answer in all caps only from the given options.";
|
||||
|
||||
String reasoning = model.generate(stateTransitionTask, null);
|
||||
|
||||
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||
"\nProvide the next conversation state (DISCUSSION,\n" + //
|
||||
" TASK_GATHERING_INFO,\n" + //
|
||||
" TASK,\n" + //
|
||||
" TASK_PLANNING,\n" + //
|
||||
" TASK_ASSIGNMENT,\n" + //
|
||||
" EXECUTION,\n" + //
|
||||
" CONCLUSION). Only give the single word answer in all caps only from the given options.";
|
||||
"\nProvide the next conversation state (" +
|
||||
"DISCUSSION," +
|
||||
"RESEARCH," +
|
||||
"RESEARCH_TASK," +
|
||||
"TASK_GATHERING_INFO," +
|
||||
"TASK," +
|
||||
"TASK_PLANNING," +
|
||||
"TASK_ASSIGNMENT," +
|
||||
"EXECUTION," +
|
||||
"CONCLUSION" +
|
||||
"\n)\nOnly give the single word answer in all caps only from the given options.";
|
||||
String response = model.generate(decisionPrompt, null);
|
||||
|
||||
ConversationState newState = ConversationState.valueOf(response.trim());
|
||||
|
@ -112,7 +127,10 @@ public class ConversationFSM {
|
|||
private void notifyNextSpeaker() {
|
||||
AgentInfo nextSpeaker = speakingQueue.poll();
|
||||
if (nextSpeaker != null) {
|
||||
nextSpeaker.notifyTurn(this);
|
||||
// Instead of calling notifyTurn, we'll create a turn notification message
|
||||
Message turnNotification = new Message(this.conversationId, "SYSTEM",
|
||||
"It's " + nextSpeaker.getName() + "'s turn to speak.");
|
||||
broadcastMessage(turnNotification);
|
||||
speakingQueue.offer(nextSpeaker);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,23 +3,23 @@ package com.ioa.conversation;
|
|||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
@Component
|
||||
public class ConversationManager {
|
||||
private final Map<String, ConversationFSM> conversations;
|
||||
private final Map<String, ConversationFSM> conversations = new ConcurrentHashMap<>();
|
||||
private final BedrockLanguageModel model;
|
||||
private final WebSocketService webSocketService;
|
||||
private final ScheduledExecutorService executorService;
|
||||
|
||||
@Autowired
|
||||
public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
|
||||
this.conversations = new ConcurrentHashMap<>();
|
||||
this.model = model;
|
||||
this.webSocketService = webSocketService;
|
||||
this.executorService = Executors.newScheduledThreadPool(1);
|
||||
}
|
||||
|
||||
public String createConversation() {
|
||||
|
@ -30,6 +30,21 @@ public class ConversationManager {
|
|||
return conversationId;
|
||||
}
|
||||
|
||||
public void joinConversation(String conversationId) {
|
||||
ConversationFSM conversation = conversations.get(conversationId);
|
||||
if (conversation != null) {
|
||||
// Logic to join a conversation (e.g., add participant)
|
||||
webSocketService.sendUpdate("conversation_joined", conversationId);
|
||||
}
|
||||
}
|
||||
|
||||
public void postMessage(String conversationId, String senderId, String content) {
|
||||
ConversationFSM conversation = conversations.get(conversationId);
|
||||
if (conversation != null) {
|
||||
conversation.postMessage(new Message(conversationId, senderId, content));
|
||||
}
|
||||
}
|
||||
|
||||
public ConversationFSM getConversation(String conversationId) {
|
||||
return conversations.get(conversationId);
|
||||
}
|
||||
|
@ -47,21 +62,6 @@ public class ConversationManager {
|
|||
conversation.removeParticipant(agent);
|
||||
}
|
||||
}
|
||||
|
||||
public void postMessage(String conversationId, String senderId, String content) {
|
||||
System.out.println("DEBUG: Posting message - ConversationId: " + conversationId + ", SenderId: " + senderId + ", Content: " + content);
|
||||
ConversationFSM conversation = conversations.get(conversationId);
|
||||
if (conversation != null) {
|
||||
if (content == null) {
|
||||
Arrays.toString(Thread.currentThread().getStackTrace()).replace( ',', '\n' );
|
||||
System.out.println("WARNING: Attempting to post null content message");
|
||||
return;
|
||||
}
|
||||
conversation.postMessage(new Message(conversationId, senderId, content));
|
||||
} else {
|
||||
System.out.println("WARNING: Conversation not found for id: " + conversationId);
|
||||
}
|
||||
}
|
||||
|
||||
public void startConversation(String conversationId, String initialMessage) {
|
||||
ConversationFSM conversation = conversations.get(conversationId);
|
||||
|
@ -69,7 +69,7 @@ public class ConversationManager {
|
|||
conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage));
|
||||
|
||||
// Start a timer to end the conversation after 10 minutes
|
||||
executorService.schedule(() -> {
|
||||
Executor.schedule(() -> {
|
||||
if (!conversation.isFinished()) {
|
||||
conversation.finish("Time limit reached");
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package com.ioa.conversation;
|
|||
|
||||
public enum ConversationState {
|
||||
DISCUSSION,
|
||||
RESEARCH,
|
||||
RESEARCH_TASK,
|
||||
TASK_GATHERING_INFO,
|
||||
TASK,
|
||||
TASK_PLANNING,
|
||||
|
|
|
@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
|
|||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Base64;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
@Component
|
||||
public class BedrockLanguageModel {
|
||||
|
@ -36,33 +37,7 @@ public class BedrockLanguageModel {
|
|||
public String generate(String prompt, String imagePath) {
|
||||
System.out.println("DEBUG: Generating response for prompt: " + prompt);
|
||||
try {
|
||||
ObjectNode requestBody = objectMapper.createObjectNode();
|
||||
requestBody.put("anthropic_version", "bedrock-2023-05-31");
|
||||
ArrayNode messages = requestBody.putArray("messages");
|
||||
ObjectNode message = messages.addObject();
|
||||
message.put("role", "user");
|
||||
requestBody.put("max_tokens", 20000);
|
||||
requestBody.put("temperature", 0.7);
|
||||
requestBody.put("top_p", 0.9);
|
||||
|
||||
ArrayNode content = message.putArray("content");
|
||||
|
||||
if (imagePath != null && !imagePath.isEmpty()) {
|
||||
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
|
||||
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||
|
||||
ObjectNode imageNode = content.addObject();
|
||||
imageNode.put("type", "image"); // Add type field
|
||||
ObjectNode imageContent = imageNode.putObject("image");
|
||||
imageContent.put("format", "png");
|
||||
ObjectNode source = imageContent.putObject("source");
|
||||
source.put("bytes", base64Image);
|
||||
}
|
||||
|
||||
ObjectNode textNode = content.addObject();
|
||||
textNode.put("type", "text"); // Add type field
|
||||
textNode.put("text", prompt);
|
||||
|
||||
ObjectNode requestBody = createRequestBody(prompt, imagePath);
|
||||
String jsonPayload = objectMapper.writeValueAsString(requestBody);
|
||||
|
||||
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
|
||||
|
@ -100,4 +75,77 @@ public class BedrockLanguageModel {
|
|||
return "Error: " + e.getMessage();
|
||||
}
|
||||
}
|
||||
|
||||
public void generateStream(String prompt, String imagePath, Consumer<String> chunkConsumer) {
|
||||
System.out.println("DEBUG: Generating streaming response for prompt: " + prompt);
|
||||
try {
|
||||
ObjectNode requestBody = createRequestBody(prompt, imagePath);
|
||||
String jsonPayload = objectMapper.writeValueAsString(requestBody);
|
||||
|
||||
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
|
||||
.modelId(modelId)
|
||||
.contentType("application/json")
|
||||
.accept("application/json")
|
||||
.body(SdkBytes.fromUtf8String(jsonPayload))
|
||||
.build();
|
||||
|
||||
InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest);
|
||||
String responseBody = response.body().asUtf8String();
|
||||
System.out.println("DEBUG: Raw response from Bedrock: " + responseBody);
|
||||
|
||||
JsonNode responseJson = objectMapper.readTree(responseBody);
|
||||
JsonNode contentArray = responseJson.path("content");
|
||||
|
||||
if (contentArray.isArray() && contentArray.size() > 0) {
|
||||
for (JsonNode content : contentArray) {
|
||||
String chunk = content.path("text").asText();
|
||||
chunkConsumer.accept(chunk);
|
||||
}
|
||||
} else {
|
||||
System.out.println("WARNING: Unexpected response format. Full response: " + responseBody);
|
||||
chunkConsumer.accept("Unexpected response format");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.out.println("ERROR: Failed to generate streaming text with Bedrock: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
chunkConsumer.accept("Error: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private ObjectNode createRequestBody(String prompt, String imagePath) {
|
||||
ObjectNode requestBody = objectMapper.createObjectNode();
|
||||
requestBody.put("anthropic_version", "bedrock-2023-05-31");
|
||||
requestBody.put("max_tokens", 20000);
|
||||
requestBody.put("temperature", 0.7);
|
||||
requestBody.put("top_p", 0.9);
|
||||
|
||||
ArrayNode messages = requestBody.putArray("messages");
|
||||
ObjectNode message = messages.addObject();
|
||||
message.put("role", "user");
|
||||
|
||||
ArrayNode content = message.putArray("content");
|
||||
|
||||
if (imagePath != null && !imagePath.isEmpty()) {
|
||||
try {
|
||||
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
|
||||
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||
|
||||
ObjectNode imageNode = content.addObject();
|
||||
imageNode.put("type", "image");
|
||||
ObjectNode imageContent = imageNode.putObject("image");
|
||||
imageContent.put("format", "png");
|
||||
ObjectNode source = imageContent.putObject("source");
|
||||
source.put("bytes", base64Image);
|
||||
} catch (Exception e) {
|
||||
System.err.println("Error reading image file: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
ObjectNode textNode = content.addObject();
|
||||
textNode.put("type", "text");
|
||||
textNode.put("text", prompt);
|
||||
|
||||
return requestBody;
|
||||
}
|
||||
}
|
|
@ -1,81 +1,17 @@
|
|||
package com.ioa.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.ioa.conversation.ConversationManager;
|
||||
import com.ioa.conversation.Message;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
|
||||
@Service
|
||||
public class WebSocketService {
|
||||
private final SimpMessagingTemplate messagingTemplate;
|
||||
private final ConversationManager conversationManager;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
@Autowired
|
||||
public WebSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
|
||||
public WebSocketService(SimpMessagingTemplate messagingTemplate) {
|
||||
this.messagingTemplate = messagingTemplate;
|
||||
this.conversationManager = conversationManager;
|
||||
this.objectMapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
public void handleMessage(WebSocketSession session, TextMessage message) {
|
||||
try {
|
||||
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
|
||||
if (messageArray.length > 0) {
|
||||
String content = messageArray[0];
|
||||
// Assume we're using a default conversation ID for simplicity
|
||||
String conversationId = "default";
|
||||
// Assume we're using a default user ID for simplicity
|
||||
String userId = "user";
|
||||
conversationManager.postMessage(conversationId, userId, content);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public void sendUpdate(String topic, Object payload) {
|
||||
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
||||
}
|
||||
|
||||
public void handleWebSocketMessage(String message) {
|
||||
System.out.println("DEBUG: Received WebSocket message: " + message);
|
||||
|
||||
// Parse the WebSocket frame
|
||||
String[] parts = message.split("\n\n", 2);
|
||||
if (parts.length < 2) {
|
||||
System.out.println("DEBUG: Invalid WebSocket message format");
|
||||
return;
|
||||
}
|
||||
|
||||
String headers = parts[0];
|
||||
String payload = parts[1];
|
||||
|
||||
// Parse the JSON payload
|
||||
try {
|
||||
JsonNode jsonNode = objectMapper.readTree(payload);
|
||||
|
||||
// Extract relevant information from the JSON
|
||||
// Adjust this based on the actual structure of your WebSocket messages
|
||||
String conversationId = jsonNode.path("conversationId").asText();
|
||||
String sender = jsonNode.path("sender").asText();
|
||||
String content = jsonNode.path("content").asText();
|
||||
|
||||
// Create a new Message object
|
||||
Message parsedMessage = new Message(conversationId, sender, content);
|
||||
|
||||
System.out.println("DEBUG: WebSocket message: " + payload);
|
||||
// Process the message
|
||||
conversationManager.postMessage(conversationId, sender, content);
|
||||
} catch (Exception e) {
|
||||
System.out.println("DEBUG: Error parsing WebSocket message: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,8 @@
|
|||
package com.ioa.task;
|
||||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class Task {
|
||||
private String id;
|
||||
private String description;
|
||||
|
@ -14,10 +11,28 @@ public class Task {
|
|||
private AgentInfo assignedAgent;
|
||||
private String result;
|
||||
|
||||
// Default constructor
|
||||
public Task() {}
|
||||
|
||||
// Existing constructor
|
||||
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
|
||||
this.id = id;
|
||||
this.description = description;
|
||||
this.requiredCapabilities = requiredCapabilities;
|
||||
this.requiredTools = requiredTools;
|
||||
}
|
||||
|
||||
// Getters and setters for all fields
|
||||
public String getId() { return id; }
|
||||
public void setId(String id) { this.id = id; }
|
||||
public String getDescription() { return description; }
|
||||
public void setDescription(String description) { this.description = description; }
|
||||
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
|
||||
public void setRequiredCapabilities(List<String> requiredCapabilities) { this.requiredCapabilities = requiredCapabilities; }
|
||||
public List<String> getRequiredTools() { return requiredTools; }
|
||||
public void setRequiredTools(List<String> requiredTools) { this.requiredTools = requiredTools; }
|
||||
public AgentInfo getAssignedAgent() { return assignedAgent; }
|
||||
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
|
||||
public String getResult() { return result; }
|
||||
public void setResult(String result) { this.result = result; }
|
||||
}
|
14
src/main/java/com/ioa/task/TaskAssignment.java
Normal file
14
src/main/java/com/ioa/task/TaskAssignment.java
Normal file
|
@ -0,0 +1,14 @@
|
|||
package com.ioa.task;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class TaskAssignment {
|
||||
private String taskId;
|
||||
private List<String> agentIds;
|
||||
|
||||
// Getters and setters
|
||||
public String getTaskId() { return taskId; }
|
||||
public void setTaskId(String taskId) { this.taskId = taskId; }
|
||||
public List<String> getAgentIds() { return agentIds; }
|
||||
public void setAgentIds(List<String> agentIds) { this.agentIds = agentIds; }
|
||||
}
|
|
@ -1,37 +1,84 @@
|
|||
package com.ioa.task;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.conversation.ConversationFSM;
|
||||
import com.ioa.conversation.ConversationManager;
|
||||
import com.ioa.conversation.Message;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.team.TeamFormation;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
import com.ioa.conversation.Message;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Component
|
||||
public class TaskManager {
|
||||
private Map<String, Task> tasks = new HashMap<>();
|
||||
private Map<String, Task> tasks = new ConcurrentHashMap<>();
|
||||
private final SimpMessagingTemplate messagingTemplate;
|
||||
private AgentRegistry agentRegistry;
|
||||
private BedrockLanguageModel model;
|
||||
private ToolRegistry toolRegistry;
|
||||
private TreeOfThought treeOfThought;
|
||||
private ConversationManager conversationManager;
|
||||
private WebSocketService webSocketService;
|
||||
|
||||
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry,
|
||||
TreeOfThought treeOfThought, ConversationManager conversationManager) {
|
||||
@Autowired
|
||||
private TeamFormation teamFormation;
|
||||
|
||||
@Autowired
|
||||
public TaskManager(AgentRegistry agentRegistry,
|
||||
BedrockLanguageModel model,
|
||||
ToolRegistry toolRegistry,
|
||||
TreeOfThought treeOfThought,
|
||||
ConversationManager conversationManager,
|
||||
WebSocketService webSocketService,
|
||||
SimpMessagingTemplate messagingTemplate) {
|
||||
this.agentRegistry = agentRegistry;
|
||||
this.model = model;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.treeOfThought = treeOfThought;
|
||||
this.conversationManager = conversationManager;
|
||||
this.webSocketService = webSocketService;
|
||||
this.messagingTemplate = messagingTemplate;
|
||||
}
|
||||
|
||||
public String createTask(Task task) {
|
||||
String taskId = UUID.randomUUID().toString();
|
||||
task.setId(taskId);
|
||||
tasks.put(taskId, task);
|
||||
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
|
||||
|
||||
// Automatically assign agents and execute the task
|
||||
List<AgentInfo> team = teamFormation.formTeam(task);
|
||||
if (!team.isEmpty()) {
|
||||
executeTask(taskId, team);
|
||||
} else {
|
||||
System.out.println("No suitable agents found for task: " + taskId);
|
||||
}
|
||||
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void assignTask(String taskId, List<String> agentIds) {
|
||||
Task task = tasks.get(taskId);
|
||||
if (task != null) {
|
||||
List<AgentInfo> team = agentIds.stream()
|
||||
.map(agentRegistry::getAgent)
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
executeTask(taskId, team);
|
||||
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task assigned to: " + String.join(", ", agentIds));
|
||||
}
|
||||
}
|
||||
|
||||
public void addTask(Task task) {
|
||||
|
@ -54,9 +101,13 @@ public class TaskManager {
|
|||
|
||||
conversation.postMessage(new Message(conversationId, "SYSTEM", "Let's work on the task: " + task.getDescription()));
|
||||
|
||||
// Allow agents to interact for a maximum of 10 minutes
|
||||
// Use the LLM to generate a plan for the task
|
||||
String planPrompt = "Generate a step-by-step plan to accomplish the following task: " + task.getDescription();
|
||||
String plan = model.generate(planPrompt, null);
|
||||
conversation.postMessage(new Message(conversationId, "SYSTEM", "Here's the plan: " + plan));
|
||||
|
||||
// Allow agents to interact for a maximum of 40 minutes
|
||||
long startTime = System.currentTimeMillis();
|
||||
//while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(10)) {
|
||||
while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(40)) {
|
||||
try {
|
||||
Thread.sleep(1000); // Check every second
|
||||
|
@ -72,8 +123,10 @@ public class TaskManager {
|
|||
String result = conversation.getResult();
|
||||
task.setResult(result);
|
||||
System.out.println("Task completed. Result: " + result);
|
||||
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task completed: " + result);
|
||||
}
|
||||
|
||||
|
||||
public Task getTask(String taskId) {
|
||||
return tasks.get(taskId);
|
||||
}
|
||||
|
|
31
src/main/java/com/ioa/task/TaskResult.java
Normal file
31
src/main/java/com/ioa/task/TaskResult.java
Normal file
|
@ -0,0 +1,31 @@
|
|||
package com.ioa.task;
|
||||
|
||||
public class TaskResult {
|
||||
private String taskId;
|
||||
private String result;
|
||||
|
||||
// Constructors
|
||||
public TaskResult() {}
|
||||
|
||||
public TaskResult(String taskId, String result) {
|
||||
this.taskId = taskId;
|
||||
this.result = result;
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public String getTaskId() {
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void setTaskId(String taskId) {
|
||||
this.taskId = taskId;
|
||||
}
|
||||
|
||||
public String getResult() {
|
||||
return result;
|
||||
}
|
||||
|
||||
public void setResult(String result) {
|
||||
this.result = result;
|
||||
}
|
||||
}
|
29
src/main/java/com/ioa/tool/ToolInfo.java
Normal file
29
src/main/java/com/ioa/tool/ToolInfo.java
Normal file
|
@ -0,0 +1,29 @@
|
|||
package com.ioa.tool;
|
||||
|
||||
public class ToolInfo {
|
||||
private String agentId;
|
||||
private String toolName;
|
||||
|
||||
public ToolInfo() {}
|
||||
|
||||
public ToolInfo(String agentId, String toolName) {
|
||||
this.agentId = agentId;
|
||||
this.toolName = toolName;
|
||||
}
|
||||
|
||||
public String getAgentId() {
|
||||
return agentId;
|
||||
}
|
||||
|
||||
public void setAgentId(String agentId) {
|
||||
this.agentId = agentId;
|
||||
}
|
||||
|
||||
public String getToolName() {
|
||||
return toolName;
|
||||
}
|
||||
|
||||
public void setToolName(String toolName) {
|
||||
this.toolName = toolName;
|
||||
}
|
||||
}
|
16
src/main/java/com/ioa/websocket/AgentProcessingUpdate.java
Normal file
16
src/main/java/com/ioa/websocket/AgentProcessingUpdate.java
Normal file
|
@ -0,0 +1,16 @@
|
|||
package com.ioa.websocket;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AgentProcessingUpdate {
|
||||
private String agentId;
|
||||
private String responseId;
|
||||
private String status; // "start" or "complete"
|
||||
|
||||
public AgentProcessingUpdate(String agentId, String responseId, String status) {
|
||||
this.agentId = agentId;
|
||||
this.responseId = responseId;
|
||||
this.status = status;
|
||||
}
|
||||
}
|
16
src/main/java/com/ioa/websocket/AgentResponseUpdate.java
Normal file
16
src/main/java/com/ioa/websocket/AgentResponseUpdate.java
Normal file
|
@ -0,0 +1,16 @@
|
|||
package com.ioa.websocket;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AgentResponseUpdate {
|
||||
private String agentId;
|
||||
private String responseId;
|
||||
private String chunk;
|
||||
|
||||
public AgentResponseUpdate(String agentId, String responseId, String chunk) {
|
||||
this.agentId = agentId;
|
||||
this.responseId = responseId;
|
||||
this.chunk = chunk;
|
||||
}
|
||||
}
|
50
src/main/java/com/ioa/websocket/ChatController.java
Normal file
50
src/main/java/com/ioa/websocket/ChatController.java
Normal file
|
@ -0,0 +1,50 @@
|
|||
package com.ioa.websocket;
|
||||
|
||||
import com.ioa.conversation.ConversationManager;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.messaging.handler.annotation.SendTo;
|
||||
import org.springframework.stereotype.Controller;
|
||||
|
||||
@Controller
|
||||
public class ChatController {
|
||||
private final ConversationManager conversationManager;
|
||||
|
||||
@Autowired
|
||||
public ChatController(ConversationManager conversationManager) {
|
||||
this.conversationManager = conversationManager;
|
||||
}
|
||||
|
||||
@MessageMapping("/chat/create")
|
||||
@SendTo("/topic/rooms")
|
||||
public String createRoom(String roomName) {
|
||||
String roomId = conversationManager.createConversation();
|
||||
return "Room created: " + roomId;
|
||||
}
|
||||
|
||||
@MessageMapping("/chat/join")
|
||||
@SendTo("/topic/rooms")
|
||||
public String joinRoom(String roomId) {
|
||||
conversationManager.joinConversation(roomId);
|
||||
return "Joined room: " + roomId;
|
||||
}
|
||||
|
||||
@MessageMapping("/chat/message")
|
||||
public void sendMessage(ChatMessage message) {
|
||||
conversationManager.postMessage(message.getRoomId(), message.getSender(), message.getContent());
|
||||
}
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
private String roomId;
|
||||
private String sender;
|
||||
private String content;
|
||||
|
||||
// Getters and setters
|
||||
public String getRoomId() { return roomId; }
|
||||
public void setRoomId(String roomId) { this.roomId = roomId; }
|
||||
public String getSender() { return sender; }
|
||||
public void setSender(String sender) { this.sender = sender; }
|
||||
public String getContent() { return content; }
|
||||
public void setContent(String content) { this.content = content; }
|
||||
}
|
91
src/main/java/com/ioa/websocket/WebSocketHandler.java
Normal file
91
src/main/java/com/ioa/websocket/WebSocketHandler.java
Normal file
|
@ -0,0 +1,91 @@
|
|||
package com.ioa.websocket;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||
import org.springframework.messaging.handler.annotation.Payload;
|
||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Controller;
|
||||
|
||||
import com.ioa.agent.AgentInfo;
|
||||
import com.ioa.agent.AgentRegistry;
|
||||
import com.ioa.model.BedrockLanguageModel;
|
||||
import com.ioa.service.WebSocketService;
|
||||
import com.ioa.task.Task;
|
||||
import com.ioa.task.TaskAssignment;
|
||||
import com.ioa.task.TaskManager;
|
||||
import com.ioa.task.TaskResult;
|
||||
import com.ioa.tool.ToolInfo;
|
||||
import com.ioa.tool.ToolRegistry;
|
||||
import com.ioa.util.TreeOfThought;
|
||||
|
||||
@Controller
|
||||
public class WebSocketHandler {
|
||||
|
||||
private final AgentRegistry agentRegistry;
|
||||
private final TaskManager taskManager;
|
||||
private final SimpMessagingTemplate messagingTemplate;
|
||||
private final TreeOfThought treeOfThought;
|
||||
private final WebSocketService webSocketService;
|
||||
private final ToolRegistry toolRegistry;
|
||||
private final BedrockLanguageModel model;
|
||||
|
||||
@Autowired
|
||||
public WebSocketHandler(AgentRegistry agentRegistry, TaskManager taskManager,
|
||||
SimpMessagingTemplate messagingTemplate, TreeOfThought treeOfThought,
|
||||
WebSocketService webSocketService, ToolRegistry toolRegistry,
|
||||
BedrockLanguageModel model) {
|
||||
this.agentRegistry = agentRegistry;
|
||||
this.taskManager = taskManager;
|
||||
this.messagingTemplate = messagingTemplate;
|
||||
this.treeOfThought = treeOfThought;
|
||||
this.webSocketService = webSocketService;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@MessageMapping("/agent/register")
|
||||
public void registerAgent(@Payload AgentInfo agentInfo) {
|
||||
// Initialize dependencies
|
||||
agentInfo.setDependencies(treeOfThought, webSocketService, toolRegistry, model);
|
||||
|
||||
agentRegistry.registerAgent(agentInfo);
|
||||
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
|
||||
}
|
||||
|
||||
@MessageMapping("/tool/register")
|
||||
public void registerTool(@Payload ToolInfo toolInfo) {
|
||||
AgentInfo agent = agentRegistry.getAgent(toolInfo.getAgentId());
|
||||
if (agent != null) {
|
||||
agent.getTools().add(toolInfo.getToolName());
|
||||
messagingTemplate.convertAndSend("/topic/tools", "New tool registered: " + toolInfo.getToolName());
|
||||
}
|
||||
}
|
||||
|
||||
@MessageMapping("/task/create")
|
||||
public void createTask(@Payload Task task) {
|
||||
System.out.println("Creating task: " + task.getId());
|
||||
String taskId = taskManager.createTask(task);
|
||||
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
|
||||
}
|
||||
|
||||
@MessageMapping("/agent/unregister")
|
||||
public void unregisterAgent(@Payload String agentId) {
|
||||
agentRegistry.unregisterAgent(agentId);
|
||||
messagingTemplate.convertAndSend("/topic/agent/" + agentId, "Unregistration successful");
|
||||
}
|
||||
|
||||
@MessageMapping("/task/assign")
|
||||
public void assignTask(@Payload TaskAssignment assignment) {
|
||||
taskManager.assignTask(assignment.getTaskId(), assignment.getAgentIds());
|
||||
messagingTemplate.convertAndSend("/topic/tasks/" + assignment.getTaskId(), "Task assigned");
|
||||
}
|
||||
|
||||
@MessageMapping("/task/result")
|
||||
public void handleTaskResult(@Payload TaskResult result) {
|
||||
Task task = taskManager.getTask(result.getTaskId());
|
||||
if (task != null) {
|
||||
task.setResult(result.getResult());
|
||||
messagingTemplate.convertAndSend("/topic/tasks/" + result.getTaskId(), "Task completed: " + result.getResult());
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user