Agents now have memory in addition to context with Tree of thought
This commit is contained in:
parent
fd66c81027
commit
9cf98dbaf7
|
@ -15,6 +15,7 @@ public class AgentInfo {
|
||||||
private String name;
|
private String name;
|
||||||
private List<String> capabilities;
|
private List<String> capabilities;
|
||||||
private List<String> tools;
|
private List<String> tools;
|
||||||
|
private Memory memory;
|
||||||
private TreeOfThought treeOfThought;
|
private TreeOfThought treeOfThought;
|
||||||
private WebSocketService webSocketService;
|
private WebSocketService webSocketService;
|
||||||
private ToolRegistry toolRegistry;
|
private ToolRegistry toolRegistry;
|
||||||
|
@ -31,16 +32,36 @@ public class AgentInfo {
|
||||||
this.webSocketService = webSocketService;
|
this.webSocketService = webSocketService;
|
||||||
this.toolRegistry = toolRegistry;
|
this.toolRegistry = toolRegistry;
|
||||||
this.model = model;
|
this.model = model;
|
||||||
|
this.memory = new Memory();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void receiveMessage(Message message) {
|
public void receiveMessage(Message message) {
|
||||||
// Process the received message
|
memory.addToHistory(message.getContent());
|
||||||
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||||
"\nYou received a message: " + message.getContent() +
|
"\nYou received a message: " + message.getContent() +
|
||||||
"\nHow would you respond or what actions would you take based on this message?";
|
"\nBased on your memory and context, how would you respond or what actions would you take?" +
|
||||||
|
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||||
String response = model.generate(prompt, null);
|
String response = model.generate(prompt, null);
|
||||||
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
|
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
|
||||||
System.out.println("DEBUG: " + name + " response: " + response);
|
System.out.println("DEBUG: " + name + " response: " + response);
|
||||||
|
|
||||||
|
// Add the response to memory
|
||||||
|
memory.addToHistory("My response: " + response);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performTreeOfThought(String task) {
|
||||||
|
String prompt = "You are " + name + " with capabilities: " + capabilities +
|
||||||
|
"\nTask: " + task +
|
||||||
|
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
|
||||||
|
"\n\nMemory:\n" + memory.getFormattedMemory();
|
||||||
|
|
||||||
|
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
|
||||||
|
String reasoning = (String) totResult.get("reasoning");
|
||||||
|
|
||||||
|
// Add the reasoning to memory
|
||||||
|
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
|
||||||
|
|
||||||
|
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void notifyTurn(ConversationFSM conversation) {
|
public void notifyTurn(ConversationFSM conversation) {
|
||||||
|
|
47
src/main/java/com/ioa/agent/Memory.java
Normal file
47
src/main/java/com/ioa/agent/Memory.java
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package com.ioa.agent;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Memory {
|
||||||
|
private List<String> conversationHistory;
|
||||||
|
private List<String> contextualFacts;
|
||||||
|
private int maxHistorySize = 10; // Adjust as needed
|
||||||
|
|
||||||
|
public Memory() {
|
||||||
|
this.conversationHistory = new ArrayList<>();
|
||||||
|
this.contextualFacts = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addToHistory(String message) {
|
||||||
|
conversationHistory.add(message);
|
||||||
|
if (conversationHistory.size() > maxHistorySize) {
|
||||||
|
conversationHistory.remove(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addContextualFact(String fact) {
|
||||||
|
contextualFacts.add(fact);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getConversationHistory() {
|
||||||
|
return new ArrayList<>(conversationHistory);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getContextualFacts() {
|
||||||
|
return new ArrayList<>(contextualFacts);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getFormattedMemory() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append("Conversation History:\n");
|
||||||
|
for (String message : conversationHistory) {
|
||||||
|
sb.append("- ").append(message).append("\n");
|
||||||
|
}
|
||||||
|
sb.append("\nContextual Facts:\n");
|
||||||
|
for (String fact : contextualFacts) {
|
||||||
|
sb.append("- ").append(fact).append("\n");
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,8 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
import org.springframework.messaging.simp.SimpMessagingTemplate;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.context.annotation.Lazy;
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
import org.springframework.web.socket.TextMessage;
|
||||||
|
import org.springframework.web.socket.WebSocketSession;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class WebSocketService {
|
public class WebSocketService {
|
||||||
|
@ -22,6 +24,22 @@ public class WebSocketService {
|
||||||
this.objectMapper = new ObjectMapper();
|
this.objectMapper = new ObjectMapper();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void handleMessage(WebSocketSession session, TextMessage message) {
|
||||||
|
try {
|
||||||
|
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
|
||||||
|
if (messageArray.length > 0) {
|
||||||
|
String content = messageArray[0];
|
||||||
|
// Assume we're using a default conversation ID for simplicity
|
||||||
|
String conversationId = "default";
|
||||||
|
// Assume we're using a default user ID for simplicity
|
||||||
|
String userId = "user";
|
||||||
|
conversationManager.postMessage(conversationId, userId, content);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void sendUpdate(String topic, Object payload) {
|
public void sendUpdate(String topic, Object payload) {
|
||||||
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
messagingTemplate.convertAndSend("/topic/" + topic, payload);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,12 +45,8 @@ public class TreeOfThought {
|
||||||
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
|
||||||
String thought = model.generate(branchPrompt, null);
|
String thought = model.generate(branchPrompt, null);
|
||||||
|
|
||||||
if (!thought.equals("No response generated") && !thought.startsWith("Error:")) {
|
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
|
||||||
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
|
children.add(childNode);
|
||||||
children.add(childNode);
|
|
||||||
} else {
|
|
||||||
System.out.println("WARNING: Failed to generate thought. Result: " + thought);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
node.put("children", children);
|
node.put("children", children);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user