Mostly working system

This commit is contained in:
Mahesh Kommareddi 2024-07-30 23:21:45 -04:00
parent 9cf98dbaf7
commit 6789f035b2
51 changed files with 623 additions and 365 deletions

3
README.md Normal file
View File

@ -0,0 +1,3 @@
mvn spring-boot:run
http://localhost:8080/

2
application.properties Normal file
View File

@ -0,0 +1,2 @@
logging.level.org.springframework.messaging=TRACE
logging.level.org.springframework.web.socket=TRACE

View File

@ -1,59 +1,51 @@
package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationManager;
import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry;
import com.ioa.tool.Tool;
import com.ioa.tool.common.*;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.util.TreeOfThought;
import com.ioa.tool.common.AppointmentSchedulerTool;
import com.ioa.tool.common.DistanceCalculatorTool;
import com.ioa.tool.common.FinancialAdviceTool;
import com.ioa.tool.common.FitnessClassFinderTool;
import com.ioa.tool.common.MovieRecommendationTool;
import com.ioa.tool.common.NewsUpdateTool;
import com.ioa.tool.common.PriceComparisonTool;
import com.ioa.tool.common.RecipeTool;
import com.ioa.tool.common.ReminderTool;
import com.ioa.tool.common.RestaurantFinderTool;
import com.ioa.tool.common.TranslationTool;
import com.ioa.tool.common.TravelBookingTool;
import com.ioa.tool.common.WeatherTool;
import com.ioa.tool.common.WebSearchTool;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Lazy;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.CountDownLatch;
import com.ioa.agent.AgentRegistry;
import com.ioa.task.TaskManager;
import com.ioa.websocket.WebSocketHandler;
import com.ioa.tool.ToolRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.util.TreeOfThought;
import com.ioa.service.WebSocketService;
import com.ioa.conversation.ConversationManager;
@SpringBootApplication
@ComponentScan(basePackages = "com.ioa")
public class IoASystem {
@Bean
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
return new WebSocketService(messagingTemplate, conversationManager);
public WebSocketHandler webSocketHandler(AgentRegistry agentRegistry,
TaskManager taskManager,
SimpMessagingTemplate messagingTemplate,
TreeOfThought treeOfThought,
WebSocketService webSocketService,
ToolRegistry toolRegistry,
BedrockLanguageModel model) {
return new WebSocketHandler(agentRegistry, taskManager, messagingTemplate,
treeOfThought, webSocketService, toolRegistry, model);
}
@Bean
public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
return new ConversationManager(model, webSocketService);
public AgentRegistry agentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
return new AgentRegistry(messagingTemplate, toolRegistry);
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought,
conversationManager, webSocketService, messagingTemplate);
}
@Bean
@ -67,149 +59,21 @@ public class IoASystem {
}
@Bean
public AgentRegistry agentRegistry(ToolRegistry toolRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, ConversationManager conversationManager) {
AgentRegistry registry = new AgentRegistry(toolRegistry);
// Agent creation is now moved to processTasksAndAgents method
return registry;
}
@Bean
public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry, TreeOfThought treeOfThought, ConversationManager conversationManager) {
return new TaskManager(agentRegistry, model, toolRegistry, treeOfThought, conversationManager);
}
@Bean
public TeamFormation teamFormation(AgentRegistry agentRegistry, TreeOfThought treeOfThought, WebSocketService webSocketService, BedrockLanguageModel model) {
return new TeamFormation(agentRegistry, treeOfThought, webSocketService, model);
public WebSocketService webSocketService(SimpMessagingTemplate messagingTemplate) {
return new WebSocketService(messagingTemplate);
}
@Bean
public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry();
return new ToolRegistry();
}
// Register all tools
toolRegistry.registerTool("webSearch", new WebSearchTool());
toolRegistry.registerTool("getWeather", new WeatherTool());
toolRegistry.registerTool("setReminder", new ReminderTool());
toolRegistry.registerTool("bookTravel", new TravelBookingTool());
toolRegistry.registerTool("calculateDistance", new DistanceCalculatorTool());
toolRegistry.registerTool("findRestaurants", new RestaurantFinderTool());
toolRegistry.registerTool("scheduleAppointment", new AppointmentSchedulerTool());
toolRegistry.registerTool("findFitnessClasses", new FitnessClassFinderTool());
toolRegistry.registerTool("getRecipe", new RecipeTool());
toolRegistry.registerTool("getNewsUpdates", new NewsUpdateTool());
toolRegistry.registerTool("translate", new TranslationTool());
toolRegistry.registerTool("compareProductPrices", new PriceComparisonTool());
toolRegistry.registerTool("getMovieRecommendations", new MovieRecommendationTool());
toolRegistry.registerTool("getFinancialAdvice", new FinancialAdviceTool());
return toolRegistry;
@Bean
public ConversationManager conversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
return new ConversationManager(model, webSocketService);
}
public static void main(String[] args) {
ConfigurableApplicationContext context = SpringApplication.run(IoASystem.class, args);
IoASystem system = context.getBean(IoASystem.class);
system.processTasksAndAgents(context);
}
public void processTasksAndAgents(ConfigurableApplicationContext context) {
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class);
TreeOfThought treeOfThought = context.getBean(TreeOfThought.class);
WebSocketService webSocketService = context.getBean(WebSocketService.class);
ToolRegistry toolRegistry = context.getBean(ToolRegistry.class);
ConversationManager conversationManager = context.getBean(ConversationManager.class);
BedrockLanguageModel model = context.getBean(BedrockLanguageModel.class);
// Register all agents
agentRegistry.registerAgent("agent1", new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent2", new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent3", new AgentInfo("agent3", "Event Planner Extraordinaire",
Arrays.asList("event planning", "team management", "booking"),
Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment", "getWeather"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent4", new AgentInfo("agent4", "Fitness Guru",
Arrays.asList("health", "nutrition", "motivation"),
Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent5", new AgentInfo("agent5", "Research Specialist",
Arrays.asList("research", "writing", "analysis"),
Arrays.asList("webSearch", "getNewsUpdates", "translate", "compareProductPrices"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent6", new AgentInfo("agent6", "Digital Marketing Expert",
Arrays.asList("marketing", "social media", "content creation"),
Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "getMovieRecommendations"),
treeOfThought, webSocketService, toolRegistry, model));
agentRegistry.registerAgent("agent7", new AgentInfo("agent7", "Family Travel Coordinator",
Arrays.asList("travel", "family planning", "budgeting"),
Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants", "getFinancialAdvice"),
treeOfThought, webSocketService, toolRegistry, model));
// Create all tasks
List<Task> tasks = Arrays.asList(
new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"))//,
// new Task("task2", "Organize a corporate team-building event in New York",
// Arrays.asList("event planning", "team management"),
// Arrays.asList("findRestaurants", "bookTravel", "scheduleAppointment")),
// new Task("task3", "Develop a personalized fitness and nutrition plan",
// Arrays.asList("health", "nutrition"),
// Arrays.asList("getWeather", "findFitnessClasses", "getRecipe")),
// new Task("task4", "Research and summarize recent advancements in renewable energy",
// Arrays.asList("research", "writing"),
// Arrays.asList("webSearch", "getNewsUpdates", "translate")),
// new Task("task5", "Plan and execute a social media marketing campaign for a new product launch",
// Arrays.asList("marketing", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment")),
// new Task("task6", "Assist in planning a multi-city European vacation for a family of four",
// Arrays.asList("travel", "family planning"),
// Arrays.asList("bookTravel", "calculateDistance", "getWeather", "findRestaurants")),
// new Task("task7", "Organize an international tech conference with virtual and in-person components",
// Arrays.asList("event planning", "tech expertise", "marketing", "travel coordination", "content creation"),
// Arrays.asList("scheduleAppointment", "webSearch", "bookTravel", "getWeather", "findRestaurants", "getNewsUpdates"))//,
// new Task("task8", "Develop and launch a multi-lingual mobile app for sustainable tourism",
// Arrays.asList("software development", "travel", "language expertise", "environmental science", "user experience design"),
// Arrays.asList("webSearch", "translate", "getWeather", "findRestaurants", "getNewsUpdates", "compareProductPrices")),
// new Task("task9", "Create a comprehensive health and wellness program for a large corporation, including mental health support",
// Arrays.asList("health", "nutrition", "psychology", "corporate wellness", "data analysis"),
// Arrays.asList("findFitnessClasses", "getRecipe", "setReminder", "getWeather", "scheduleAppointment", "getFinancialAdvice")),
// new Task("task10", "Plan and execute a global product launch campaign for a revolutionary eco-friendly technology",
// Arrays.asList("marketing", "environmental science", "international business", "public relations", "social media"),
// Arrays.asList("webSearch", "getNewsUpdates", "scheduleAppointment", "translate", "compareProductPrices", "bookTravel")),
// new Task("task11", "Design and implement a smart city initiative focusing on transportation, energy, and public safety",
// Arrays.asList("urban planning", "environmental science", "data analysis", "public policy", "technology integration"),
// Arrays.asList("webSearch", "getWeather", "calculateDistance", "getNewsUpdates", "getFinancialAdvice", "findHomeServices"))
);
for (Task task : tasks) {
taskManager.addTask(task); // Add each task to the TaskManager
System.out.println("\nProcessing task: " + task.getDescription());
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
if (!team.isEmpty()) {
taskManager.executeTask(task.getId(), team);
System.out.println("Task result: " + task.getResult());
} else {
System.out.println("No suitable agents found for this task. Consider updating the agent pool or revising the task requirements.");
}
}
SpringApplication.run(IoASystem.class, args);
}
}

