First commit

This commit is contained in:
Mahesh Kommareddi 2024-07-16 19:25:40 -04:00
commit a97cd853fe
15 changed files with 658 additions and 0 deletions

12
pom.xml Normal file
View File

@ -0,0 +1,12 @@
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrock-runtime</artifactId>
<version>2.20.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.13.0</version>
</dependency>
</dependencies>

68
src/main/Main.java Normal file
View File

@ -0,0 +1,68 @@
package com.ioa;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
public class Main {
public static void main(String[] args) {
// Initialize the system
ToolRegistry toolRegistry = new ToolRegistry();
CommonTools commonTools = new CommonTools();
// Register all tools from CommonTools
for (Method method : CommonTools.class.getMethods()) {
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
toolRegistry.registerTool(method.getName(), method);
}
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
// Register some example agents
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"),
Arrays.asList("webSearch", "getWeather", "setReminder"));
AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
agentRegistry.registerAgent(agent1.getId(), agent1);
agentRegistry.registerAgent(agent2.getId(), agent2);
// Create a sample task
Task task = new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
// Form a team for the task
List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team);
// Assign the task to the first agent in the team (simplified)
task.setAssignedAgent(team.get(0));
// Execute the task
taskManager.addTask(task);
taskManager.executeTask(task.getId());
// Print the result
System.out.println("Task result: " + task.getResult());
}
}

View File

@ -0,0 +1,35 @@
package com.ioa;
import com.ioa.agent.AgentRegistry;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.lang.reflect.Method;
public class IoASystem {
public static void initialize() {
ToolRegistry toolRegistry = new ToolRegistry();
CommonTools commonTools = new CommonTools();
// Register all tools from CommonTools
for (Method method : CommonTools.class.getMethods()) {
if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
toolRegistry.registerTool(method.getName(), method);
}
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
ChatLanguageModel model = OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
// Initialize other components as needed
}
}

View File

@ -0,0 +1,28 @@
package com.ioa.agent;
import java.util.List;
public class AgentInfo {
private String id;
private String name;
private List<String> capabilities;
private List<String> tools;
// Constructor
public AgentInfo(String id, String name, List<String> capabilities, List<String> tools) {
this.id = id;
this.name = name;
this.capabilities = capabilities;
this.tools = tools;
}
// Getters and setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public List<String> getCapabilities() { return capabilities; }
public void setCapabilities(List<String> capabilities) { this.capabilities = capabilities; }
public List<String> getTools() { return tools; }
public void setTools(List<String> tools) { this.tools = tools; }
}

View File

@ -0,0 +1,37 @@
package com.ioa.agent;
import com.ioa.tool.ToolRegistry;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class AgentRegistry {
private Map<String, AgentInfo> agents = new HashMap<>();
private ToolRegistry toolRegistry;
public AgentRegistry(ToolRegistry toolRegistry) {
this.toolRegistry = toolRegistry;
}
public void registerAgent(String agentId, AgentInfo agentInfo) {
agents.put(agentId, agentInfo);
// Register agent's tools
for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool);
}
}
}
public AgentInfo getAgent(String agentId) {
return agents.get(agentId);
}
public List<AgentInfo> searchAgents(List<String> capabilities) {
return agents.values().stream()
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
.collect(Collectors.toList());
}
}

View File

@ -0,0 +1,66 @@
package com.ioa.conversation;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
public class ConversationFSM {
private ConversationState currentState;
private TreeOfThought treeOfThought;
public ConversationFSM(ChatLanguageModel model) {
this.currentState = ConversationState.DISCUSSION;
this.treeOfThought = new TreeOfThought(model);
}
public void handleMessage(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState;
String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
Response<String> response = treeOfThought.getModel().generate(decisionPrompt);
ConversationState newState = ConversationState.valueOf(response.content().trim());
transitionTo(newState);
// Handle the message based on the new state
switch (newState) {
case DISCUSSION:
handleDiscussionMessage(message);
break;
case TASK_ASSIGNMENT:
handleTaskAssignmentMessage(message);
break;
case EXECUTION:
handleExecutionMessage(message);
break;
case CONCLUSION:
handleConclusionMessage(message);
break;
}
}
private void transitionTo(ConversationState newState) {
// Add any transition logic here
this.currentState = newState;
}
private void handleDiscussionMessage(Message message) {
// Implement discussion logic
}
private void handleTaskAssignmentMessage(Message message) {
// Implement task assignment logic
}
private void handleExecutionMessage(Message message) {
// Implement execution logic
}
private void handleConclusionMessage(Message message) {
// Implement conclusion logic
}
}

View File

@ -0,0 +1,5 @@
package com.ioa.conversation;
public enum ConversationState {
DISCUSSION, TASK_ASSIGNMENT, EXECUTION, CONCLUSION
}

View File

@ -0,0 +1,14 @@
package com.ioa.conversation;
public class Message {
private String sender;
private String content;
public Message(String sender, String content) {
this.sender = sender;
this.content = content;
}
public String getSender() { return sender; }
public String getContent() { return content; }
}

