updating for new langchain, more tunes

This commit is contained in:
mrq 2023-05-03 01:01:58 +00:00
parent e152cd98a4
commit 8eaecaf643
6 changed files with 56 additions and 18 deletions

View File

@ -1,6 +1,7 @@
langchain langchain
openai openai
llama-cpp-python llama-cpp-python
sentence_transformers
gradio gradio
faiss-cpu faiss-cpu
termcolor termcolor

View File

@ -30,8 +30,9 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain import LLMChain from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from .memory import GenerativeAgentMemory from .memory import GenerativeAgentMemory
from .prompts import get_prompt, get_stop_tokens from .prompts import get_prompt, get_stop_tokens
@ -106,6 +107,9 @@ class GenerativeAgent(BaseModel):
def summarize_related_memories(self, observation: str) -> str: def summarize_related_memories(self, observation: str) -> str:
"""Summarize memories that are most relevant to an observation.""" """Summarize memories that are most relevant to an observation."""
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories')) prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
q1 = f"What is the relationship between the subjects in that interaction?"
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
"""
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip() entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
q1 = f"What is the relationship between {self.name} and {entity_name}" q1 = f"What is the relationship between {self.name} and {entity_name}"
if self.name.strip() in entity_name: if self.name.strip() in entity_name:
@ -114,6 +118,7 @@ class GenerativeAgent(BaseModel):
entity_action = self._get_entity_action(observation, entity_name) entity_action = self._get_entity_action(observation, entity_name)
q2 = f"{entity_name} is {entity_action}" q2 = f"{entity_name} is {entity_action}"
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip() summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
"""
return f'{self.name} {summary}' return f'{self.name} {summary}'
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip() #return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()

View File

@ -28,9 +28,10 @@ import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain import LLMChain from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseLanguageModel, BaseMemory, Document from langchain.schema import BaseMemory, Document
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -120,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
def _score_memory_importance(self, memory_content: str) -> float: def _score_memory_importance(self, memory_content: str) -> float:
"""Score the absolute importance of the given memory.""" """Score the absolute importance of the given memory."""
prompt = PromptTemplate.from_template(get_prompt("memory_importance")) prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/"]), memory_content=memory_content).strip() score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/", "("]), memory_content=memory_content).strip()
if self.verbose: if self.verbose:
print(f"Importance score: {score}") print(f"Importance score: {score}")
try: try:

View File

