diff --git a/pom.xml b/pom.xml
index f3d8047..6f5cbfc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1,12 +1,101 @@
-
-
- software.amazon.awssdk
- bedrock-runtime
- 2.20.0
-
-
- com.fasterxml.jackson.core
- jackson-databind
- 2.13.0
-
-
\ No newline at end of file
+
+
+ 4.0.0
+
+ com.ioa
+ ioa-system
+ 1.0-SNAPSHOT
+
+
+ 11
+ 2.5.5
+ 2.26.9
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-dependencies
+ ${spring-boot.version}
+ pom
+ import
+
+
+ software.amazon.awssdk
+ bom
+ ${aws.sdk.version}
+ pom
+ import
+
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-web
+
+
+ org.springframework.boot
+ spring-boot-starter-websocket
+
+
+ software.amazon.awssdk
+ bedrockruntime
+ ${aws.sdk.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+ org.projectlombok
+ lombok
+ true
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-maven-plugin
+ ${spring-boot.version}
+
+ com.ioa.IoASystem
+
+
+ org.projectlombok
+ lombok
+
+
+
+
+
+
+ repackage
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.8.1
+
+
+ ${java.version}
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/main/java/com/ioa/IoASystem.java b/src/main/java/com/ioa/IoASystem.java
index 69d5db0..f4abdd0 100644
--- a/src/main/java/com/ioa/IoASystem.java
+++ b/src/main/java/com/ioa/IoASystem.java
@@ -1,35 +1,100 @@
package com.ioa;
+import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
+import com.ioa.task.Task;
import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.openai.OpenAiChatModel;
+import com.ioa.model.BedrockLanguageModel;
+import com.ioa.service.WebSocketService;
+import com.ioa.tool.Tool;
+
+import org.springframework.boot.SpringApplication;
+import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.context.annotation.Bean;
import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.List;
+@SpringBootApplication
public class IoASystem {
- public static void initialize() {
+
+ @Bean
+ public ToolRegistry toolRegistry() {
ToolRegistry toolRegistry = new ToolRegistry();
CommonTools commonTools = new CommonTools();
- // Register all tools from CommonTools
for (Method method : CommonTools.class.getMethods()) {
- if (method.isAnnotationPresent(dev.langchain4j.agent.tool.Tool.class)) {
+ if (method.isAnnotationPresent(Tool.class)) {
toolRegistry.registerTool(method.getName(), method);
}
}
+ return toolRegistry;
+ }
- AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
- ChatLanguageModel model = OpenAiChatModel.builder()
- .apiKey(System.getenv("OPENAI_API_KEY"))
- .build();
+ @Bean
+ public AgentRegistry agentRegistry(ToolRegistry toolRegistry) {
+ return new AgentRegistry(toolRegistry);
+ }
+
+ @Bean
+ public BedrockLanguageModel bedrockLanguageModel() {
+ return new BedrockLanguageModel("anthropic.claude-v2");
+ }
+
+ @Bean
+ public TeamFormation teamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
+ return new TeamFormation(agentRegistry, model);
+ }
+
+ @Bean
+ public TaskManager taskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
+ return new TaskManager(agentRegistry, model, toolRegistry);
+ }
+
+ @Bean
+ public WebSocketService webSocketService() {
+ return new WebSocketService();
+ }
+
+ public static void main(String[] args) {
+ var context = SpringApplication.run(IoASystem.class, args);
+
+ AgentRegistry agentRegistry = context.getBean(AgentRegistry.class);
+ TeamFormation teamFormation = context.getBean(TeamFormation.class);
+ TaskManager taskManager = context.getBean(TaskManager.class);
+
+ // Register some example agents
+ AgentInfo agent1 = new AgentInfo("agent1", "General Assistant",
+ Arrays.asList("general", "search"),
+ Arrays.asList("webSearch", "getWeather", "setReminder"));
+ AgentInfo agent2 = new AgentInfo("agent2", "Travel Expert",
+ Arrays.asList("travel", "booking"),
+ Arrays.asList("bookTravel", "calculateDistance", "findRestaurants"));
- TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
- TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
-
- // Initialize other components as needed
+ agentRegistry.registerAgent(agent1.getId(), agent1);
+ agentRegistry.registerAgent(agent2.getId(), agent2);
+
+ // Create a sample task
+ Task task = new Task("task1", "Plan a weekend trip to Paris",
+ Arrays.asList("travel", "booking"),
+ Arrays.asList("bookTravel", "findRestaurants", "getWeather"));
+
+ // Form a team for the task
+ List team = teamFormation.formTeam(task);
+ System.out.println("Formed team: " + team);
+
+ // Assign the task to the first agent in the team (simplified)
+ task.setAssignedAgent(team.get(0));
+
+ // Execute the task
+ taskManager.addTask(task);
+ taskManager.executeTask(task.getId());
+
+ // Print the result
+ System.out.println("Task result: " + task.getResult());
}
}
\ No newline at end of file
diff --git a/src/main/Main.java b/src/main/java/com/ioa/Main.java
similarity index 90%
rename from src/main/Main.java
rename to src/main/java/com/ioa/Main.java
index 1ad7458..34f4695 100644
--- a/src/main/Main.java
+++ b/src/main/java/com/ioa/Main.java
@@ -7,8 +7,7 @@ import com.ioa.task.TaskManager;
import com.ioa.team.TeamFormation;
import com.ioa.tool.CommonTools;
import com.ioa.tool.ToolRegistry;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.openai.OpenAiChatModel;
+import com.ioa.model.BedrockLanguageModel;
import java.lang.reflect.Method;
import java.util.Arrays;
@@ -28,9 +27,7 @@ public class Main {
}
AgentRegistry agentRegistry = new AgentRegistry(toolRegistry);
- ChatLanguageModel model = OpenAiChatModel.builder()
- .apiKey(System.getenv("OPENAI_API_KEY"))
- .build();
+ BedrockLanguageModel model = new BedrockLanguageModel("anthropic.claude-v2"); // or another model ID
TeamFormation teamFormation = new TeamFormation(agentRegistry, model);
TaskManager taskManager = new TaskManager(agentRegistry, model, toolRegistry);
@@ -65,4 +62,4 @@ public class Main {
// Print the result
System.out.println("Task result: " + task.getResult());
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/com/ioa/agent/AgentInfo.java b/src/main/java/com/ioa/agent/AgentInfo.java
index acd3a68..55b156d 100644
--- a/src/main/java/com/ioa/agent/AgentInfo.java
+++ b/src/main/java/com/ioa/agent/AgentInfo.java
@@ -1,28 +1,19 @@
package com.ioa.agent;
+import lombok.Data;
import java.util.List;
+@Data
public class AgentInfo {
private String id;
private String name;
private List capabilities;
private List tools;
- // Constructor
public AgentInfo(String id, String name, List capabilities, List tools) {
this.id = id;
this.name = name;
this.capabilities = capabilities;
this.tools = tools;
}
-
- // Getters and setters
- public String getId() { return id; }
- public void setId(String id) { this.id = id; }
- public String getName() { return name; }
- public void setName(String name) { this.name = name; }
- public List getCapabilities() { return capabilities; }
- public void setCapabilities(List capabilities) { this.capabilities = capabilities; }
- public List getTools() { return tools; }
- public void setTools(List tools) { this.tools = tools; }
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/agent/AgentRegistry.java b/src/main/java/com/ioa/agent/AgentRegistry.java
index 6bda082..272f8f4 100644
--- a/src/main/java/com/ioa/agent/AgentRegistry.java
+++ b/src/main/java/com/ioa/agent/AgentRegistry.java
@@ -1,12 +1,14 @@
package com.ioa.agent;
import com.ioa.tool.ToolRegistry;
+import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
+@Component
public class AgentRegistry {
private Map agents = new HashMap<>();
private ToolRegistry toolRegistry;
@@ -17,7 +19,7 @@ public class AgentRegistry {
public void registerAgent(String agentId, AgentInfo agentInfo) {
agents.put(agentId, agentInfo);
- // Register agent's tools
+ // Verify that all tools the agent claims to have are registered
for (String tool : agentInfo.getTools()) {
if (toolRegistry.getTool(tool) == null) {
throw new IllegalArgumentException("Tool not found in registry: " + tool);
@@ -34,4 +36,8 @@ public class AgentRegistry {
.filter(agent -> agent.getCapabilities().containsAll(capabilities))
.collect(Collectors.toList());
}
+
+ public List getAllAgents() {
+ return List.copyOf(agents.values());
+ }
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/config/WebSocketConfig.java b/src/main/java/com/ioa/config/WebSocketConfig.java
new file mode 100644
index 0000000..70919ec
--- /dev/null
+++ b/src/main/java/com/ioa/config/WebSocketConfig.java
@@ -0,0 +1,23 @@
+package com.ioa.config;
+
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.simp.config.MessageBrokerRegistry;
+import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
+import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
+import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
+
+@Configuration
+@EnableWebSocketMessageBroker
+public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
+
+ @Override
+ public void configureMessageBroker(MessageBrokerRegistry config) {
+ config.enableSimpleBroker("/topic");
+ config.setApplicationDestinationPrefixes("/app");
+ }
+
+ @Override
+ public void registerStompEndpoints(StompEndpointRegistry registry) {
+ registry.addEndpoint("/ws").withSockJS();
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/conversation/ConversationFSM.java b/src/main/java/com/ioa/conversation/ConversationFSM.java
index 08f5a7b..893871c 100644
--- a/src/main/java/com/ioa/conversation/ConversationFSM.java
+++ b/src/main/java/com/ioa/conversation/ConversationFSM.java
@@ -1,29 +1,34 @@
package com.ioa.conversation;
-import com.ioa.util.TreeOfThought;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.output.Response;
+import com.ioa.model.BedrockLanguageModel;
+import com.ioa.service.WebSocketService;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+@Component
public class ConversationFSM {
private ConversationState currentState;
- private TreeOfThought treeOfThought;
+ private BedrockLanguageModel model;
- public ConversationFSM(ChatLanguageModel model) {
+ @Autowired
+ private WebSocketService webSocketService;
+
+ public ConversationFSM(BedrockLanguageModel model) {
this.currentState = ConversationState.DISCUSSION;
- this.treeOfThought = new TreeOfThought(model);
+ this.model = model;
}
public void handleMessage(Message message) {
String stateTransitionTask = "Decide the next conversation state based on this message: " + message.getContent() +
"\nCurrent state: " + currentState;
- String reasoning = treeOfThought.reason(stateTransitionTask, 2, 3);
+ String reasoning = model.generate(stateTransitionTask);
String decisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the next conversation state (DISCUSSION, TASK_ASSIGNMENT, EXECUTION, or CONCLUSION).";
- Response response = treeOfThought.getModel().generate(decisionPrompt);
+ String response = model.generate(decisionPrompt);
- ConversationState newState = ConversationState.valueOf(response.content().trim());
+ ConversationState newState = ConversationState.valueOf(response.trim());
transitionTo(newState);
// Handle the message based on the new state
@@ -44,8 +49,8 @@ public class ConversationFSM {
}
private void transitionTo(ConversationState newState) {
- // Add any transition logic here
this.currentState = newState;
+ webSocketService.sendUpdate("conversation_state", new ConversationStateUpdate(currentState));
}
private void handleDiscussionMessage(Message message) {
@@ -63,4 +68,12 @@ public class ConversationFSM {
private void handleConclusionMessage(Message message) {
// Implement conclusion logic
}
+
+ private class ConversationStateUpdate {
+ public ConversationState state;
+
+ ConversationStateUpdate(ConversationState state) {
+ this.state = state;
+ }
+ }
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/conversation/Message.java b/src/main/java/com/ioa/conversation/Message.java
index 278e4f1..b5d825b 100644
--- a/src/main/java/com/ioa/conversation/Message.java
+++ b/src/main/java/com/ioa/conversation/Message.java
@@ -1,5 +1,8 @@
package com.ioa.conversation;
+import lombok.Data;
+
+@Data
public class Message {
private String sender;
private String content;
@@ -8,7 +11,4 @@ public class Message {
this.sender = sender;
this.content = content;
}
-
- public String getSender() { return sender; }
- public String getContent() { return content; }
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/model/BedrockLanguageModel.java b/src/main/java/com/ioa/model/BedrockLanguageModel.java
index a380393..bb530cd 100644
--- a/src/main/java/com/ioa/model/BedrockLanguageModel.java
+++ b/src/main/java/com/ioa/model/BedrockLanguageModel.java
@@ -1,16 +1,19 @@
package com.ioa.model;
-import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import software.amazon.awssdk.core.SdkBytes;
+import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.Map;
import java.util.HashMap;
+import org.springframework.stereotype.Component;
+
+@Component
public class BedrockLanguageModel {
private final BedrockRuntimeClient bedrockClient;
private final ObjectMapper objectMapper;
diff --git a/src/main/java/com/ioa/service/WebSocketService.java b/src/main/java/com/ioa/service/WebSocketService.java
new file mode 100644
index 0000000..4e144b8
--- /dev/null
+++ b/src/main/java/com/ioa/service/WebSocketService.java
@@ -0,0 +1,16 @@
+package com.ioa.service;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.messaging.simp.SimpMessagingTemplate;
+import org.springframework.stereotype.Service;
+
+@Service
+public class WebSocketService {
+
+ @Autowired
+ private SimpMessagingTemplate messagingTemplate;
+
+ public void sendUpdate(String topic, Object payload) {
+ messagingTemplate.convertAndSend("/topic/" + topic, payload);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/task/Task.java b/src/main/java/com/ioa/task/Task.java
index 6cf5677..c6457fa 100644
--- a/src/main/java/com/ioa/task/Task.java
+++ b/src/main/java/com/ioa/task/Task.java
@@ -1,9 +1,11 @@
package com.ioa.task;
import com.ioa.agent.AgentInfo;
+import lombok.Data;
import java.util.List;
+@Data
public class Task {
private String id;
private String description;
@@ -12,21 +14,10 @@ public class Task {
private AgentInfo assignedAgent;
private String result;
- // Constructor
public Task(String id, String description, List requiredCapabilities, List requiredTools) {
this.id = id;
this.description = description;
this.requiredCapabilities = requiredCapabilities;
this.requiredTools = requiredTools;
}
-
- // Getters and setters
- public String getId() { return id; }
- public String getDescription() { return description; }
- public List getRequiredCapabilities() { return requiredCapabilities; }
- public List getRequiredTools() { return requiredTools; }
- public AgentInfo getAssignedAgent() { return assignedAgent; }
- public void setAssignedAgent(AgentInfo assignedAgent) { this.assignedAgent = assignedAgent; }
- public String getResult() { return result; }
- public void setResult(String result) { this.result = result; }
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/task/TaskManager.java b/src/main/java/com/ioa/task/TaskManager.java
index 990e9d0..33169ac 100644
--- a/src/main/java/com/ioa/task/TaskManager.java
+++ b/src/main/java/com/ioa/task/TaskManager.java
@@ -2,43 +2,57 @@ package com.ioa.task;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
+import com.ioa.model.BedrockLanguageModel;
+import com.ioa.service.WebSocketService;
import com.ioa.tool.ToolRegistry;
-import com.ioa.util.TreeOfThought;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.output.Response;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
+@Component
public class TaskManager {
private Map tasks = new HashMap<>();
private AgentRegistry agentRegistry;
- private TreeOfThought treeOfThought;
+ private BedrockLanguageModel model;
private ToolRegistry toolRegistry;
- public TaskManager(AgentRegistry agentRegistry, ChatLanguageModel model, ToolRegistry toolRegistry) {
+ @Autowired
+ private WebSocketService webSocketService;
+
+ public TaskManager(AgentRegistry agentRegistry, BedrockLanguageModel model, ToolRegistry toolRegistry) {
this.agentRegistry = agentRegistry;
- this.treeOfThought = new TreeOfThought(model);
+ this.model = model;
this.toolRegistry = toolRegistry;
}
+ public void addTask(Task task) {
+ tasks.put(task.getId(), task);
+ }
+
public void executeTask(String taskId) {
Task task = tasks.get(taskId);
AgentInfo agent = task.getAssignedAgent();
+ updateTaskProgress(taskId, "STARTED", 0);
+
String executionPlanningTask = "Plan the execution of this task: " + task.getDescription() +
"\nAssigned agent capabilities: " + agent.getCapabilities() +
"\nAvailable tools: " + agent.getTools();
- String reasoning = treeOfThought.reason(executionPlanningTask, 3, 3);
+ String reasoning = model.generate(executionPlanningTask);
+
+ updateTaskProgress(taskId, "IN_PROGRESS", 50);
String executionPrompt = "Based on this execution plan:\n" + reasoning +
"\nExecute the task using the available tools and provide the result.";
- Response response = treeOfThought.getModel().generate(executionPrompt);
+ String response = model.generate(executionPrompt);
- String result = executeToolsFromResponse(response.content(), agent);
+ String result = executeToolsFromResponse(response, agent);
task.setResult(result);
+ updateTaskProgress(taskId, "COMPLETED", 100);
}
private String executeToolsFromResponse(String response, AgentInfo agent) {
@@ -53,11 +67,20 @@ public class TaskManager {
return result.toString();
}
- public void addTask(Task task) {
- tasks.put(task.getId(), task);
+ private void updateTaskProgress(String taskId, String status, int progressPercentage) {
+ TaskProgress progress = new TaskProgress(taskId, status, progressPercentage);
+ webSocketService.sendUpdate("task_progress", progress);
}
- public Task getTask(String taskId) {
- return tasks.get(taskId);
+ private class TaskProgress {
+ public String taskId;
+ public String status;
+ public int progressPercentage;
+
+ TaskProgress(String taskId, String status, int progressPercentage) {
+ this.taskId = taskId;
+ this.status = status;
+ this.progressPercentage = progressPercentage;
+ }
}
}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/team/TeamFormation.java b/src/main/java/com/ioa/team/TeamFormation.java
index 3d925da..0da1e67 100644
--- a/src/main/java/com/ioa/team/TeamFormation.java
+++ b/src/main/java/com/ioa/team/TeamFormation.java
@@ -2,22 +2,22 @@ package com.ioa.team;
import com.ioa.agent.AgentInfo;
import com.ioa.agent.AgentRegistry;
+import com.ioa.model.BedrockLanguageModel;
import com.ioa.task.Task;
-import com.ioa.util.TreeOfThought;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.output.Response;
+import org.springframework.stereotype.Component;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
+@Component
public class TeamFormation {
private AgentRegistry agentRegistry;
- private TreeOfThought treeOfThought;
+ private BedrockLanguageModel model;
- public TeamFormation(AgentRegistry agentRegistry, ChatLanguageModel model) {
+ public TeamFormation(AgentRegistry agentRegistry, BedrockLanguageModel model) {
this.agentRegistry = agentRegistry;
- this.treeOfThought = new TreeOfThought(model);
+ this.model = model;
}
public List formTeam(Task task) {
@@ -29,13 +29,13 @@ public class TeamFormation {
"\nRequired tools: " + requiredTools +
"\nAvailable agents and their tools: " + formatAgentTools(potentialAgents);
- String reasoning = treeOfThought.reason(teamFormationTask, 3, 3);
+ String reasoning = model.generate(teamFormationTask);
String finalDecisionPrompt = "Based on this reasoning:\n" + reasoning +
"\nProvide the final team composition as a comma-separated list of agent IDs.";
- Response response = treeOfThought.getModel().generate(finalDecisionPrompt);
+ String response = model.generate(finalDecisionPrompt);
- return parseTeamComposition(response.content(), potentialAgents);
+ return parseTeamComposition(response, potentialAgents);
}
private String formatAgentTools(List agents) {
diff --git a/src/main/java/com/ioa/tool/Tool.java b/src/main/java/com/ioa/tool/Tool.java
new file mode 100644
index 0000000..3db8918
--- /dev/null
+++ b/src/main/java/com/ioa/tool/Tool.java
@@ -0,0 +1,12 @@
+package com.ioa.tool;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.METHOD)
+public @interface Tool {
+ String value();
+}
\ No newline at end of file
diff --git a/src/main/java/com/ioa/tool/ToolRegistry.java b/src/main/java/com/ioa/tool/ToolRegistry.java
index d0626af..03f6c90 100644
--- a/src/main/java/com/ioa/tool/ToolRegistry.java
+++ b/src/main/java/com/ioa/tool/ToolRegistry.java
@@ -1,8 +1,11 @@
package com.ioa.tool;
+import org.springframework.stereotype.Component;
+
import java.util.HashMap;
import java.util.Map;
+@Component
public class ToolRegistry {
private Map tools = new HashMap<>();
diff --git a/src/main/java/com/ioa/util/TreeOfThought.java b/src/main/java/com/ioa/util/TreeOfThought.java
index c5203ba..fbbecb5 100644
--- a/src/main/java/com/ioa/util/TreeOfThought.java
+++ b/src/main/java/com/ioa/util/TreeOfThought.java
@@ -1,12 +1,11 @@
package com.ioa.util;
-import dev.langchain4j.model.chat.ChatLanguageModel;
-import dev.langchain4j.model.output.Response;
+import com.ioa.model.BedrockLanguageModel;
public class TreeOfThought {
- private final ChatLanguageModel model;
+ private final BedrockLanguageModel model;
- public TreeOfThought(ChatLanguageModel model) {
+ public TreeOfThought(BedrockLanguageModel model) {
this.model = model;
}
@@ -23,8 +22,7 @@ public class TreeOfThought {
for (int i = 0; i < branches; i++) {
String branchPrompt = "Consider the task: " + task + "\nCurrent path: " + path +
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
- Response response = model.generate(branchPrompt);
- String thought = response.content();
+ String thought = model.generate(branchPrompt);
result.append("Branch ").append(i + 1).append(":\n");
result.append(thought).append("\n");
@@ -35,11 +33,10 @@ public class TreeOfThought {
private String evaluateLeaf(String task, String path) {
String prompt = "Evaluate the effectiveness of this approach for the task: " + task + "\nPath: " + path;
- Response response = model.generate(prompt);
- return response.content();
+ return model.generate(prompt);
}
- public ChatLanguageModel getModel() {
+ public BedrockLanguageModel getModel() {
return model;
}
-}
\ No newline at end of file
+}