Agents now have memory in addition to context with Tree of thought

This commit is contained in:
Mahesh Kommareddi 2024-07-20 16:25:20 -04:00
parent fd66c81027
commit 9cf98dbaf7
4 changed files with 90 additions and 8 deletions

View File

@ -15,6 +15,7 @@ public class AgentInfo {
private String name; private String name;
private List<String> capabilities; private List<String> capabilities;
private List<String> tools; private List<String> tools;
private Memory memory;
private TreeOfThought treeOfThought; private TreeOfThought treeOfThought;
private WebSocketService webSocketService; private WebSocketService webSocketService;
private ToolRegistry toolRegistry; private ToolRegistry toolRegistry;
@ -31,16 +32,36 @@ public class AgentInfo {
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.model = model; this.model = model;
this.memory = new Memory();
} }
public void receiveMessage(Message message) { public void receiveMessage(Message message) {
// Process the received message memory.addToHistory(message.getContent());
String prompt = "You are " + name + " with capabilities: " + capabilities + String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nYou received a message: " + message.getContent() + "\nYou received a message: " + message.getContent() +
"\nHow would you respond or what actions would you take based on this message?"; "\nBased on your memory and context, how would you respond or what actions would you take?" +
"\n\nMemory:\n" + memory.getFormattedMemory();
String response = model.generate(prompt, null); String response = model.generate(prompt, null);
System.out.println("DEBUG: " + name + " processed message: " + message.getContent()); System.out.println("DEBUG: " + name + " processed message: " + message.getContent());
System.out.println("DEBUG: " + name + " response: " + response); System.out.println("DEBUG: " + name + " response: " + response);
// Add the response to memory
memory.addToHistory("My response: " + response);
}
public void performTreeOfThought(String task) {
String prompt = "You are " + name + " with capabilities: " + capabilities +
"\nTask: " + task +
"\nBased on your memory and context, perform a tree of thought reasoning to approach this task." +
"\n\nMemory:\n" + memory.getFormattedMemory();
Map<String, Object> totResult = treeOfThought.reason(prompt, 3, 2);
String reasoning = (String) totResult.get("reasoning");
// Add the reasoning to memory
memory.addContextualFact("Tree of Thought for task '" + task + "': " + reasoning);
System.out.println("DEBUG: " + name + " Tree of Thought reasoning: " + reasoning);
} }
public void notifyTurn(ConversationFSM conversation) { public void notifyTurn(ConversationFSM conversation) {

View File

@ -0,0 +1,47 @@
package com.ioa.agent;
import java.util.ArrayList;
import java.util.List;
public class Memory {
private List<String> conversationHistory;
private List<String> contextualFacts;
private int maxHistorySize = 10; // Adjust as needed
public Memory() {
this.conversationHistory = new ArrayList<>();
this.contextualFacts = new ArrayList<>();
}
public void addToHistory(String message) {
conversationHistory.add(message);
if (conversationHistory.size() > maxHistorySize) {
conversationHistory.remove(0);
}
}
public void addContextualFact(String fact) {
contextualFacts.add(fact);
}
public List<String> getConversationHistory() {
return new ArrayList<>(conversationHistory);
}
public List<String> getContextualFacts() {
return new ArrayList<>(contextualFacts);
}
public String getFormattedMemory() {
StringBuilder sb = new StringBuilder();
sb.append("Conversation History:\n");
for (String message : conversationHistory) {
sb.append("- ").append(message).append("\n");
}
sb.append("\nContextual Facts:\n");
for (String fact : contextualFacts) {
sb.append("- ").append(fact).append("\n");
}
return sb.toString();
}
}

View File

@ -8,6 +8,8 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
@Service @Service
public class WebSocketService { public class WebSocketService {
@ -22,6 +24,22 @@ public class WebSocketService {
this.objectMapper = new ObjectMapper(); this.objectMapper = new ObjectMapper();
} }
public void handleMessage(WebSocketSession session, TextMessage message) {
try {
String[] messageArray = objectMapper.readValue(message.getPayload(), String[].class);
if (messageArray.length > 0) {
String content = messageArray[0];
// Assume we're using a default conversation ID for simplicity
String conversationId = "default";
// Assume we're using a default user ID for simplicity
String userId = "user";
conversationManager.postMessage(conversationId, userId, content);
}
} catch (Exception e) {
e.printStackTrace();
}
}
public void sendUpdate(String topic, Object payload) { public void sendUpdate(String topic, Object payload) {
messagingTemplate.convertAndSend("/topic/" + topic, payload); messagingTemplate.convertAndSend("/topic/" + topic, payload);
} }

View File

@ -45,12 +45,8 @@ public class TreeOfThought {
"\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):"; "\nExplore a new branch of thought (branch " + (i+1) + "/" + branches + "):";
String thought = model.generate(branchPrompt, null); String thought = model.generate(branchPrompt, null);
if (!thought.equals("No response generated") && !thought.startsWith("Error:")) { Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought);
Map<String, Object> childNode = exploreThought(task, depth - 1, branches, thought); children.add(childNode);
children.add(childNode);
} else {
System.out.println("WARNING: Failed to generate thought. Result: " + thought);
}
} }
node.put("children", children); node.put("children", children);