@ -28,6 +28,7 @@ PROMPTS = {
"summarize_related_memories": { "summarize_related_memories": {
"system": ( "system": (
"{relevant_memories_simple}" "{relevant_memories_simple}"
"{observation}"
), ),
"user": ( "user": (
"{q1}?" "{q1}?"

View File

@ -3,7 +3,7 @@ import gradio as gr
import gradio.utils import gradio.utils
from termcolor import colored from termcolor import colored
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent from utils import create_agent, agent_observes, agent_reacts, interview_agent, run_conversation, get_summary, save_agent, load_agent
webui = None webui = None
@ -42,6 +42,20 @@ def agent_observes_proxy( agents, observations ):
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}") messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
return "\n".join(messages) return "\n".join(messages)
def agent_reacts_proxy( agents, observations ):
if not isinstance( agents, list ):
agents = [ agents ]
messages = []
for agent in agents:
if agent not in AGENTS:
load_agent( agent )
agent = AGENTS[agent]
observations = observations.split("\n")
response = agent_reacts( agent, observations )
messages.append(f"[{agent.name}] {response}")
return "\n".join(messages)
def interview_agent_proxy( agents, message ): def interview_agent_proxy( agents, message ):
if not isinstance( agents, list ): if not isinstance( agents, list ):
agents = [ agents ] agents = [ agents ]
@ -114,10 +128,18 @@ def view_agent( agents, last_k = 50 ):
agent = AGENTS[agent] agent = AGENTS[agent]
memories = agent.memory.memory_retriever.memory_stream[-last_k:] memories = agent.memory.memory_retriever.memory_stream[-last_k:]
memories = "\n".join([ document.page_content for document in memories]) memories = "\n".join([ document.page_content for document in memories])
message = f"{agent.name}'s summary:\n{agent.summary}\n{agent.name}'s memories:\n{memories}" message = (
f"{agent.name}: (sex: {agent.sex}, age: {agent.age})"
f"\n{agent.name}'s innate traits:"
f"\n{agent.traits}"
f"\n{agent.name}'s summary:"
f"\n{agent.summary}"
f"\n{agent.name}'s memories:"
f"\n{memories}"
)
messages.append( message ) messages.append( message )
return "\n".join(messages) return "\n\n\n".join(messages)
def get_agents_list(): def get_agents_list():
return [ k for k in AGENTS ] return [ k for k in AGENTS ]
@ -221,6 +243,7 @@ def setup_webui(share=False):
OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="") OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="")
with gr.Row(): with gr.Row():
ACTIONS["memorize"] = gr.Button(value="Memorize")
ACTIONS["act"] = gr.Button(value="Act") ACTIONS["act"] = gr.Button(value="Act")
ACTIONS["view"] = gr.Button(value="View") ACTIONS["view"] = gr.Button(value="View")
ACTIONS["summarize"] = gr.Button(value="Summarize") ACTIONS["summarize"] = gr.Button(value="Summarize")
@ -229,7 +252,11 @@ def setup_webui(share=False):
with gr.Column(): with gr.Column():
CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output") CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output")
ACTIONS["act"].click(agent_observes_proxy, ACTIONS["memorize"].click(agent_observes_proxy,
inputs=list(OBSERVE_SETTINGS.values()),
outputs=CONSOLE_OUTPUTS["agent_actions"]
)
ACTIONS["act"].click(agent_reacts_proxy,
inputs=list(OBSERVE_SETTINGS.values()), inputs=list(OBSERVE_SETTINGS.values()),
outputs=CONSOLE_OUTPUTS["agent_actions"] outputs=CONSOLE_OUTPUTS["agent_actions"]
) )

View File

@ -13,11 +13,13 @@ import re
import pickle import pickle
import random import random
from langchain.callbacks.base import CallbackManager
from langchain.docstore import InMemoryDocstore from langchain.docstore import InMemoryDocstore
from langchain.retrievers import TimeWeightedVectorStoreRetriever from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
# shit I can shove behind an env var # shit I can shove behind an env var
@ -196,6 +198,16 @@ def agent_observes( agent: GenerativeAgent, observations: List[str] ):
results.append(agent.memory.add_memory(observation)) results.append(agent.memory.add_memory(observation))
return results return results
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
results = []
for observation in observations:
observation = observation.replace("{name}", agent.name)
print(colored("[Observation]", "magenta"), observation)
_, response = agent.generate_response(observation)
print(colored("[Reaction]", "magenta"), response)
results.append(response)
return results
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str: def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
message = message.replace("{name}", agent.name) message = message.replace("{name}", agent.name)
new_message = f"{username} says {message}" new_message = f"{username} says {message}"
@ -204,23 +216,14 @@ def interview_agent(agent: GenerativeAgent, message: str, username: str = "Perso
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None: def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
"""Runs a conversation between agents."""
print(colored("[Conversation]", "magenta")) print(colored("[Conversation]", "magenta"))
agent_observes( agents[0], [observation] ) agent_observes( agents[0], [observation] )
agents = agents[1:] + [agents[0]] agents = agents[1:] + [agents[0]]
dialogue = [] dialogue = []
while True: while True:
break_dialogue = False
for agent in agents: for agent in agents:
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation) _, observation = agent_reacts( agent, [ observation ] )
dialogue.append(observation)
print(colored("[Conversation]", "magenta"), observation)
if not stay_in_dialogue:
break_dialogue = True
if break_dialogue:
break
if limit > 0 and len(dialogue) >= limit: if limit > 0 and len(dialogue) >= limit:
break break
agent_observes( agent, [observation] )
return dialogue return dialogue