Agents now have memory in addition to context with Tree of thought

This commit is contained in:
Mahesh Kommareddi 2024-07-20 16:25:20 -04:00
parent fd66c81027
commit 9cf98dbaf7
4 changed files with 90 additions and 8 deletions

View File

@ -15,6 +15,7 @@ public class AgentInfo {
private String name;
private List<String> capabilities;
private List<String> tools;
private Memory memory;
private TreeOfThought treeOfThought;
private WebSocketService webSocketService;
private ToolRegistry toolRegistry;
@ -31,16 +32,36 @@ public class AgentInfo {
this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry;
this.model = model;
this.memory = new Memory();
}
public void receiveMessage(Message message) {
// Process the received message
memory.addToHistory(message.getContent());
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nYou received a message: " + message.getContent() +
"\nHow would you respond or what actions would you take based on this message?";
"\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
String response = model.generate(prompt, null);
System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
System.out.println("DEBUG: " + name + " response: " + response);
// Add the response to memory
memory.addToHistory("My response: " + response);
}
public void performTreeOfThought(String task) {
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nTask: " + task +
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
"\n\nMemory:\n" + memory.getFormattedMemory();
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
String reasoning = (String) totResult.get("reasoning");
// Add the reasoning to memory
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
}
public void notifyTurn(ConversationFSM conversation) {

View File

@ -0,0 +1,47 @@
package com.ioa.agent;
import java.util.ArrayList;
import java.util.List;
public class Memory {
private List<String> conversationHistory;
private List<String> contextualFacts;
private int maxHistorySize = 10; // Adjust as needed
public Memory() {
this.conversationHistory = new ArrayList<>();
this.contextualFacts = new ArrayList<>();
}
public void addToHistory(String message) {
conversationHistory.add(message);
if (conversationHistory.size() > maxHistorySize) {
conversationHistory.remove(0);
}
}
public void addContextualFact(String fact) {
contextualFacts.add(fact);
}
public List<String> getConversationHistory() {
return new ArrayList<>(conversationHistory);
}
public List<String> getContextualFacts() {
return new ArrayList<>(contextualFacts);
}
public String getFormattedMemory() {
StringBuilder sb = new StringBuilder();
sb.append("Conversation History:\n");
for (String message : conversationHistory) {
sb.append("- ").append(message).append("\n");
}
sb.append("\nContextual Facts:\n");
for (String fact : contextualFacts) {
sb.append("- ").append(fact).append("\n");
}
return sb.toString();
}
}

View File

@ -8,6 +8,8 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
@Service
public class WebSocketService {
@ -22,6 +24,22 @@ public class WebSocketService {
this.objectMapper = new ObjectMapper();
}
public void handleMessage(WebSocketSession session, TextMessage message) {
try {
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
if (messageArray.length > 0) {
String content = messageArray[0];
// Assume we're using a default conversation ID for simplicity
String conversationId = "default";
// Assume we're using a default user ID for simplicity
String userId = "user";
conversationManager.postMessage(conversationId, userId, content);
}
} catch (Exception e) {
e.printStackTrace();
}
}
public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload);
}

View File

@ -45,12 +45,8 @@ public class TreeOfThought {
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
String thought = model.generate(branchPrompt, null);
if (!thought.equals("No response generated") && !thought.startsWith("Error:")) {
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
children.add(childNode);
} else {
System.out.println("WARNING: Failed to generate thought. Result: " + thought);
}
}
node.put("children", children);