View File

@ -0,0 +1,54 @@
package com.ioa.model;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import software.amazon.awssdk.core.SdkBytes;
import java.util.Map;
import java.util.HashMap;
public class BedrockLanguageModel {
private final BedrockRuntimeClient bedrockClient;
private final ObjectMapper objectMapper;
private final String modelId;
public BedrockLanguageModel(String modelId) {
this.bedrockClient = BedrockRuntimeClient.builder()
.region(Region.US_EAST_1)
.credentialsProvider(ProfileCredentialsProvider.create())
.build();
this.objectMapper = new ObjectMapper();
this.modelId = modelId;
}
public String generate(String prompt) {
try {
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("prompt", prompt);
requestBody.put("max_tokens_to_sample", 500);
requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9);
String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeModelRequest = InvokeModelRequest.builder()
.modelId(modelId)
.contentType("application/json")
.accept("application/json")
.body(SdkBytes.fromUtf8String(jsonPayload))
.build();
InvokeModelResponse response = bedrockClient.invokeModel(invokeModelRequest);
String responseBody = response.body().asUtf8String();
Map<String, Object> responseMap = objectMapper.readValue(responseBody, Map.class);
return (String) responseMap.get("completion");
} catch (Exception e) {
throw new RuntimeException("Error generating text with Bedrock", e);
}
}
}

View File

@ -0,0 +1,32 @@
package com.ioa.task;
import com.ioa.agent.AgentInfo;
import java.util.List;
public class Task {
private String id;
private String description;
private List<String> requiredCapabilities;
private List<String> requiredTools;
private AgentInfo assignedAgent;
private String result;
// Constructor
public Task(String id, String description, List<String> requiredCapabilities, List<String> requiredTools) {
this.id = id;
this.description = description;
this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools;
}
// Getters and setters
public String getId() { return id; }
public String getDescription() { return description; }
public List<String> getRequiredCapabilities() { return requiredCapabilities; }
public List<String> getRequiredTools() { return requiredTools; }
public AgentInfo getAssignedAgent() { return assignedAgent; }
public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
public String getResult() { return result; }
public void setResult(String result) { this.result = result; }
}

View File

@ -0,0 +1,63 @@
package com.ioa.task;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.tool.ToolRegistry;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.Map;
public class TaskManager {
private Map<String, Task> tasks = new HashMap<>();
private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought;
private ToolRegistry toolRegistry;
public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) {
this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model);
this.toolRegistry = toolRegistry;
}
public void executeTask(String taskId) {
Task task = tasks.get(taskId);
AgentInfo agent = task.getAssignedAgent();
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
"\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools();
String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3);
String executionPrompt = "Based on this execution plan:\n" + reasoning +
"\nExecute the task using the available tools and provide the result.";
Response<String> response = treeOfThought.getModel().generate(executionPrompt);
String result = executeToolsFromResponse(response.content(), agent);
task.setResult(result);
}
private String executeToolsFromResponse(String response, AgentInfo agent) {
StringBuilder result = new StringBuilder();
for (String tool : agent.getTools()) {
if (response.contains(tool)) {
Object toolInstance = toolRegistry.getTool(tool);
// Execute the tool (this is a simplified representation)
result.append(tool).append(" result: ").append(toolInstance.toString()).append("\n");
}
}
return result.toString();
}
public void addTask(Task task) {
tasks.put(task.getId(), task);
}
public Task getTask(String taskId) {
return tasks.get(taskId);
}
}

View File

@ -0,0 +1,53 @@
package com.ioa.team;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
import com.ioa.task.Task;
import com.ioa.util.TreeOfThought;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class TeamFormation {
private AgentRegistry agentRegistry;
private TreeOfThought treeOfThought;
public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) {
this.agentRegistry = agentRegistry;
this.treeOfThought = new TreeOfThought(model);
}
public List<AgentInfo> formTeam(Task task) {
List<String> requiredCapabilities = task.getRequiredCapabilities();
List<String> requiredTools = task.getRequiredTools();
List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities);
String teamFormationTask = "Form the best team for this task: " + task.getDescription() +
"\nRequired tools: " + requiredTools +
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
String reasoning = treeOfThought.reason(teamFormationTask, 3, 3);
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the final team composition as a comma-separated list of agent IDs.";
Response<String> response = treeOfThought.getModel().generate(finalDecisionPrompt);
return parseTeamComposition(response.content(), potentialAgents);
}
private String formatAgentTools(List<AgentInfo> agents) {
return agents.stream()
.map(agent -> agent.getId() + ": " + agent.getTools())
.collect(Collectors.joining(", "));
}
private List<AgentInfo> parseTeamComposition(String composition, List<AgentInfo> potentialAgents) {
List<String> selectedIds = Arrays.asList(composition.split(","));
return potentialAgents.stream()
.filter(agent -> selectedIds.contains(agent.getId()))
.collect(Collectors.toList());
}
}

