updating for new langchain, more tunes
This commit is contained in:
parent
e152cd98a4
commit
8eaecaf643
|
@ -1,6 +1,7 @@
|
|||
langchain
|
||||
openai
|
||||
llama-cpp-python
|
||||
sentence_transformers
|
||||
gradio
|
||||
faiss-cpu
|
||||
termcolor
|
|
@ -30,8 +30,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
from .memory import GenerativeAgentMemory
|
||||
from .prompts import get_prompt, get_stop_tokens
|
||||
|
@ -106,6 +107,9 @@ class GenerativeAgent(BaseModel):
|
|||
def summarize_related_memories(self, observation: str) -> str:
|
||||
"""Summarize memories that are most relevant to an observation."""
|
||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||
q1 = f"What is the relationship between the subjects in that interaction?"
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
|
||||
"""
|
||||
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
|
||||
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
||||
if self.name.strip() in entity_name:
|
||||
|
@ -114,6 +118,7 @@ class GenerativeAgent(BaseModel):
|
|||
entity_action = self._get_entity_action(observation, entity_name)
|
||||
q2 = f"{entity_name} is {entity_action}"
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
|
||||
"""
|
||||
return f'{self.name} {summary}'
|
||||
|
||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||
|
|
|
@ -28,9 +28,10 @@ import re
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseLanguageModel, BaseMemory, Document
|
||||
from langchain.schema import BaseMemory, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -120,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
|
|||
def _score_memory_importance(self, memory_content: str) -> float:
|
||||
"""Score the absolute importance of the given memory."""
|
||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/"]), memory_content=memory_content).strip()
|
||||
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/", "("]), memory_content=memory_content).strip()
|
||||
if self.verbose:
|
||||
print(f"Importance score: {score}")
|
||||
try:
|
||||
|
|
|
@ -28,6 +28,7 @@ PROMPTS = {
|
|||
"summarize_related_memories": {
|
||||
"system": (
|
||||
"{relevant_memories_simple}"
|
||||
"{observation}"
|
||||
),
|
||||
"user": (
|
||||
"{q1}?"
|
||||
|
|
35
src/main.py
35
src/main.py
|
@ -3,7 +3,7 @@ import gradio as gr
|
|||
import gradio.utils
|
||||
from termcolor import colored
|
||||
|
||||
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
||||
from utils import create_agent, agent_observes, agent_reacts, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
||||
|
||||
webui = None
|
||||
|
||||
|
@ -42,6 +42,20 @@ def agent_observes_proxy( agents, observations ):
|
|||
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
||||
return "\n".join(messages)
|
||||
|
||||
def agent_reacts_proxy( agents, observations ):
|
||||
if not isinstance( agents, list ):
|
||||
agents = [ agents ]
|
||||
|
||||
messages = []
|
||||
for agent in agents:
|
||||
if agent not in AGENTS:
|
||||
load_agent( agent )
|
||||
agent = AGENTS[agent]
|
||||
observations = observations.split("\n")
|
||||
response = agent_reacts( agent, observations )
|
||||
messages.append(f"[{agent.name}] {response}")
|
||||
return "\n".join(messages)
|
||||
|
||||
def interview_agent_proxy( agents, message ):
|
||||
if not isinstance( agents, list ):
|
||||
agents = [ agents ]
|
||||
|
@ -114,10 +128,18 @@ def view_agent( agents, last_k = 50 ):
|
|||
agent = AGENTS[agent]
|
||||
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
|
||||
memories = "\n".join([ document.page_content for document in memories])
|
||||
message = f"{agent.name}'s summary:\n{agent.summary}\n{agent.name}'s memories:\n{memories}"
|
||||
message = (
|
||||
f"{agent.name}: (sex: {agent.sex}, age: {agent.age})"
|
||||
f"\n{agent.name}'s innate traits:"
|
||||
f"\n{agent.traits}"
|
||||
f"\n{agent.name}'s summary:"
|
||||
f"\n{agent.summary}"
|
||||
f"\n{agent.name}'s memories:"
|
||||
f"\n{memories}"
|
||||
)
|
||||
|
||||
messages.append( message )
|
||||
return "\n".join(messages)
|
||||
return "\n\n\n".join(messages)
|
||||
|
||||
def get_agents_list():
|
||||
return [ k for k in AGENTS ]
|
||||
|
@ -221,6 +243,7 @@ def setup_webui(share=False):
|
|||
OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="")
|
||||
|
||||
with gr.Row():
|
||||
ACTIONS["memorize"] = gr.Button(value="Memorize")
|
||||
ACTIONS["act"] = gr.Button(value="Act")
|
||||
ACTIONS["view"] = gr.Button(value="View")
|
||||
ACTIONS["summarize"] = gr.Button(value="Summarize")
|
||||
|
@ -229,7 +252,11 @@ def setup_webui(share=False):
|
|||
with gr.Column():
|
||||
CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output")
|
||||
|
||||
ACTIONS["act"].click(agent_observes_proxy,
|
||||
ACTIONS["memorize"].click(agent_observes_proxy,
|
||||
inputs=list(OBSERVE_SETTINGS.values()),
|
||||
outputs=CONSOLE_OUTPUTS["agent_actions"]
|
||||
)
|
||||
ACTIONS["act"].click(agent_reacts_proxy,
|
||||
inputs=list(OBSERVE_SETTINGS.values()),
|
||||
outputs=CONSOLE_OUTPUTS["agent_actions"]
|
||||
)
|
||||
|
|
25
src/utils.py
25
src/utils.py
|
@ -13,11 +13,13 @@ import re
|
|||
import pickle
|
||||
import random
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.docstore import InMemoryDocstore
|
||||
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
# shit I can shove behind an env var
|
||||
|
@ -196,6 +198,16 @@ def agent_observes( agent: GenerativeAgent, observations: List[str] ):
|
|||
results.append(agent.memory.add_memory(observation))
|
||||
return results
|
||||
|
||||
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
|
||||
results = []
|
||||
for observation in observations:
|
||||
observation = observation.replace("{name}", agent.name)
|
||||
print(colored("[Observation]", "magenta"), observation)
|
||||
_, response = agent.generate_response(observation)
|
||||
print(colored("[Reaction]", "magenta"), response)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
|
||||
message = message.replace("{name}", agent.name)
|
||||
new_message = f"{username} says {message}"
|
||||
|
@ -204,23 +216,14 @@ def interview_agent(agent: GenerativeAgent, message: str, username: str = "Perso
|
|||
|
||||
|
||||
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
||||
"""Runs a conversation between agents."""
|
||||
print(colored("[Conversation]", "magenta"))
|
||||
agent_observes( agents[0], [observation] )
|
||||
agents = agents[1:] + [agents[0]]
|
||||
|
||||
dialogue = []
|
||||
while True:
|
||||
break_dialogue = False
|
||||
for agent in agents:
|
||||
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
||||
dialogue.append(observation)
|
||||
print(colored("[Conversation]", "magenta"), observation)
|
||||
if not stay_in_dialogue:
|
||||
break_dialogue = True
|
||||
if break_dialogue:
|
||||
break
|
||||
_, observation = agent_reacts( agent, [ observation ] )
|
||||
if limit > 0 and len(dialogue) >= limit:
|
||||
break
|
||||
agent_observes( agent, [observation] )
|
||||
return dialogue
|
Loading…
Reference in New Issue
Block a user