Fully working Claude 3.5 on Bedrock code that uses agents
This commit is contained in:
parent
72a7555ac6
commit
a310b2364b
|
@ -62,11 +62,11 @@ public class IoASystem {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
var context = SpringApplication.run(IoASystem.class, args);
|
var context = SpringApplication.run(IoASystem.class, args);
|
||||||
|
|
||||||
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
|
AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
|
||||||
TeamFormation teamFormation = context.getBean(TeamFormation.class);
|
TeamFormation teamFormation = context.getBean(TeamFormation.class);
|
||||||
TaskManager taskManager = context.getBean(TaskManager.class);
|
TaskManager taskManager = context.getBean(TaskManager.class);
|
||||||
|
|
||||||
// Register some example agents
|
// Register some example agents
|
||||||
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
|
AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
|
||||||
Arrays.asList("general", "search"),
|
Arrays.asList("general", "search"),
|
||||||
|
@ -77,23 +77,28 @@ public class IoASystem {
|
||||||
|
|
||||||
agentRegistry.registerAgent(agent1.getId(), agent1);
|
agentRegistry.registerAgent(agent1.getId(), agent1);
|
||||||
agentRegistry.registerAgent(agent2.getId(), agent2);
|
agentRegistry.registerAgent(agent2.getId(), agent2);
|
||||||
|
|
||||||
// Create a sample task
|
// Create a sample task
|
||||||
Task task = new Task("task1", "Plan a weekend trip to Paris",
|
Task task = new Task("task1", "Plan a weekend trip to Paris",
|
||||||
Arrays.asList("travel", "booking"),
|
Arrays.asList("travel", "booking"),
|
||||||
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
|
Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
|
||||||
|
|
||||||
// Form a team for the task
|
// Form a team for the task
|
||||||
List<AgentInfo> team = teamFormation.formTeam(task);
|
List<AgentInfo> team = teamFormation.formTeam(task);
|
||||||
System.out.println("Formed team: " + team);
|
System.out.println("Formed team: " + team);
|
||||||
|
|
||||||
// Assign the task to the first agent in the team (simplified)
|
if (team.isEmpty()) {
|
||||||
|
System.out.println("No suitable agents found for the task. Exiting.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign the task to the first agent in the team
|
||||||
task.setAssignedAgent(team.get(0));
|
task.setAssignedAgent(team.get(0));
|
||||||
|
|
||||||
// Execute the task
|
// Execute the task
|
||||||
taskManager.addTask(task);
|
taskManager.addTask(task);
|
||||||
taskManager.executeTask(task.getId());
|
taskManager.executeTask(task.getId());
|
||||||
|
|
||||||
// Print the result
|
// Print the result
|
||||||
System.out.println("Task result: " + task.getResult());
|
System.out.println("Task result: " + task.getResult());
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,11 +22,11 @@ public class ConversationFSM {
|
||||||
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
|
||||||
"\nCurrent state: " + currentState;
|
"\nCurrent state: " + currentState;
|
||||||
|
|
||||||
String reasoning = model.generate(stateTransitionTask);
|
String reasoning = model.generate(stateTransitionTask, null);
|
||||||
|
|
||||||
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
|
||||||
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
|
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
|
||||||
String response = model.generate(decisionPrompt);
|
String response = model.generate(decisionPrompt, null);
|
||||||
|
|
||||||
ConversationState newState = ConversationState.valueOf(response.trim());
|
ConversationState newState = ConversationState.valueOf(response.trim());
|
||||||
transitionTo(newState);
|
transitionTo(newState);
|
||||||
|
|
|
@ -12,6 +12,10 @@ import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
|
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.Base64;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
public class BedrockLanguageModel {
|
public class BedrockLanguageModel {
|
||||||
private final BedrockRuntimeClient bedrockClient;
|
private final BedrockRuntimeClient bedrockClient;
|
||||||
|
@ -27,19 +31,35 @@ public class BedrockLanguageModel {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String generate(String prompt) {
|
public String generate(String prompt, String imagePath) {
|
||||||
try {
|
try {
|
||||||
ObjectNode requestBody = objectMapper.createObjectNode();
|
ObjectNode requestBody = objectMapper.createObjectNode();
|
||||||
requestBody.put("anthropic_version", "bedrock-2023-05-31");
|
requestBody.put("anthropic_version", "bedrock-2023-05-31");
|
||||||
ArrayNode messages = requestBody.putArray("messages");
|
ArrayNode messages = requestBody.putArray("messages");
|
||||||
ObjectNode message = messages.addObject();
|
ObjectNode message = messages.addObject();
|
||||||
message.put("role", "user");
|
message.put("role", "user");
|
||||||
message.put("content", prompt);
|
|
||||||
|
|
||||||
requestBody.put("max_tokens", 500);
|
requestBody.put("max_tokens", 500);
|
||||||
requestBody.put("temperature", 0.7);
|
requestBody.put("temperature", 0.7);
|
||||||
requestBody.put("top_p", 0.9);
|
requestBody.put("top_p", 0.9);
|
||||||
|
|
||||||
|
ArrayNode content = message.putArray("content");
|
||||||
|
|
||||||
|
if (imagePath != null && !imagePath.isEmpty()) {
|
||||||
|
byte[] imageBytes = Files.readAllBytes(Paths.get(imagePath));
|
||||||
|
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||||
|
|
||||||
|
ObjectNode imageNode = content.addObject();
|
||||||
|
imageNode.put("type", "image"); // Add type field
|
||||||
|
ObjectNode imageContent = imageNode.putObject("image");
|
||||||
|
imageContent.put("format", "png");
|
||||||
|
ObjectNode source = imageContent.putObject("source");
|
||||||
|
source.put("bytes", base64Image);
|
||||||
|
}
|
||||||
|
|
||||||
|
ObjectNode textNode = content.addObject();
|
||||||
|
textNode.put("type", "text"); // Add type field
|
||||||
|
textNode.put("text", prompt);
|
||||||
|
|
||||||
String jsonPayload = objectMapper.writeValueAsString(requestBody);
|
String jsonPayload = objectMapper.writeValueAsString(requestBody);
|
||||||
|
|
||||||
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
|
InvokeModelRequest invokeRequest = InvokeModelRequest.builder()
|
||||||
|
@ -53,7 +73,7 @@ public class BedrockLanguageModel {
|
||||||
String responseBody = response.body().asUtf8String();
|
String responseBody = response.body().asUtf8String();
|
||||||
|
|
||||||
ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody);
|
ObjectNode responseJson = (ObjectNode) objectMapper.readTree(responseBody);
|
||||||
return responseJson.path("content").asText();
|
return responseJson.path("content").get(0).path("text").asText();
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("Error generating text with Bedrock", e);
|
throw new RuntimeException("Error generating text with Bedrock", e);
|
||||||
|
|
|
@ -41,13 +41,13 @@ public class TaskManager {
|
||||||
"\nAssigned agent capabilities: " + agent.getCapabilities() +
|
"\nAssigned agent capabilities: " + agent.getCapabilities() +
|
||||||
"\nAvailable tools: " + agent.getTools();
|
"\nAvailable tools: " + agent.getTools();
|
||||||
|
|
||||||
String reasoning = model.generate(executionPlanningTask);
|
String reasoning = model.generate(executionPlanningTask, null);
|
||||||
|
|
||||||
updateTaskProgress(taskId, "IN_PROGRESS", 50);
|
updateTaskProgress(taskId, "IN_PROGRESS", 50);
|
||||||
|
|
||||||
String executionPrompt = "Based on this execution plan:\n" + reasoning +
|
String executionPrompt = "Based on this execution plan:\n" + reasoning +
|
||||||
"\nExecute the task using the available tools and provide the result.";
|
"\nExecute the task using the available tools and provide the result.";
|
||||||
String response = model.generate(executionPrompt);
|
String response = model.generate(executionPrompt, null);
|
||||||
|
|
||||||
String result = executeToolsFromResponse(response, agent);
|
String result = executeToolsFromResponse(response, agent);
|
||||||
|
|
||||||
|
|
|
@ -25,29 +25,38 @@ public class TeamFormation {
|
||||||
List<String> requiredTools = task.getRequiredTools();
|
List<String> requiredTools = task.getRequiredTools();
|
||||||
List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities);
|
List<AgentInfo> potentialAgents = agentRegistry.searchAgents(requiredCapabilities);
|
||||||
|
|
||||||
|
System.out.println("Potential agents: " + potentialAgents);
|
||||||
|
|
||||||
String teamFormationTask = "Form the best team for this task: " + task.getDescription() +
|
String teamFormationTask = "Form the best team for this task: " + task.getDescription() +
|
||||||
|
"\nRequired capabilities: " + requiredCapabilities +
|
||||||
"\nRequired tools: " + requiredTools +
|
"\nRequired tools: " + requiredTools +
|
||||||
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
|
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents) +
|
||||||
|
"\nPlease respond with a comma-separated list of agent IDs that form the best team for this task.";
|
||||||
|
|
||||||
String reasoning = model.generate(teamFormationTask);
|
System.out.println("Sending prompt to language model: " + teamFormationTask);
|
||||||
|
|
||||||
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
|
String response = model.generate(teamFormationTask, null);
|
||||||
"\nProvide the final team composition as a comma-separated list of agent IDs.";
|
|
||||||
String response = model.generate(finalDecisionPrompt);
|
|
||||||
|
|
||||||
|
System.out.println("Language model response: " + response);
|
||||||
|
|
||||||
return parseTeamComposition(response, potentialAgents);
|
return parseTeamComposition(response, potentialAgents);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String formatAgentTools(List<AgentInfo> agents) {
|
private String formatAgentTools(List<AgentInfo> agents) {
|
||||||
return agents.stream()
|
return agents.stream()
|
||||||
.map(agent -> agent.getId() + ": " + agent.getTools())
|
.map(agent -> agent.getId() + " (capabilities: " + agent.getCapabilities() + ", tools: " + agent.getTools() + ")")
|
||||||
.collect(Collectors.joining(", "));
|
.collect(Collectors.joining(", "));
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AgentInfo> parseTeamComposition(String composition, List<AgentInfo> potentialAgents) {
|
private List<AgentInfo> parseTeamComposition(String composition, List<AgentInfo> potentialAgents) {
|
||||||
List<String> selectedIds = Arrays.asList(composition.split(","));
|
List<String> selectedIds = Arrays.asList(composition.split(","));
|
||||||
return potentialAgents.stream()
|
System.out.println("Parsed agent IDs: " + selectedIds);
|
||||||
.filter(agent -> selectedIds.contains(agent.getId()))
|
|
||||||
|
List<AgentInfo> team = potentialAgents.stream()
|
||||||
|
.filter(agent -> selectedIds.contains(agent.getId().trim()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
System.out.println("Final team: " + team);
|
||||||
|
return team;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -22,7 +22,7 @@ public class TreeOfThought {
|
||||||
for (int i = 0; i < branches; i++) {
|
for (int i = 0; i < branches; i++) {
|
||||||
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
|
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
|
||||||
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
||||||
String thought = model.generate(branchPrompt);
|
String thought = model.generate(branchPrompt, null);
|
||||||
|
|
||||||
result.append("Branch ").append(i + 1).append(":\n");
|
result.append("Branch ").append(i + 1).append(":\n");
|
||||||
result.append(thought).append("\n");
|
result.append(thought).append("\n");
|
||||||
|
@ -33,7 +33,7 @@ public class TreeOfThought {
|
||||||
|
|
||||||
private String evaluateLeaf(String task, String path) {
|
private String evaluateLeaf(String task, String path) {
|
||||||
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
|
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
|
||||||
return model.generate(prompt);
|
return model.generate(prompt, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public BedrockLanguageModel getModel() {
|
public BedrockLanguageModel getModel() {
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user