updating for new langchain, more tunes

This commit is contained in:
mrq 2023-05-03 01:01:58 +00:00
parent e152cd98a4
commit 8eaecaf643
6 changed files with 56 additions and 18 deletions

View File

@ -1,6 +1,7 @@
langchain
openai
llama-cpp-python
sentence_transformers
gradio
faiss-cpu
termcolor

View File

@ -30,8 +30,9 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from .memory import GenerativeAgentMemory
from .prompts import get_prompt, get_stop_tokens
@ -106,6 +107,9 @@ class GenerativeAgent(BaseModel):
def summarize_related_memories(self, observation: str) -> str:
"""Summarize memories that are most relevant to an observation."""
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
q1 = f"What is the relationship between the subjects in that interaction?"
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
"""
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
q1 = f"What is the relationship between {self.name} and {entity_name}"
if self.name.strip() in entity_name:
@ -114,6 +118,7 @@ class GenerativeAgent(BaseModel):
entity_action = self._get_entity_action(observation, entity_name)
q2 = f"{entity_name} is {entity_action}"
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
"""
return f'{self.name} {summary}'
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()

View File

@ -28,9 +28,10 @@ import re
from typing import Any, Dict, List, Optional
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseLanguageModel, BaseMemory, Document
from langchain.schema import BaseMemory, Document
logger = logging.getLogger(__name__)
@ -120,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
def _score_memory_importance(self, memory_content: str) -> float:
"""Score the absolute importance of the given memory."""
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/"]), memory_content=memory_content).strip()
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/", "("]), memory_content=memory_content).strip()
if self.verbose:
print(f"Importance score: {score}")
try:

View File

@ -28,6 +28,7 @@ PROMPTS = {
"summarize_related_memories": {
"system": (
"{relevant_memories_simple}"
"{observation}"
),
"user": (
"{q1}?"

View File

@ -3,7 +3,7 @@ import gradio as gr
import gradio.utils
from termcolor import colored
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent
from utils import create_agent, agent_observes, agent_reacts, interview_agent, run_conversation, get_summary, save_agent, load_agent
webui = None
@ -42,6 +42,20 @@ def agent_observes_proxy( agents, observations ):
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
return "\n".join(messages)
def agent_reacts_proxy( agents, observations ):
if not isinstance( agents, list ):
agents = [ agents ]
messages = []
for agent in agents:
if agent not in AGENTS:
load_agent( agent )
agent = AGENTS[agent]
observations = observations.split("\n")
response = agent_reacts( agent, observations )
messages.append(f"[{agent.name}] {response}")
return "\n".join(messages)
def interview_agent_proxy( agents, message ):
if not isinstance( agents, list ):
agents = [ agents ]
@ -114,10 +128,18 @@ def view_agent( agents, last_k = 50 ):
agent = AGENTS[agent]
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
memories = "\n".join([ document.page_content for document in memories])
message = f"{agent.name}'s summary:\n{agent.summary}\n{agent.name}'s memories:\n{memories}"
message = (
f"{agent.name}: (sex: {agent.sex}, age: {agent.age})"
f"\n{agent.name}'s innate traits:"
f"\n{agent.traits}"
f"\n{agent.name}'s summary:"
f"\n{agent.summary}"
f"\n{agent.name}'s memories:"
f"\n{memories}"
)
messages.append( message )
return "\n".join(messages)
return "\n\n\n".join(messages)
def get_agents_list():
return [ k for k in AGENTS ]
@ -221,6 +243,7 @@ def setup_webui(share=False):
OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="")
with gr.Row():
ACTIONS["memorize"] = gr.Button(value="Memorize")
ACTIONS["act"] = gr.Button(value="Act")
ACTIONS["view"] = gr.Button(value="View")
ACTIONS["summarize"] = gr.Button(value="Summarize")
@ -229,7 +252,11 @@ def setup_webui(share=False):
with gr.Column():
CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output")
ACTIONS["act"].click(agent_observes_proxy,
ACTIONS["memorize"].click(agent_observes_proxy,
inputs=list(OBSERVE_SETTINGS.values()),
outputs=CONSOLE_OUTPUTS["agent_actions"]
)
ACTIONS["act"].click(agent_reacts_proxy,
inputs=list(OBSERVE_SETTINGS.values()),
outputs=CONSOLE_OUTPUTS["agent_actions"]
)

View File

@ -13,11 +13,13 @@ import re
import pickle
import random
from langchain.callbacks.base import CallbackManager
from langchain.docstore import InMemoryDocstore
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import FAISS
# shit I can shove behind an env var
@ -196,6 +198,16 @@ def agent_observes( agent: GenerativeAgent, observations: List[str] ):
results.append(agent.memory.add_memory(observation))
return results
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
results = []
for observation in observations:
observation = observation.replace("{name}", agent.name)
print(colored("[Observation]", "magenta"), observation)
_, response = agent.generate_response(observation)
print(colored("[Reaction]", "magenta"), response)
results.append(response)
return results
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
message = message.replace("{name}", agent.name)
new_message = f"{username} says {message}"
@ -204,23 +216,14 @@ def interview_agent(agent: GenerativeAgent, message: str, username: str = "Perso
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
"""Runs a conversation between agents."""
print(colored("[Conversation]", "magenta"))
agent_observes( agents[0], [observation] )
agents = agents[1:] + [agents[0]]
dialogue = []
while True:
break_dialogue = False
for agent in agents:
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
dialogue.append(observation)
print(colored("[Conversation]", "magenta"), observation)
if not stay_in_dialogue:
break_dialogue = True
if break_dialogue:
break
_, observation = agent_reacts( agent, [ observation ] )
if limit > 0 and len(dialogue) >= limit:
break
agent_observes( agent, [observation] )
return dialogue