updating for new langchain, more tunes
This commit is contained in:
parent
e152cd98a4
commit
8eaecaf643
@ -1,6 +1,7 @@
|
|||||||
langchain
|
langchain
|
||||||
openai
|
openai
|
||||||
llama-cpp-python
|
llama-cpp-python
|
||||||
|
sentence_transformers
|
||||||
gradio
|
gradio
|
||||||
faiss-cpu
|
faiss-cpu
|
||||||
termcolor
|
termcolor
|
@ -30,8 +30,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
|
||||||
|
|
||||||
from .memory import GenerativeAgentMemory
|
from .memory import GenerativeAgentMemory
|
||||||
from .prompts import get_prompt, get_stop_tokens
|
from .prompts import get_prompt, get_stop_tokens
|
||||||
@ -106,6 +107,9 @@ class GenerativeAgent(BaseModel):
|
|||||||
def summarize_related_memories(self, observation: str) -> str:
|
def summarize_related_memories(self, observation: str) -> str:
|
||||||
"""Summarize memories that are most relevant to an observation."""
|
"""Summarize memories that are most relevant to an observation."""
|
||||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||||
|
q1 = f"What is the relationship between the subjects in that interaction?"
|
||||||
|
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
|
||||||
|
"""
|
||||||
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
|
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
|
||||||
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
||||||
if self.name.strip() in entity_name:
|
if self.name.strip() in entity_name:
|
||||||
@ -114,6 +118,7 @@ class GenerativeAgent(BaseModel):
|
|||||||
entity_action = self._get_entity_action(observation, entity_name)
|
entity_action = self._get_entity_action(observation, entity_name)
|
||||||
q2 = f"{entity_name} is {entity_action}"
|
q2 = f"{entity_name} is {entity_action}"
|
||||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
|
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
|
||||||
|
"""
|
||||||
return f'{self.name} {summary}'
|
return f'{self.name} {summary}'
|
||||||
|
|
||||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||||
|
@ -28,9 +28,10 @@ import re
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||||
from langchain.schema import BaseLanguageModel, BaseMemory, Document
|
from langchain.schema import BaseMemory, Document
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ class GenerativeAgentMemory(BaseMemory):
|
|||||||
def _score_memory_importance(self, memory_content: str) -> float:
|
def _score_memory_importance(self, memory_content: str) -> float:
|
||||||
"""Score the absolute importance of the given memory."""
|
"""Score the absolute importance of the given memory."""
|
||||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||||
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/"]), memory_content=memory_content).strip()
|
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/", "("]), memory_content=memory_content).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Importance score: {score}")
|
print(f"Importance score: {score}")
|
||||||
try:
|
try:
|
||||||
|
@ -28,6 +28,7 @@ PROMPTS = {
|
|||||||
"summarize_related_memories": {
|
"summarize_related_memories": {
|
||||||
"system": (
|
"system": (
|
||||||
"{relevant_memories_simple}"
|
"{relevant_memories_simple}"
|
||||||
|
"{observation}"
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{q1}?"
|
"{q1}?"
|
||||||
|
35
src/main.py
35
src/main.py
@ -3,7 +3,7 @@ import gradio as gr
|
|||||||
import gradio.utils
|
import gradio.utils
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
from utils import create_agent, agent_observes, agent_reacts, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
||||||
|
|
||||||
webui = None
|
webui = None
|
||||||
|
|
||||||
@ -42,6 +42,20 @@ def agent_observes_proxy( agents, observations ):
|
|||||||
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
def agent_reacts_proxy( agents, observations ):
|
||||||
|
if not isinstance( agents, list ):
|
||||||
|
agents = [ agents ]
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for agent in agents:
|
||||||
|
if agent not in AGENTS:
|
||||||
|
load_agent( agent )
|
||||||
|
agent = AGENTS[agent]
|
||||||
|
observations = observations.split("\n")
|
||||||
|
response = agent_reacts( agent, observations )
|
||||||
|
messages.append(f"[{agent.name}] {response}")
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
def interview_agent_proxy( agents, message ):
|
def interview_agent_proxy( agents, message ):
|
||||||
if not isinstance( agents, list ):
|
if not isinstance( agents, list ):
|
||||||
agents = [ agents ]
|
agents = [ agents ]
|
||||||
@ -114,10 +128,18 @@ def view_agent( agents, last_k = 50 ):
|
|||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
|
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
|
||||||
memories = "\n".join([ document.page_content for document in memories])
|
memories = "\n".join([ document.page_content for document in memories])
|
||||||
message = f"{agent.name}'s summary:\n{agent.summary}\n{agent.name}'s memories:\n{memories}"
|
message = (
|
||||||
|
f"{agent.name}: (sex: {agent.sex}, age: {agent.age})"
|
||||||
|
f"\n{agent.name}'s innate traits:"
|
||||||
|
f"\n{agent.traits}"
|
||||||
|
f"\n{agent.name}'s summary:"
|
||||||
|
f"\n{agent.summary}"
|
||||||
|
f"\n{agent.name}'s memories:"
|
||||||
|
f"\n{memories}"
|
||||||
|
)
|
||||||
|
|
||||||
messages.append( message )
|
messages.append( message )
|
||||||
return "\n".join(messages)
|
return "\n\n\n".join(messages)
|
||||||
|
|
||||||
def get_agents_list():
|
def get_agents_list():
|
||||||
return [ k for k in AGENTS ]
|
return [ k for k in AGENTS ]
|
||||||
@ -221,6 +243,7 @@ def setup_webui(share=False):
|
|||||||
OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="")
|
OBSERVE_SETTINGS["input"] = gr.Textbox(lines=4, label="Input", value="")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
ACTIONS["memorize"] = gr.Button(value="Memorize")
|
||||||
ACTIONS["act"] = gr.Button(value="Act")
|
ACTIONS["act"] = gr.Button(value="Act")
|
||||||
ACTIONS["view"] = gr.Button(value="View")
|
ACTIONS["view"] = gr.Button(value="View")
|
||||||
ACTIONS["summarize"] = gr.Button(value="Summarize")
|
ACTIONS["summarize"] = gr.Button(value="Summarize")
|
||||||
@ -229,7 +252,11 @@ def setup_webui(share=False):
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output")
|
CONSOLE_OUTPUTS["agent_actions"] = gr.Textbox(lines=8, label="Console Output")
|
||||||
|
|
||||||
ACTIONS["act"].click(agent_observes_proxy,
|
ACTIONS["memorize"].click(agent_observes_proxy,
|
||||||
|
inputs=list(OBSERVE_SETTINGS.values()),
|
||||||
|
outputs=CONSOLE_OUTPUTS["agent_actions"]
|
||||||
|
)
|
||||||
|
ACTIONS["act"].click(agent_reacts_proxy,
|
||||||
inputs=list(OBSERVE_SETTINGS.values()),
|
inputs=list(OBSERVE_SETTINGS.values()),
|
||||||
outputs=CONSOLE_OUTPUTS["agent_actions"]
|
outputs=CONSOLE_OUTPUTS["agent_actions"]
|
||||||
)
|
)
|
||||||
|
25
src/utils.py
25
src/utils.py
@ -13,11 +13,13 @@ import re
|
|||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from langchain.callbacks.base import CallbackManager
|
|
||||||
from langchain.docstore import InMemoryDocstore
|
from langchain.docstore import InMemoryDocstore
|
||||||
|
|
||||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
|
|
||||||
# shit I can shove behind an env var
|
# shit I can shove behind an env var
|
||||||
@ -196,6 +198,16 @@ def agent_observes( agent: GenerativeAgent, observations: List[str] ):
|
|||||||
results.append(agent.memory.add_memory(observation))
|
results.append(agent.memory.add_memory(observation))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
|
||||||
|
results = []
|
||||||
|
for observation in observations:
|
||||||
|
observation = observation.replace("{name}", agent.name)
|
||||||
|
print(colored("[Observation]", "magenta"), observation)
|
||||||
|
_, response = agent.generate_response(observation)
|
||||||
|
print(colored("[Reaction]", "magenta"), response)
|
||||||
|
results.append(response)
|
||||||
|
return results
|
||||||
|
|
||||||
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
|
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
|
||||||
message = message.replace("{name}", agent.name)
|
message = message.replace("{name}", agent.name)
|
||||||
new_message = f"{username} says {message}"
|
new_message = f"{username} says {message}"
|
||||||
@ -204,23 +216,14 @@ def interview_agent(agent: GenerativeAgent, message: str, username: str = "Perso
|
|||||||
|
|
||||||
|
|
||||||
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
||||||
"""Runs a conversation between agents."""
|
|
||||||
print(colored("[Conversation]", "magenta"))
|
print(colored("[Conversation]", "magenta"))
|
||||||
agent_observes( agents[0], [observation] )
|
agent_observes( agents[0], [observation] )
|
||||||
agents = agents[1:] + [agents[0]]
|
agents = agents[1:] + [agents[0]]
|
||||||
|
|
||||||
dialogue = []
|
dialogue = []
|
||||||
while True:
|
while True:
|
||||||
break_dialogue = False
|
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
_, observation = agent_reacts( agent, [ observation ] )
|
||||||
dialogue.append(observation)
|
|
||||||
print(colored("[Conversation]", "magenta"), observation)
|
|
||||||
if not stay_in_dialogue:
|
|
||||||
break_dialogue = True
|
|
||||||
if break_dialogue:
|
|
||||||
break
|
|
||||||
if limit > 0 and len(dialogue) >= limit:
|
if limit > 0 and len(dialogue) >= limit:
|
||||||
break
|
break
|
||||||
agent_observes( agent, [observation] )
|
|
||||||
return dialogue
|
return dialogue
|
Loading…
Reference in New Issue
Block a user