commit a97cd853fe911d4a072a61f61043c38aba2dca6a Author: Mahesh Kommareddi Date: Tue Jul 16 19:25:40 2024 -0400 First commit diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..f3d8047 --- /dev/null +++ b/pom.xml @@ -0,0 +1,12 @@ + + + software.amazon.awssdk + bedrock-runtime + 2.20.0 + + + com.fasterxml.jackson.core + jackson-databind + 2.13.0 + + \ No newline at end of file diff --git a/src/main/Main.java b/src/main/Main.java new file mode 100644 index 0000000..1ad7458 --- /dev/null +++ b/src/main/Main.java @@ -0,0 +1,68 @@ +package com.ioa; + +import com.ioa.agent.AgentInfo; +import com.ioa.agent.AgentRegistry; +import com.ioa.task.Task; +import com.ioa.task.TaskManager; +import com.ioa.team.TeamFormation; +import com.ioa.tool.CommonTools; +import com.ioa.tool.ToolRegistry; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; + +public class Main { + public static void main(String[] args) { + // Initialize the system + ToolRegistry toolRegistry = new ToolRegistry(); + CommonTools commonTools = new CommonTools(); + + // Register all tools from CommonTools + for (Method method : CommonTools.class.getMethods()) { + if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) { + toolRegistry.registerTool(method.getName(), method); + } + } + + AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); + ChatLanguageModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .build(); + + TeamFormation teamFormation = new TeamFormation(agentRegistry, model); + TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); + + // Register some example agents + AgentInfo agent1 = new AgentInfo("agent1", "General Assistant", + Arrays.asList("general", "search"), + Arrays.asList("webSearch", "getWeather", "setReminder")); + AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert", + Arrays.asList("travel", "booking"), + Arrays.asList("bookTravel", "calculateDistance", "findRestaurants")); + + agentRegistry.registerAgent(agent1.getId(), agent1); + agentRegistry.registerAgent(agent2.getId(), agent2); + + // Create a sample task + Task task = new Task("task1", "Plan a weekend trip to Paris", + Arrays.asList("travel", "booking"), + Arrays.asList("bookTravel", "findRestaurants", "getWeather")); + + // Form a team for the task + List team = teamFormation.formTeam(task); + System.out.println("Formed team: " + team); + + // Assign the task to the first agent in the team (simplified) + task.setAssignedAgent(team.get(0)); + + // Execute the task + taskManager.addTask(task); + taskManager.executeTask(task.getId()); + + // Print the result + System.out.println("Task result: " + task.getResult()); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/IoASystem.java b/src/main/java/com/ioa/IoASystem.java new file mode 100644 index 0000000..69d5db0 --- /dev/null +++ b/src/main/java/com/ioa/IoASystem.java @@ -0,0 +1,35 @@ +package com.ioa; + +import com.ioa.agent.AgentRegistry; +import com.ioa.task.TaskManager; +import com.ioa.team.TeamFormation; +import com.ioa.tool.CommonTools; +import com.ioa.tool.ToolRegistry; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; + +import java.lang.reflect.Method; + +public class IoASystem { + public static void initialize() { + ToolRegistry toolRegistry = new ToolRegistry(); + CommonTools commonTools = new CommonTools(); + + // Register all tools from CommonTools + for (Method method : CommonTools.class.getMethods()) { + if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) { + toolRegistry.registerTool(method.getName(), method); + } + } + + AgentRegistry agentRegistry = new AgentRegistry(toolRegistry); + ChatLanguageModel model = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .build(); + + TeamFormation teamFormation = new TeamFormation(agentRegistry, model); + TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry); + + // Initialize other components as needed + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/agent/AgentInfo.java b/src/main/java/com/ioa/agent/AgentInfo.java new file mode 100644 index 0000000..acd3a68 --- /dev/null +++ b/src/main/java/com/ioa/agent/AgentInfo.java @@ -0,0 +1,28 @@ +package com.ioa.agent; + +import java.util.List; + +public class AgentInfo { + private String id; + private String name; + private List capabilities; + private List tools; + + // Constructor + public AgentInfo(String id, String name, List capabilities, List tools) { + this.id = id; + this.name = name; + this.capabilities = capabilities; + this.tools = tools; + } + + // Getters and setters + public String getId() { return id; } + public void setId(String id) { this.id = id; } + public String getName() { return name; } + public void setName(String name) { this.name = name; } + public List getCapabilities() { return capabilities; } + public void setCapabilities(List capabilities) { this.capabilities = capabilities; } + public List getTools() { return tools; } + public void setTools(List tools) { this.tools = tools; } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/agent/AgentRegistry.java b/src/main/java/com/ioa/agent/AgentRegistry.java new file mode 100644 index 0000000..6bda082 --- /dev/null +++ b/src/main/java/com/ioa/agent/AgentRegistry.java @@ -0,0 +1,37 @@ +package com.ioa.agent; + +import com.ioa.tool.ToolRegistry; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class AgentRegistry { + private Map agents = new HashMap<>(); + private ToolRegistry toolRegistry; + + public AgentRegistry(ToolRegistry toolRegistry) { + this.toolRegistry = toolRegistry; + } + + public void registerAgent(String agentId, AgentInfo agentInfo) { + agents.put(agentId, agentInfo); + // Register agent's tools + for (String tool : agentInfo.getTools()) { + if (toolRegistry.getTool(tool) == null) { + throw new IllegalArgumentException("Tool not found in registry: " + tool); + } + } + } + + public AgentInfo getAgent(String agentId) { + return agents.get(agentId); + } + + public List searchAgents(List capabilities) { + return agents.values().stream() + .filter(agent -> agent.getCapabilities().containsAll(capabilities)) + .collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/conversation/ConversationFSM.java b/src/main/java/com/ioa/conversation/ConversationFSM.java new file mode 100644 index 0000000..08f5a7b --- /dev/null +++ b/src/main/java/com/ioa/conversation/ConversationFSM.java @@ -0,0 +1,66 @@ +package com.ioa.conversation; + +import com.ioa.util.TreeOfThought; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; + +public class ConversationFSM { + private ConversationState currentState; + private TreeOfThought treeOfThought; + + public ConversationFSM(ChatLanguageModel model) { + this.currentState = ConversationState.DISCUSSION; + this.treeOfThought = new TreeOfThought(model); + } + + public void handleMessage(Message message) { + String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() + + "\nCurrent state: " + currentState; + + String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3); + + String decisionPrompt = "Based on this reasoning:\n" + reasoning + + "\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION)."; + Response response = treeOfThought.getModel().generate(decisionPrompt); + + ConversationState newState = ConversationState.valueOf(response.content().trim()); + transitionTo(newState); + + // Handle the message based on the new state + switch (newState) { + case DISCUSSION: + handleDiscussionMessage(message); + break; + case TASK_ASSIGNMENT: + handleTaskAssignmentMessage(message); + break; + case EXECUTION: + handleExecutionMessage(message); + break; + case CONCLUSION: + handleConclusionMessage(message); + break; + } + } + + private void transitionTo(ConversationState newState) { + // Add any transition logic here + this.currentState = newState; + } + + private void handleDiscussionMessage(Message message) { + // Implement discussion logic + } + + private void handleTaskAssignmentMessage(Message message) { + // Implement task assignment logic + } + + private void handleExecutionMessage(Message message) { + // Implement execution logic + } + + private void handleConclusionMessage(Message message) { + // Implement conclusion logic + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/conversation/ConversationState.java b/src/main/java/com/ioa/conversation/ConversationState.java new file mode 100644 index 0000000..b8921b4 --- /dev/null +++ b/src/main/java/com/ioa/conversation/ConversationState.java @@ -0,0 +1,5 @@ +package com.ioa.conversation; + +public enum ConversationState { + DISCUSSION, TASK_ASSIGNMENT, EXECUTION, CONCLUSION +} \ No newline at end of file diff --git a/src/main/java/com/ioa/conversation/Message.java b/src/main/java/com/ioa/conversation/Message.java new file mode 100644 index 0000000..278e4f1 --- /dev/null +++ b/src/main/java/com/ioa/conversation/Message.java @@ -0,0 +1,14 @@ +package com.ioa.conversation; + +public class Message { + private String sender; + private String content; + + public Message(String sender, String content) { + this.sender = sender; + this.content = content; + } + + public String getSender() { return sender; } + public String getContent() { return content; } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/model/BedrockLanguageModel.java b/src/main/java/com/ioa/model/BedrockLanguageModel.java new file mode 100644 index 0000000..a380393 --- /dev/null +++ b/src/main/java/com/ioa/model/BedrockLanguageModel.java @@ -0,0 +1,54 @@ +package com.ioa.model; + +import com.fasterxml.jackson.databind.ObjectMapper; +import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.core.SdkBytes; + +import java.util.Map; +import java.util.HashMap; + +public class BedrockLanguageModel { + private final BedrockRuntimeClient bedrockClient; + private final ObjectMapper objectMapper; + private final String modelId; + + public BedrockLanguageModel(String modelId) { + this.bedrockClient = BedrockRuntimeClient.builder() + .region(Region.US_EAST_1) + .credentialsProvider(ProfileCredentialsProvider.create()) + .build(); + this.objectMapper = new ObjectMapper(); + this.modelId = modelId; + } + + public String generate(String prompt) { + try { + Map requestBody = new HashMap<>(); + requestBody.put("prompt", prompt); + requestBody.put("max_tokens_to_sample", 500); + requestBody.put("temperature", 0.7); + requestBody.put("top_p", 0.9); + + String jsonPayload = objectMapper.writeValueAsString(requestBody); + + InvokeModelRequest invokeModelRequest = InvokeModelRequest.builder() + .modelId(modelId) + .contentType("application/json") + .accept("application/json") + .body(SdkBytes.fromUtf8String(jsonPayload)) + .build(); + + InvokeModelResponse response = bedrockClient.invokeModel(invokeModelRequest); + String responseBody = response.body().asUtf8String(); + + Map responseMap = objectMapper.readValue(responseBody, Map.class); + return (String) responseMap.get("completion"); + } catch (Exception e) { + throw new RuntimeException("Error generating text with Bedrock", e); + } + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/task/Task.java b/src/main/java/com/ioa/task/Task.java new file mode 100644 index 0000000..6cf5677 --- /dev/null +++ b/src/main/java/com/ioa/task/Task.java @@ -0,0 +1,32 @@ +package com.ioa.task; + +import com.ioa.agent.AgentInfo; + +import java.util.List; + +public class Task { + private String id; + private String description; + private List requiredCapabilities; + private List requiredTools; + private AgentInfo assignedAgent; + private String result; + + // Constructor + public Task(String id, String description, List requiredCapabilities, List requiredTools) { + this.id = id; + this.description = description; + this.requiredCapabilities = requiredCapabilities; + this.requiredTools = requiredTools; + } + + // Getters and setters + public String getId() { return id; } + public String getDescription() { return description; } + public List getRequiredCapabilities() { return requiredCapabilities; } + public List getRequiredTools() { return requiredTools; } + public AgentInfo getAssignedAgent() { return assignedAgent; } + public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; } + public String getResult() { return result; } + public void setResult(String result) { this.result = result; } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/task/TaskManager.java b/src/main/java/com/ioa/task/TaskManager.java new file mode 100644 index 0000000..990e9d0 --- /dev/null +++ b/src/main/java/com/ioa/task/TaskManager.java @@ -0,0 +1,63 @@ +package com.ioa.task; + +import com.ioa.agent.AgentInfo; +import com.ioa.agent.AgentRegistry; +import com.ioa.tool.ToolRegistry; +import com.ioa.util.TreeOfThought; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; + +import java.util.HashMap; +import java.util.Map; + +public class TaskManager { + private Map tasks = new HashMap<>(); + private AgentRegistry agentRegistry; + private TreeOfThought treeOfThought; + private ToolRegistry toolRegistry; + + public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) { + this.agentRegistry = agentRegistry; + this.treeOfThought = new TreeOfThought(model); + this.toolRegistry = toolRegistry; + } + + public void executeTask(String taskId) { + Task task = tasks.get(taskId); + AgentInfo agent = task.getAssignedAgent(); + + String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() + + "\nAssigned agent capabilities: " + agent.getCapabilities() + + "\nAvailable tools: " + agent.getTools(); + + String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3); + + String executionPrompt = "Based on this execution plan:\n" + reasoning + + "\nExecute the task using the available tools and provide the result."; + Response response = treeOfThought.getModel().generate(executionPrompt); + + String result = executeToolsFromResponse(response.content(), agent); + + task.setResult(result); + } + + private String executeToolsFromResponse(String response, AgentInfo agent) { + StringBuilder result = new StringBuilder(); + for (String tool : agent.getTools()) { + if (response.contains(tool)) { + Object toolInstance = toolRegistry.getTool(tool); + // Execute the tool (this is a simplified representation) + result.append(tool).append(" result: ").append(toolInstance.toString()).append("\n"); + } + } + return result.toString(); + } + + public void addTask(Task task) { + tasks.put(task.getId(), task); + } + + public Task getTask(String taskId) { + return tasks.get(taskId); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/team/TeamFormation.java b/src/main/java/com/ioa/team/TeamFormation.java new file mode 100644 index 0000000..3d925da --- /dev/null +++ b/src/main/java/com/ioa/team/TeamFormation.java @@ -0,0 +1,53 @@ +package com.ioa.team; + +import com.ioa.agent.AgentInfo; +import com.ioa.agent.AgentRegistry; +import com.ioa.task.Task; +import com.ioa.util.TreeOfThought; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class TeamFormation { + private AgentRegistry agentRegistry; + private TreeOfThought treeOfThought; + + public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) { + this.agentRegistry = agentRegistry; + this.treeOfThought = new TreeOfThought(model); + } + + public List formTeam(Task task) { + List requiredCapabilities = task.getRequiredCapabilities(); + List requiredTools = task.getRequiredTools(); + List potentialAgents = agentRegistry.searchAgents(requiredCapabilities); + + String teamFormationTask = "Form the best team for this task: " + task.getDescription() + + "\nRequired tools: " + requiredTools + + "\nAvailable agents and their tools: " + formatAgentTools(potentialAgents); + + String reasoning = treeOfThought.reason(teamFormationTask, 3, 3); + + String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning + + "\nProvide the final team composition as a comma-separated list of agent IDs."; + Response response = treeOfThought.getModel().generate(finalDecisionPrompt); + + return parseTeamComposition(response.content(), potentialAgents); + } + + private String formatAgentTools(List agents) { + return agents.stream() + .map(agent -> agent.getId() + ": " + agent.getTools()) + .collect(Collectors.joining(", ")); + } + + private List parseTeamComposition(String composition, List potentialAgents) { + List selectedIds = Arrays.asList(composition.split(",")); + return potentialAgents.stream() + .filter(agent -> selectedIds.contains(agent.getId())) + .collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/tool/CommonTools.java b/src/main/java/com/ioa/tool/CommonTools.java new file mode 100644 index 0000000..6ecd50c --- /dev/null +++ b/src/main/java/com/ioa/tool/CommonTools.java @@ -0,0 +1,126 @@ +package com.ioa.tool; + +import dev.langchain4j.agent.tool.Tool; + +public class CommonTools { + + @Tool("Search the web for information") + public String webSearch(String query) { + // Implement web search functionality + return "Web search results for: " + query; + } + + @Tool("Get current weather information") + public String getWeather(String location) { + // Implement weather API call + return "Weather information for " + location; + } + + @Tool("Set a reminder") + public String setReminder(String task, String time) { + // Implement reminder functionality + return "Reminder set for " + task + " at " + time; + } + + @Tool("Calculate distances between locations") + public String calculateDistance(String from, String to) { + // Implement distance calculation + return "Distance from " + from + " to " + to; + } + + @Tool("Translate text between languages") + public String translate(String text, String fromLang, String toLang) { + // Implement translation API call + return "Translated text from " + fromLang + " to " + toLang; + } + + @Tool("Get recipe suggestions") + public String getRecipe(String ingredients) { + // Implement recipe suggestion logic + return "Recipe suggestions for: " + ingredients; + } + + @Tool("Check product prices and compare") + public String compareProductPrices(String product) { + // Implement price comparison logic + return "Price comparison for: " + product; + } + + @Tool("Book travel arrangements") + public String bookTravel(String destination, String dates) { + // Implement travel booking logic + return "Travel arrangements for " + destination + " on " + dates; + } + + @Tool("Find nearby restaurants") + public String findRestaurants(String location, String cuisine) { + // Implement restaurant search + return "Restaurants near " + location + " serving " + cuisine; + } + + @Tool("Schedule appointments") + public String scheduleAppointment(String service, String date) { + // Implement appointment scheduling + return "Appointment scheduled for " + service + " on " + date; + } + + @Tool("Get movie recommendations") + public String getMovieRecommendations(String genres, String mood) { + // Implement movie recommendation logic + return "Movie recommendations for " + genres + " matching " + mood + " mood"; + } + + @Tool("Find and book fitness classes") + public String findFitnessClasses(String type, String location) { + // Implement fitness class search and booking + return "Fitness classes for " + type + " near " + location; + } + + @Tool("Get public transport information") + public String getPublicTransport(String from, String to) { + // Implement public transport routing + return "Public transport options from " + from + " to " + to; + } + + @Tool("Track package deliveries") + public String trackPackage(String trackingNumber) { + // Implement package tracking + return "Tracking information for package: " + trackingNumber; + } + + @Tool("Get news updates") + public String getNewsUpdates(String topics) { + // Implement news aggregation + return "Latest news updates on: " + topics; + } + + @Tool("Find and apply for jobs") + public String jobSearch(String field, String location) { + // Implement job search functionality + return "Job openings in " + field + " near " + location; + } + + @Tool("Get health and medical advice") + public String getMedicalAdvice(String symptoms) { + // Implement medical advice lookup (with disclaimer) + return "General health information for symptoms: " + symptoms; + } + + @Tool("Find and book event tickets") + public String findEventTickets(String event, String location) { + // Implement event ticket search and booking + return "Ticket options for " + event + " in " + location; + } + + @Tool("Get financial advice and budgeting tips") + public String getFinancialAdvice(String income, String expenses) { + // Implement financial advice generation + return "Financial advice based on income: " + income + " and expenses: " + expenses; + } + + @Tool("Find and book home services") + public String findHomeServices(String service, String location) { + // Implement home service search and booking + return "Home service options for " + service + " in " + location; + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/tool/ToolRegistry.java b/src/main/java/com/ioa/tool/ToolRegistry.java new file mode 100644 index 0000000..d0626af --- /dev/null +++ b/src/main/java/com/ioa/tool/ToolRegistry.java @@ -0,0 +1,20 @@ +package com.ioa.tool; + +import java.util.HashMap; +import java.util.Map; + +public class ToolRegistry { + private Map tools = new HashMap<>(); + + public void registerTool(String name, Object tool) { + tools.put(name, tool); + } + + public Object getTool(String name) { + return tools.get(name); + } + + public Map getAllTools() { + return new HashMap<>(tools); + } +} \ No newline at end of file diff --git a/src/main/java/com/ioa/util/TreeOfThought.java b/src/main/java/com/ioa/util/TreeOfThought.java new file mode 100644 index 0000000..c5203ba --- /dev/null +++ b/src/main/java/com/ioa/util/TreeOfThought.java @@ -0,0 +1,45 @@ +package com.ioa.util; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; + +public class TreeOfThought { + private final ChatLanguageModel model; + + public TreeOfThought(ChatLanguageModel model) { + this.model = model; + } + + public String reason(String task, int depth, int branches) { + return exploreThought(task, depth, branches, ""); + } + + private String exploreThought(String task, int depth, int branches, String path) { + if (depth == 0) { + return evaluateLeaf(task, path); + } + + StringBuilder result = new StringBuilder(); + for (int i = 0; i < branches; i++) { + String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path + + "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; + Response response = model.generate(branchPrompt); + String thought = response.content(); + + result.append("Branch ").append(i + 1).append(":\n"); + result.append(thought).append("\n"); + result.append(exploreThought(task, depth - 1, branches, path + " -> " + thought)).append("\n\n"); + } + return result.toString(); + } + + private String evaluateLeaf(String task, String path) { + String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path; + Response response = model.generate(prompt); + return response.content(); + } + + public ChatLanguageModel getModel() { + return model; + } +} \ No newline at end of file