View File

@ -0,0 +1,126 @@
package com.ioa.tool;
import dev.langchain4j.agent.tool.Tool;
public class CommonTools {
@Tool("Search the web for information")
public String webSearch(String query) {
// Implement web search functionality
return "Web search results for: " + query;
}
@Tool("Get current weather information")
public String getWeather(String location) {
// Implement weather API call
return "Weather information for " + location;
}
@Tool("Set a reminder")
public String setReminder(String task, String time) {
// Implement reminder functionality
return "Reminder set for " + task + " at " + time;
}
@Tool("Calculate distances between locations")
public String calculateDistance(String from, String to) {
// Implement distance calculation
return "Distance from " + from + " to " + to;
}
@Tool("Translate text between languages")
public String translate(String text, String fromLang, String toLang) {
// Implement translation API call
return "Translated text from " + fromLang + " to " + toLang;
}
@Tool("Get recipe suggestions")
public String getRecipe(String ingredients) {
// Implement recipe suggestion logic
return "Recipe suggestions for: " + ingredients;
}
@Tool("Check product prices and compare")
public String compareProductPrices(String product) {
// Implement price comparison logic
return "Price comparison for: " + product;
}
@Tool("Book travel arrangements")
public String bookTravel(String destination, String dates) {
// Implement travel booking logic
return "Travel arrangements for " + destination + " on " + dates;
}
@Tool("Find nearby restaurants")
public String findRestaurants(String location, String cuisine) {
// Implement restaurant search
return "Restaurants near " + location + " serving " + cuisine;
}
@Tool("Schedule appointments")
public String scheduleAppointment(String service, String date) {
// Implement appointment scheduling
return "Appointment scheduled for " + service + " on " + date;
}
@Tool("Get movie recommendations")
public String getMovieRecommendations(String genres, String mood) {
// Implement movie recommendation logic
return "Movie recommendations for " + genres + " matching " + mood + " mood";
}
@Tool("Find and book fitness classes")
public String findFitnessClasses(String type, String location) {
// Implement fitness class search and booking
return "Fitness classes for " + type + " near " + location;
}
@Tool("Get public transport information")
public String getPublicTransport(String from, String to) {
// Implement public transport routing
return "Public transport options from " + from + " to " + to;
}
@Tool("Track package deliveries")
public String trackPackage(String trackingNumber) {
// Implement package tracking
return "Tracking information for package: " + trackingNumber;
}
@Tool("Get news updates")
public String getNewsUpdates(String topics) {
// Implement news aggregation
return "Latest news updates on: " + topics;
}
@Tool("Find and apply for jobs")
public String jobSearch(String field, String location) {
// Implement job search functionality
return "Job openings in " + field + " near " + location;
}
@Tool("Get health and medical advice")
public String getMedicalAdvice(String symptoms) {
// Implement medical advice lookup (with disclaimer)
return "General health information for symptoms: " + symptoms;
}
@Tool("Find and book event tickets")
public String findEventTickets(String event, String location) {
// Implement event ticket search and booking
return "Ticket options for " + event + " in " + location;
}
@Tool("Get financial advice and budgeting tips")
public String getFinancialAdvice(String income, String expenses) {
// Implement financial advice generation
return "Financial advice based on income: " + income + " and expenses: " + expenses;
}
@Tool("Find and book home services")
public String findHomeServices(String service, String location) {
// Implement home service search and booking
return "Home service options for " + service + " in " + location;
}
}

View File

@ -0,0 +1,20 @@
package com.ioa.tool;
import java.util.HashMap;
import java.util.Map;
public class ToolRegistry {
private Map<String, Object> tools = new HashMap<>();
public void registerTool(String name, Object tool) {
tools.put(name, tool);
}
public Object getTool(String name) {
return tools.get(name);
}
public Map<String, Object> getAllTools() {
return new HashMap<>(tools);
}
}

View File

@ -0,0 +1,45 @@
package com.ioa.util;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
public class TreeOfThought {
private final ChatLanguageModel model;
public TreeOfThought(ChatLanguageModel model) {
this.model = model;
}
public String reason(String task, int depth, int branches) {
return exploreThought(task, depth, branches, "");
}
private String exploreThought(String task, int depth, int branches, String path) {
if (depth == 0) {
return evaluateLeaf(task, path);
}
StringBuilder result = new StringBuilder();
for (int i = 0; i < branches; i++) {
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
Response<String> response = model.generate(branchPrompt);
String thought = response.content();
result.append("Branch ").append(i + 1).append(":\n");
result.append(thought).append("\n");
result.append(exploreThought(task, depth - 1, branches, path + " -> " + thought)).append("\n\n");
}
return result.toString();
}
private String evaluateLeaf(String task, String path) {
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
Response<String> response = model.generate(prompt);
return response.content();
}
public ChatLanguageModel getModel() {
return model;
}
}