Agents now have memory in addition to context with Tree of thought
This commit is contained in:
parent
fd66c81027
commit
9cf98dbaf7
|
@ -15,6 +15,7 @@ public class AgentInfo {
|
|||
private String name;
|
||||
private List<String> capabilities;
|
||||
private List<String> tools;
|
||||
private Memory memory;
|
||||
private TreeOfThought treeOfThought;
|
||||
private WebSocketService webSocketService;
|
||||
private ToolRegistry toolRegistry;
|
||||
|
@ -31,16 +32,36 @@ public class AgentInfo {
|
|||
this.webSocketService = webSocketService;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.model = model;
|
||||
this.memory = new Memory();
|
||||
}
|
||||
|
||||
public void receiveMessage(Message message) {
|
||||
// Process the received message
|
||||
memory.addToHistory(message.getContent());
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nYou received a message: " + message.getContent() +
|
||||
"\nHow would you respond or what actions would you take based on this message?";
|
||||
"\nBased on your memory and context, how would you respond or what actions would you take?" +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
String response = model.generate(prompt, null);
|
||||
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
|
||||
System.out.println("DEBUG: " + name + " response: " + response);
|
||||
|
||||
// Add the response to memory
|
||||
memory.addToHistory("My response: " + response);
|
||||
}
|
||||
|
||||
public void performTreeOfThought(String task) {
|
||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||
"\nTask: " + task +
|
||||
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
|
||||
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||
|
||||
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
|
||||
String reasoning = (String) totResult.get("reasoning");
|
||||
|
||||
// Add the reasoning to memory
|
||||
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
|
||||
|
||||
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
|
||||
}
|
||||
|
||||
public void notifyTurn(ConversationFSM conversation) {
|
||||
|
|
47
src/main/java/com/ioa/agent/Memory.java
Normal file
47
src/main/java/com/ioa/agent/Memory.java
Normal file
|
@ -0,0 +1,47 @@
|
|||
package com.ioa.agent;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class Memory {
|
||||
private List<String> conversationHistory;
|
||||
private List<String> contextualFacts;
|
||||
private int maxHistorySize = 10; // Adjust as needed
|
||||
|
||||
public Memory() {
|
||||
this.conversationHistory = new ArrayList<>();
|
||||
this.contextualFacts = new ArrayList<>();
|
||||
}
|
||||
|
||||
public void addToHistory(String message) {
|
||||
conversationHistory.add(message);
|
||||
if (conversationHistory.size() > maxHistorySize) {
|
||||
conversationHistory.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
public void addContextualFact(String fact) {
|
||||
contextualFacts.add(fact);
|
||||
}
|
||||
|
||||
public List<String> getConversationHistory() {
|
||||
return new ArrayList<>(conversationHistory);
|
||||
}
|
||||
|
||||
public List<String> getContextualFacts() {
|
||||
return new ArrayList<>(contextualFacts);
|
||||
}
|
||||
|
||||
public String getFormattedMemory() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Conversation History:\n");
|
||||
for (String message : conversationHistory) {
|
||||
sb.append("- ").append(message).append("\n");
|
||||
}
|
||||
sb.append("\nContextual Facts:\n");
|
||||
for (String fact : contextualFacts) {
|
||||
sb.append("- ").append(fact).append("\n");
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
|
@ -8,6 +8,8 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
|
||||
@Service
|
||||
public class WebSocketService {
|
||||
|
@ -22,6 +24,22 @@ public class WebSocketService {
|
|||
this.objectMapper = new ObjectMapper();
|
||||
}
|
||||
|
||||
public void handleMessage(WebSocketSession session, TextMessage message) {
|
||||
try {
|
||||
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
|
||||
if (messageArray.length > 0) {
|
||||
String content = messageArray[0];
|
||||
// Assume we're using a default conversation ID for simplicity
|
||||
String conversationId = "default";
|
||||
// Assume we're using a default user ID for simplicity
|
||||
String userId = "user";
|
||||
conversationManager.postMessage(conversationId, userId, content);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public void sendUpdate(String topic, Object payload) {
|
||||
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
||||
}
|
||||
|
|
|
@ -45,12 +45,8 @@ public class TreeOfThought {
|
|||
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
||||
String thought = model.generate(branchPrompt, null);
|
||||
|
||||
if (!thought.equals("No response generated") && !thought.startsWith("Error:")) {
|
||||
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
|
||||
children.add(childNode);
|
||||
} else {
|
||||
System.out.println("WARNING: Failed to generate thought. Result: " + thought);
|
||||
}
|
||||
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
|
||||
children.add(childNode);
|
||||
}
|
||||
|
||||
node.put("children", children);
|
||||
|
|
Loading…
Reference in New Issue
Block a user