View File

@ -1,15 +1,21 @@
package com.ioa.agent;
import com.ioa.conversation.ConversationFSM;
import com.ioa.conversation.Message;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
import com.ioa.websocket.AgentProcessingUpdate;
import com.ioa.websocket.AgentResponseUpdate;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
public class AgentInfo {
private String id;
private String name;
@ -21,26 +27,30 @@ public class AgentInfo {
private ToolRegistry toolRegistry;
private BedrockLanguageModel model;
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools,
TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, BedrockLanguageModel model) {
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
this.id = id;
this.name = name;
this.capabilities = capabilities;
this.tools = tools;
this.memory = new Memory();
}
public void setDependencies(TreeOfThought treeOfThought, WebSocketService webSocketService,
ToolRegistry toolRegistry, BedrockLanguageModel model) {
this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry;
this.model = model;
this.memory = new Memory();
}
public void receiveMessage(Message message) {
public void receiveMessageLocal(Message message) {
if (this.memory == null)
this.memory = new Memory();
memory.addToHistory(message.getContent());
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nYou received a message: " + message.getContent() +
"\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
"\nYou received a message: " + message.getContent() +
"\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
String response = model.generate(prompt, null);
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
System.out.println("DEBUG: " + name + " response: " + response);
@ -49,31 +59,24 @@ public class AgentInfo {
memory.addToHistory("My response: " + response);
}
public void performTreeOfThought(String task) {
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nTask: " + task +
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
"\n\nMemory:\n" + memory.getFormattedMemory();
public void receiveMessage(Message message) {
if (this.memory == null)
this.memory = new Memory();
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
String reasoning = (String) totResult.get("reasoning");
// Add the reasoning to memory
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
// This is a turn notification for this agent
respondToTurn(message.getConversationId());
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
}
public void notifyTurn(ConversationFSM conversation) {
private void respondToTurn(String conversationId) {
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nIt's your turn to speak in the conversation. What would you like to say or do?";
String response = model.generate(prompt, null);
conversation.postMessage(new Message(conversation.getConversationId(), id, response));
"\nIt's your turn to speak in the conversation. What would you like to say or do?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
// Send the response back to the conversation
Message responseMessage = new Message(conversationId, this.id, prompt);
webSocketService.sendUpdate("conversation_message", responseMessage);
}
// Getters and setters
public String getId() { return id; }
public String getName() { return name; }
public List<String> getCapabilities() { return capabilities; }
public List<String> getTools() { return tools; }
}

View File

@ -0,0 +1,12 @@
package com.ioa.agent;
public class AgentMessage {
private String agentId;
private String content;
// Getters and setters
public String getAgentId() { return agentId; }
public void setAgentId(String agentId) { this.agentId = agentId; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -1,34 +1,53 @@
package com.ioa.agent;
import com.ioa.tool.ToolRegistry;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component;
import com.ioa.tool.ToolRegistry;
import java.util.*;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@Component
public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>();
private ToolRegistry toolRegistry;
private Map<String, AgentInfo> agents = new ConcurrentHashMap<>();
private final SimpMessagingTemplate messagingTemplate;
private final ToolRegistry toolRegistry;
public AgentRegistry(ToolRegistry toolRegistry) {
@Autowired
public AgentRegistry(SimpMessagingTemplate messagingTemplate, ToolRegistry toolRegistry) {
this.messagingTemplate = messagingTemplate;
this.toolRegistry = toolRegistry;
}
public void registerAgent(String agentId, AgentInfo agentInfo) {
agents.put(agentId, agentInfo);
public void registerAgent(AgentInfo agentInfo) {
agents.put(agentInfo.getId(), agentInfo);
// Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool);
}
}
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
public void unregisterAgent(String agentId) {
agents.remove(agentId);
messagingTemplate.convertAndSend("/topic/agents", "Agent unregistered: " + agentId);
}
public AgentInfo getAgent(String agentId) {
return agents.get(agentId);
}
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
public List<AgentInfo> searchAgents(List<String> capabilities) {
return searchAgents(capabilities, 1.0); // Default to exact match
}
@ -50,8 +69,4 @@ public class AgentRegistry {
.count();
return (double) matchingCapabilities / requiredCapabilities.size();
}
public List<AgentInfo> getAllAgents() {
return new ArrayList<>(agents.values());
}
}

View File

@ -1,7 +1,19 @@
package com.ioa.config;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.util.MimeTypeUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.converter.DefaultContentTypeResolver;
import java.util.List;
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
@ -18,6 +30,34 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry.addEndpoint("/ws").withSockJS();
registry.addEndpoint("/ws")
.setAllowedOriginPatterns("*")
.withSockJS();
}
@Override
public boolean configureMessageConverters(List<MessageConverter> messageConverters) {
MappingJackson2MessageConverter converter = new MappingJackson2MessageConverter();
DefaultContentTypeResolver resolver = new DefaultContentTypeResolver();
resolver.setDefaultMimeType(MimeTypeUtils.APPLICATION_JSON);
converter.setContentTypeResolver(resolver);
messageConverters.add(converter);
return false;
}
@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(new ChannelInterceptor() {
@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
if (accessor != null && accessor.getCommand() != null) {
System.out.println("Received STOMP Frame: " + accessor.getCommand());
System.out.println("Destination: " + accessor.getDestination());
System.out.println("Payload: " + new String((byte[]) message.getPayload()));
}
return message;
}
});
}
}

View File

@ -91,18 +91,33 @@ public class ConversationFSM {
private void updateState(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState +
"\nParticipants: " + participants;
"\nParticipants: " + participants +
"\nProvide the next conversation state (" +
"DISCUSSION," +
"RESEARCH," +
"RESEARCH_TASK," +
"TASK_GATHERING_INFO," +
"TASK," +
"TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String reasoning = model.generate(stateTransitionTask, null);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION,\n" + //
" TASK_GATHERING_INFO,\n" + //
" TASK,\n" + //
" TASK_PLANNING,\n" + //
" TASK_ASSIGNMENT,\n" + //
" EXECUTION,\n" + //
" CONCLUSION). Only give the single word answer in all caps only from the given options.";
"\nProvide the next conversation state (" +
"DISCUSSION," +
"RESEARCH," +
"RESEARCH_TASK," +
"TASK_GATHERING_INFO," +
"TASK," +
"TASK_PLANNING," +
"TASK_ASSIGNMENT," +
"EXECUTION," +
"CONCLUSION" +
"\n)\nOnly give the single word answer in all caps only from the given options.";
String response = model.generate(decisionPrompt, null);
ConversationState newState = ConversationState.valueOf(response.trim());
@ -112,7 +127,10 @@ public class ConversationFSM {
private void notifyNextSpeaker() {
AgentInfo nextSpeaker = speakingQueue.poll();
if (nextSpeaker != null) {
nextSpeaker.notifyTurn(this);
// Instead of calling notifyTurn, we'll create a turn notification message
Message turnNotification = new Message(this.conversationId, "SYSTEM",
"It's " + nextSpeaker.getName() + "'s turn to speak.");
broadcastMessage(turnNotification);
speakingQueue.offer(nextSpeaker);
}
}

View File

@ -3,23 +3,23 @@ package com.ioa.conversation;
import com.ioa.agent.AgentInfo;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.*;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@Component
public class ConversationManager {
private final Map<String, ConversationFSM> conversations;
private final Map<String, ConversationFSM> conversations = new ConcurrentHashMap<>();
private final BedrockLanguageModel model;
private final WebSocketService webSocketService;
private final ScheduledExecutorService executorService;
@Autowired
public ConversationManager(BedrockLanguageModel model, WebSocketService webSocketService) {
this.conversations = new ConcurrentHashMap<>();
this.model = model;
this.webSocketService = webSocketService;
this.executorService = Executors.newScheduledThreadPool(1);
}
public String createConversation() {
@ -30,6 +30,21 @@ public class ConversationManager {
return conversationId;
}
public void joinConversation(String conversationId) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
// Logic to join a conversation (e.g., add participant)
webSocketService.sendUpdate("conversation_joined", conversationId);
}
}
public void postMessage(String conversationId, String senderId, String content) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
conversation.postMessage(new Message(conversationId, senderId, content));
}
}
public ConversationFSM getConversation(String conversationId) {
return conversations.get(conversationId);
}
@ -48,28 +63,13 @@ public class ConversationManager {
}
}
public void postMessage(String conversationId, String senderId, String content) {
System.out.println("DEBUG: Posting message - ConversationId: " + conversationId + ", SenderId: " + senderId + ", Content: " + content);
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
if (content == null) {
Arrays.toString(Thread.currentThread().getStackTrace()).replace( ',', '\n' );
System.out.println("WARNING: Attempting to post null content message");
return;
}
conversation.postMessage(new Message(conversationId, senderId, content));
} else {
System.out.println("WARNING: Conversation not found for id: " + conversationId);
}
}
public void startConversation(String conversationId, String initialMessage) {
ConversationFSM conversation = conversations.get(conversationId);
if (conversation != null) {
conversation.postMessage(new Message(conversationId, "SYSTEM", initialMessage));
// Start a timer to end the conversation after 10 minutes
executorService.schedule(() -> {
Executor.schedule(() -> {
if (!conversation.isFinished()) {
conversation.finish("Time limit reached");
}

View File

@ -2,6 +2,8 @@ package com.ioa.conversation;
public enum ConversationState {
DISCUSSION,
RESEARCH,
RESEARCH_TASK,
TASK_GATHERING_INFO,
TASK,
TASK_PLANNING,

View File

@ -17,6 +17,7 @@ import org.springframework.stereotype.Component;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Base64;
import java.util.function.Consumer;
@Component
public class BedrockLanguageModel {
@ -36,33 +37,7 @@ public class BedrockLanguageModel {
public String generate(String prompt, String imagePath) {
System.out.println("DEBUG: Generating response for prompt: " + prompt);
try {
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("anthropic_version", "bedrock-2023-05-31");
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image"); // Add type field
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text"); // Add type field
textNode.put("text", prompt);
ObjectNode requestBody = createRequestBody(prompt, imagePath);
String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
@ -100,4 +75,77 @@ public class BedrockLanguageModel {
return "Error: " + e.getMessage();
}
}
public void generateStream(String prompt, String imagePath, Consumer<String> chunkConsumer) {
System.out.println("DEBUG: Generating streaming response for prompt: " + prompt);
try {
ObjectNode requestBody = createRequestBody(prompt, imagePath);
String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
.modelId(modelId)
.contentType("application/json")
.accept("application/json")
.body(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeModelResponse response = bedrockClient.invokeModel(invokeRequest);
String responseBody = response.body().asUtf8String();
System.out.println("DEBUG: Raw response from Bedrock: " + responseBody);
JsonNode responseJson = objectMapper.readTree(responseBody);
JsonNode contentArray = responseJson.path("content");
if (contentArray.isArray() && contentArray.size() > 0) {
for (JsonNode content : contentArray) {
String chunk = content.path("text").asText();
chunkConsumer.accept(chunk);
}
} else {
System.out.println("WARNING: Unexpected response format. Full response: " + responseBody);
chunkConsumer.accept("Unexpected response format");
}
} catch (Exception e) {
System.out.println("ERROR: Failed to generate streaming text with Bedrock: " + e.getMessage());
e.printStackTrace();
chunkConsumer.accept("Error: " + e.getMessage());
}
}
private ObjectNode createRequestBody(String prompt, String imagePath) {
ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("anthropic_version", "bedrock-2023-05-31");
requestBody.put("max_tokens", 20000);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject();
message.put("role", "user");
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
try {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image");
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
} catch (Exception e) {
System.err.println("Error reading image file: " + e.getMessage());
e.printStackTrace();
}
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text");
textNode.put("text", prompt);
return requestBody;
}
}

View File

@ -1,81 +1,17 @@
package com.ioa.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
@Service
public class WebSocketService {
private final SimpMessagingTemplate messagingTemplate;
private final ConversationManager conversationManager;
private final ObjectMapper objectMapper;
@Autowired
public WebSocketService(SimpMessagingTemplate messagingTemplate, @Lazy ConversationManager conversationManager) {
public WebSocketService(SimpMessagingTemplate messagingTemplate) {
this.messagingTemplate = messagingTemplate;
this.conversationManager = conversationManager;
this.objectMapper = new ObjectMapper();
}
public void handleMessage(WebSocketSession session, TextMessage message) {
try {
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
if (messageArray.length > 0) {
String content = messageArray[0];
// Assume we're using a default conversation ID for simplicity
String conversationId = "default";
// Assume we're using a default user ID for simplicity
String userId = "user";
conversationManager.postMessage(conversationId, userId, content);
}
} catch (Exception e) {
e.printStackTrace();
}
}
public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload);
}
public void handleWebSocketMessage(String message) {
System.out.println("DEBUG: Received WebSocket message: " + message);
// Parse the WebSocket frame
String[] parts = message.split("\n\n", 2);
if (parts.length < 2) {
System.out.println("DEBUG: Invalid WebSocket message format");
return;
}
String headers = parts[0];
String payload = parts[1];
// Parse the JSON payload
try {
JsonNode jsonNode = objectMapper.readTree(payload);
// Extract relevant information from the JSON
// Adjust this based on the actual structure of your WebSocket messages
String conversationId = jsonNode.path("conversationId").asText();
String sender = jsonNode.path("sender").asText();
String content = jsonNode.path("content").asText();
// Create a new Message object
Message parsedMessage = new Message(conversationId, sender, content);
System.out.println("DEBUG: WebSocket message: " + payload);
// Process the message
conversationManager.postMessage(conversationId, sender, content);
} catch (Exception e) {
System.out.println("DEBUG: Error parsing WebSocket message: " + e.getMessage());
e.printStackTrace();
}
}
}

View File

@ -1,11 +1,8 @@
package com.ioa.task;
import com.ioa.agent.AgentInfo;
import lombok.Data;
import java.util.List;
@Data
public class Task {
private String id;
private String description;
@ -14,10 +11,28 @@ public class Task {
private AgentInfo assignedAgent;
private String result;
// Default constructor
public Task() {}
// Existing constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id;
this.description = description;
this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools;
}
// Getters and setters for all fields
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getDescription() { return description; }
public void setDescription(String description) { this.description = description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public void setRequiredCapabilities(List<String> requiredCapabilities) { this.requiredCapabilities = requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public void setRequiredTools(List<String> requiredTools) { this.requiredTools = requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
}

View File

@ -0,0 +1,14 @@
package com.ioa.task;
import java.util.List;
public class TaskAssignment {
private String taskId;
private List<String> agentIds;
// Getters and setters
public String getTaskId() { return taskId; }
public void setTaskId(String taskId) { this.taskId = taskId; }
public List<String> getAgentIds() { return agentIds; }
public void setAgentIds(List<String> agentIds) { this.agentIds = agentIds; }
}

View File

@ -1,37 +1,84 @@
package com.ioa.task;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Component;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.conversation.ConversationFSM;
import com.ioa.conversation.ConversationManager;
import com.ioa.conversation.Message;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.team.TeamFormation;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
import com.ioa.conversation.Message;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@Component
public class TaskManager {
private Map<String, Task> tasks = new HashMap<>();
private Map<String, Task> tasks = new ConcurrentHashMap<>();
private final SimpMessagingTemplate messagingTemplate;
private AgentRegistry agentRegistry;
private BedrockLanguageModel model;
private ToolRegistry toolRegistry;
private TreeOfThought treeOfThought;
private ConversationManager conversationManager;
private WebSocketService webSocketService;
public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry,
TreeOfThought treeOfThought, ConversationManager conversationManager) {
@Autowired
private TeamFormation teamFormation;
@Autowired
public TaskManager(AgentRegistry agentRegistry,
BedrockLanguageModel model,
ToolRegistry toolRegistry,
TreeOfThought treeOfThought,
ConversationManager conversationManager,
WebSocketService webSocketService,
SimpMessagingTemplate messagingTemplate) {
this.agentRegistry = agentRegistry;
this.model = model;
this.toolRegistry = toolRegistry;
this.treeOfThought = treeOfThought;
this.conversationManager = conversationManager;
this.webSocketService = webSocketService;
this.messagingTemplate = messagingTemplate;
}
public String createTask(Task task) {
String taskId = UUID.randomUUID().toString();
task.setId(taskId);
tasks.put(taskId, task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
// Automatically assign agents and execute the task
List<AgentInfo> team = teamFormation.formTeam(task);
if (!team.isEmpty()) {
executeTask(taskId, team);
} else {
System.out.println("No suitable agents found for task: " + taskId);
}
return taskId;
}
public void assignTask(String taskId, List<String> agentIds) {
Task task = tasks.get(taskId);
if (task != null) {
List<AgentInfo> team = agentIds.stream()
.map(agentRegistry::getAgent)
.filter(Objects::nonNull)
.collect(Collectors.toList());
executeTask(taskId, team);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task assigned to: " + String.join(", ", agentIds));
}
}
public void addTask(Task task) {
@ -54,9 +101,13 @@ public class TaskManager {
conversation.postMessage(new Message(conversationId, "SYSTEM", "Let's work on the task: " + task.getDescription()));
// Allow agents to interact for a maximum of 10 minutes
// Use the LLM to generate a plan for the task
String planPrompt = "Generate a step-by-step plan to accomplish the following task: " + task.getDescription();
String plan = model.generate(planPrompt, null);
conversation.postMessage(new Message(conversationId, "SYSTEM", "Here's the plan: " + plan));
// Allow agents to interact for a maximum of 40 minutes
long startTime = System.currentTimeMillis();
//while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(10)) {
while (!conversation.isFinished() && (System.currentTimeMillis() - startTime) < TimeUnit.MINUTES.toMillis(40)) {
try {
Thread.sleep(1000); // Check every second
@ -72,8 +123,10 @@ public class TaskManager {
String result = conversation.getResult();
task.setResult(result);
System.out.println("Task completed. Result: " + result);
messagingTemplate.convertAndSend("/topic/tasks/" + taskId, "Task completed: " + result);
}
public Task getTask(String taskId) {
return tasks.get(taskId);
}

View File

@ -0,0 +1,31 @@
package com.ioa.task;
public class TaskResult {
private String taskId;
private String result;
// Constructors
public TaskResult() {}
public TaskResult(String taskId, String result) {
this.taskId = taskId;
this.result = result;
}
// Getters and setters
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
}

View File

@ -0,0 +1,29 @@
package com.ioa.tool;
public class ToolInfo {
private String agentId;
private String toolName;
public ToolInfo() {}
public ToolInfo(String agentId, String toolName) {
this.agentId = agentId;
this.toolName = toolName;
}
public String getAgentId() {
return agentId;
}
public void setAgentId(String agentId) {
this.agentId = agentId;
}
public String getToolName() {
return toolName;
}
public void setToolName(String toolName) {
this.toolName = toolName;
}
}

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentProcessingUpdate {
private String agentId;
private String responseId;
private String status; // "start" or "complete"
public AgentProcessingUpdate(String agentId, String responseId, String status) {
this.agentId = agentId;
this.responseId = responseId;
this.status = status;
}
}

View File

@ -0,0 +1,16 @@
package com.ioa.websocket;
import lombok.Data;
@Data
public class AgentResponseUpdate {
private String agentId;
private String responseId;
private String chunk;
public AgentResponseUpdate(String agentId, String responseId, String chunk) {
this.agentId = agentId;
this.responseId = responseId;
this.chunk = chunk;
}
}

View File

@ -0,0 +1,50 @@
package com.ioa.websocket;
import com.ioa.conversation.ConversationManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.stereotype.Controller;
@Controller
public class ChatController {
private final ConversationManager conversationManager;
@Autowired
public ChatController(ConversationManager conversationManager) {
this.conversationManager = conversationManager;
}
@MessageMapping("/chat/create")
@SendTo("/topic/rooms")
public String createRoom(String roomName) {
String roomId = conversationManager.createConversation();
return "Room created: " + roomId;
}
@MessageMapping("/chat/join")
@SendTo("/topic/rooms")
public String joinRoom(String roomId) {
conversationManager.joinConversation(roomId);
return "Joined room: " + roomId;
}
@MessageMapping("/chat/message")
public void sendMessage(ChatMessage message) {
conversationManager.postMessage(message.getRoomId(), message.getSender(), message.getContent());
}
}
class ChatMessage {
private String roomId;
private String sender;
private String content;
// Getters and setters
public String getRoomId() { return roomId; }
public void setRoomId(String roomId) { this.roomId = roomId; }
public String getSender() { return sender; }
public void setSender(String sender) { this.sender = sender; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}

View File

@ -0,0 +1,91 @@
package com.ioa.websocket;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Controller;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.model.BedrockLanguageModel;
import com.ioa.service.WebSocketService;
import com.ioa.task.Task;
import com.ioa.task.TaskAssignment;
import com.ioa.task.TaskManager;
import com.ioa.task.TaskResult;
import com.ioa.tool.ToolInfo;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
@Controller
public class WebSocketHandler {
private final AgentRegistry agentRegistry;
private final TaskManager taskManager;
private final SimpMessagingTemplate messagingTemplate;
private final TreeOfThought treeOfThought;
private final WebSocketService webSocketService;
private final ToolRegistry toolRegistry;
private final BedrockLanguageModel model;
@Autowired
public WebSocketHandler(AgentRegistry agentRegistry, TaskManager taskManager,
SimpMessagingTemplate messagingTemplate, TreeOfThought treeOfThought,
WebSocketService webSocketService, ToolRegistry toolRegistry,
BedrockLanguageModel model) {
this.agentRegistry = agentRegistry;
this.taskManager = taskManager;
this.messagingTemplate = messagingTemplate;
this.treeOfThought = treeOfThought;
this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry;
this.model = model;
}
@MessageMapping("/agent/register")
public void registerAgent(@Payload AgentInfo agentInfo) {
// Initialize dependencies
agentInfo.setDependencies(treeOfThought, webSocketService, toolRegistry, model);
agentRegistry.registerAgent(agentInfo);
messagingTemplate.convertAndSend("/topic/agents", "New agent registered: " + agentInfo.getId());
}
@MessageMapping("/tool/register")
public void registerTool(@Payload ToolInfo toolInfo) {
AgentInfo agent = agentRegistry.getAgent(toolInfo.getAgentId());
if (agent != null) {
agent.getTools().add(toolInfo.getToolName());
messagingTemplate.convertAndSend("/topic/tools", "New tool registered: " + toolInfo.getToolName());
}
}
@MessageMapping("/task/create")
public void createTask(@Payload Task task) {
System.out.println("Creating task: " + task.getId());
String taskId = taskManager.createTask(task);
messagingTemplate.convertAndSend("/topic/tasks", "New task created: " + taskId);
}
@MessageMapping("/agent/unregister")
public void unregisterAgent(@Payload String agentId) {
agentRegistry.unregisterAgent(agentId);
messagingTemplate.convertAndSend("/topic/agent/" + agentId, "Unregistration successful");
}
@MessageMapping("/task/assign")
public void assignTask(@Payload TaskAssignment assignment) {
taskManager.assignTask(assignment.getTaskId(), assignment.getAgentIds());
messagingTemplate.convertAndSend("/topic/tasks/" + assignment.getTaskId(), "Task assigned");
}
@MessageMapping("/task/result")
public void handleTaskResult(@Payload TaskResult result) {
Task task = taskManager.getTask(result.getTaskId());
if (task != null) {
task.setResult(result.getResult());
messagingTemplate.convertAndSend("/topic/tasks/" + result.getTaskId(), "Task completed: " + result.getResult());
}
}
}