Fully working Claude 3.5 on Bedrock code that uses agents

This commit is contained in:
Mahesh Kommareddi 2024-07-16 20:57:58 -04:00
parent 72a7555ac6
commit a310b2364b
12 changed files with 61 additions and 27 deletions

View File

@ -62,11 +62,11 @@ public class IoASystem {
public static void main(String[] args) { public static void main(String[] args) {
var context = SpringApplication.run(IoASystem.class, args); var context = SpringApplication.run(IoASystem.class, args);
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class); AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
TeamFormation teamFormation = context.getBean(TeamFormation.class); TeamFormation teamFormation = context.getBean(TeamFormation.class);
TaskManager taskManager = context.getBean(TaskManager.class); TaskManager taskManager = context.getBean(TaskManager.class);
// Register some example agents // Register some example agents
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant", AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
Arrays.asList("general", "search"), Arrays.asList("general", "search"),
@ -77,23 +77,28 @@ public class IoASystem {
agentRegistry.registerAgent(agent1.getId(), agent1); agentRegistry.registerAgent(agent1.getId(), agent1);
agentRegistry.registerAgent(agent2.getId(), agent2); agentRegistry.registerAgent(agent2.getId(), agent2);
// Create a sample task // Create a sample task
Task task = new Task("task1", "Plan a weekend trip to Paris", Task task = new Task("task1", "Plan a weekend trip to Paris",
Arrays.asList("travel", "booking"), Arrays.asList("travel", "booking"),
Arrays.asList("bookTravel", "findRestaurants", "getWeather")); Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
// Form a team for the task // Form a team for the task
List<AgentInfo> team = teamFormation.formTeam(task); List<AgentInfo> team = teamFormation.formTeam(task);
System.out.println("Formed team: " + team); System.out.println("Formed team: " + team);
// Assign the task to the first agent in the team (simplified) if (team.isEmpty()) {
System.out.println("No suitable agents found for the task. Exiting.");
return;
}
// Assign the task to the first agent in the team
task.setAssignedAgent(team.get(0)); task.setAssignedAgent(team.get(0));
// Execute the task // Execute the task
taskManager.addTask(task); taskManager.addTask(task);
taskManager.executeTask(task.getId()); taskManager.executeTask(task.getId());
// Print the result // Print the result
System.out.println("Task result: " + task.getResult()); System.out.println("Task result: " + task.getResult());
} }

View File

@ -22,11 +22,11 @@ public class ConversationFSM {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() + String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState; "\nCurrent state: " + currentState;
String reasoning = model.generate(stateTransitionTask); String reasoning = model.generate(stateTransitionTask, null);
String decisionPrompt = "Based on this reasoning:\n" + reasoning + String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION)."; "\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
String response = model.generate(decisionPrompt); String response = model.generate(decisionPrompt, null);
ConversationState newState = ConversationState.valueOf(response.trim()); ConversationState newState = ConversationState.valueOf(response.trim());
transitionTo(newState); transitionTo(newState);

View File

@ -12,6 +12,10 @@ import com.fasterxml.jackson.databind.node.ArrayNode;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Base64;
@Component @Component
public class BedrockLanguageModel { public class BedrockLanguageModel {
private final BedrockRuntimeClient bedrockClient; private final BedrockRuntimeClient bedrockClient;
@ -27,19 +31,35 @@ public class BedrockLanguageModel {
this.modelId = modelId; this.modelId = modelId;
} }
public String generate(String prompt) { public String generate(String prompt, String imagePath) {
try { try {
ObjectNode requestBody = objectMapper.createObjectNode(); ObjectNode requestBody = objectMapper.createObjectNode();
requestBody.put("anthropic_version", "bedrock-2023-05-31"); requestBody.put("anthropic_version", "bedrock-2023-05-31");
ArrayNode messages = requestBody.putArray("messages"); ArrayNode messages = requestBody.putArray("messages");
ObjectNode message = messages.addObject(); ObjectNode message = messages.addObject();
message.put("role", "user"); message.put("role", "user");
message.put("content", prompt);
requestBody.put("max_tokens", 500); requestBody.put("max_tokens", 500);
requestBody.put("temperature", 0.7); requestBody.put("temperature", 0.7);
requestBody.put("top_p", 0.9); requestBody.put("top_p", 0.9);
ArrayNode content = message.putArray("content");
if (imagePath != null && !imagePath.isEmpty()) {
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
ObjectNode imageNode = content.addObject();
imageNode.put("type", "image"); // Add type field
ObjectNode imageContent = imageNode.putObject("image");
imageContent.put("format", "png");
ObjectNode source = imageContent.putObject("source");
source.put("bytes", base64Image);
}
ObjectNode textNode = content.addObject();
textNode.put("type", "text"); // Add type field
textNode.put("text", prompt);
String jsonPayload = objectMapper.writeValueAsString(requestBody); String jsonPayload = objectMapper.writeValueAsString(requestBody);
InvokeModelRequest invokeRequest = InvokeModelRequest.builder() InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
@ -53,7 +73,7 @@ public class BedrockLanguageModel {
String responseBody = response.body().asUtf8String(); String responseBody = response.body().asUtf8String();
ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody); ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody);
return responseJson.path("content").asText(); return responseJson.path("content").get(0).path("text").asText();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error generating text with Bedrock", e); throw new RuntimeException("Error generating text with Bedrock", e);

View File

@ -41,13 +41,13 @@ public class TaskManager {
"\nAssigned agent capabilities: " + agent.getCapabilities() + "\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools(); "\nAvailable tools: " + agent.getTools();
String reasoning = model.generate(executionPlanningTask); String reasoning = model.generate(executionPlanningTask, null);
updateTaskProgress(taskId, "IN_PROGRESS", 50); updateTaskProgress(taskId, "IN_PROGRESS", 50);
String executionPrompt = "Based on this execution plan:\n" + reasoning + String executionPrompt = "Based on this execution plan:\n" + reasoning +
"\nExecute the task using the available tools and provide the result."; "\nExecute the task using the available tools and provide the result.";
String response = model.generate(executionPrompt); String response = model.generate(executionPrompt, null);
String result = executeToolsFromResponse(response, agent); String result = executeToolsFromResponse(response, agent);

View File

@ -25,29 +25,38 @@ public class TeamFormation {
List<String> requiredTools = task.getRequiredTools(); List<String> requiredTools = task.getRequiredTools();
List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities); List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities);
System.out.println("Potential agents: " + potentialAgents);
String teamFormationTask = "Form the best team for this task: " + task.getDescription() + String teamFormationTask = "Form the best team for this task: " + task.getDescription() +
"\nRequired capabilities: " + requiredCapabilities +
"\nRequired tools: " + requiredTools + "\nRequired tools: " + requiredTools +
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents); "\nAvailable agents and their tools: " + formatAgentTools(potentialAgents) +
"\nPlease respond with a comma-separated list of agent IDs that form the best team for this task.";
String reasoning = model.generate(teamFormationTask); System.out.println("Sending prompt to language model: " + teamFormationTask);
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning + String response = model.generate(teamFormationTask, null);
"\nProvide the final team composition as a comma-separated list of agent IDs.";
String response = model.generate(finalDecisionPrompt);
System.out.println("Language model response: " + response);
return parseTeamComposition(response, potentialAgents); return parseTeamComposition(response, potentialAgents);
} }
private String formatAgentTools(List<AgentInfo> agents) { private String formatAgentTools(List<AgentInfo> agents) {
return agents.stream() return agents.stream()
.map(agent -> agent.getId() + ": " + agent.getTools()) .map(agent -> agent.getId() + " (capabilities: " + agent.getCapabilities() + ", tools: " + agent.getTools() + ")")
.collect(Collectors.joining(", ")); .collect(Collectors.joining(", "));
} }
private List<AgentInfo> parseTeamComposition(String composition, List<AgentInfo> potentialAgents) { private List<AgentInfo> parseTeamComposition(String composition, List<AgentInfo> potentialAgents) {
List<String> selectedIds = Arrays.asList(composition.split(",")); List<String> selectedIds = Arrays.asList(composition.split(","));
return potentialAgents.stream() System.out.println("Parsed agent IDs: " + selectedIds);
.filter(agent -> selectedIds.contains(agent.getId()))
List<AgentInfo> team = potentialAgents.stream()
.filter(agent -> selectedIds.contains(agent.getId().trim()))
.collect(Collectors.toList()); .collect(Collectors.toList());
System.out.println("Final team: " + team);
return team;
} }
} }

View File

@ -22,7 +22,7 @@ public class TreeOfThought {
for (int i = 0; i < branches; i++) { for (int i = 0; i < branches; i++) {
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path + String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
String thought = model.generate(branchPrompt); String thought = model.generate(branchPrompt, null);
result.append("Branch ").append(i + 1).append(":\n"); result.append("Branch ").append(i + 1).append(":\n");
result.append(thought).append("\n"); result.append(thought).append("\n");
@ -33,7 +33,7 @@ public class TreeOfThought {
private String evaluateLeaf(String task, String path) { private String evaluateLeaf(String task, String path) {
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path; String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
return model.generate(prompt); return model.generate(prompt, null);
} }
public BedrockLanguageModel getModel() { public BedrockLanguageModel getModel() {