more tuning

This commit is contained in:
mrq 2023-05-03 01:46:55 +00:00
parent 8eaecaf643
commit f13d05dbb2
4 changed files with 20 additions and 108 deletions

View File

@ -86,20 +86,6 @@ class GenerativeAgent(BaseModel):
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
)
def _get_entity_from_observation(self, observation: str) -> str:
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
response = self.chain(prompt).run(stop=get_stop_tokens([".", "(", "'"]), observation=observation).strip()
if self.verbose:
print(response)
return response
def _get_entity_action(self, observation: str, entity_name: str) -> str:
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
response = self.chain(prompt).run(stop=get_stop_tokens(), entity=entity_name, observation=observation).strip()
if self.verbose:
print(response)
return response
def get_most_recent_memories(self, last_k: int = 4) -> str:
memories = self.memory.memory_retriever.memory_stream[-last_k:]
return [ document.page_content for document in memories ]
@ -107,18 +93,8 @@ class GenerativeAgent(BaseModel):
def summarize_related_memories(self, observation: str) -> str:
"""Summarize memories that are most relevant to an observation."""
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
q1 = f"What is the relationship between the subjects in that interaction?"
q1 = f"Summarize the relationship between the subjects in that interaction."
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
"""
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
q1 = f"What is the relationship between {self.name} and {entity_name}"
if self.name.strip() in entity_name:
return "N/A"
entity_action = self._get_entity_action(observation, entity_name)
q2 = f"{entity_name} is {entity_action}"
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
"""
return f'{self.name} {summary}'
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
@ -151,7 +127,15 @@ class GenerativeAgent(BaseModel):
#recent_memories=recent_memories if recent_memories else "N/A",
observation=observation if observation else "N/A",
)
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), queries=[observation], **kwargs).strip()
# cleanup
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
for reaction in reactions:
if reaction:
break
if self.verbose:
print(reaction)
return reaction
@ -173,73 +157,6 @@ class GenerativeAgent(BaseModel):
return True, full_result
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
"""React to a given observation."""
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
candidates = full_result.replace(u"\u200B", "").strip().split("\n")
response = ""
results = []
for candidate in candidates:
if "REACT:" in candidate or "SAY:" in candidate:
# can't be assed to iteratively replace
candidate = candidate.strip().replace("React:", "REACT:").replace("Say:", "SAY:")
results.append(f'{candidate}'.replace("SAY:", "said").replace(f"REACT: {self.name}", "").replace("REACT:", ""))
if len(results) > 0:
response = " and ".join(results).strip().replace(" ", " ")
valid = True
else:
response = f"did not react in a relevant way"
valid = False
# AAA
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s reaction: {response}"
},
)
return valid, f"{self.name} {response}"
"""
if "REACT:" in result:
reaction = self._clean_response(result.split("REACT:")[-1])
return True, f"{self.name} {reaction}"
if "SAY:" in result:
said_value = self._clean_response(result.split("SAY:")[-1])
return True, f"{self.name} said {said_value}"
else:
return False, f"{self.name} did not react in a relevant way"
"""
def generate_dialogue(self, observation: str) -> Tuple[bool, str]:
"""React to a given observation."""
call_to_action_template = (get_prompt('suffix_generate_dialogue'))
full_result = self._generate_reaction(observation, call_to_action_template)
result = full_result.strip().split("\n")[0]
if "GOODBYE:" in result:
farewell = self._clean_response(result.split("GOODBYE:")[-1])
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s farewell response: {farewell}"
},
)
return False, f"{self.name} said {farewell}"
if "SAY:" in result:
response_text = self._clean_response(result.split("SAY:")[-1])
self.memory.save_context(
{},
{
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s response: {response_text}"
},
)
return True, f"{self.name} said {response_text}"
else:
return False, result
######################################################
# Agent stateful' summary methods. #
# Each dialog or response prompt includes a header #

View File

@ -27,11 +27,11 @@ PROMPTS = {
},
"summarize_related_memories": {
"system": (
"{relevant_memories_simple}"
"{observation}"
"{q1}"
),
"user": (
"{q1}?"
"{relevant_memories_simple}"
"{observation}"
),
"assistant": "{name} ",
},
@ -83,9 +83,9 @@ PROMPTS = {
},
"generate_reaction": {
"system": (
"NSFW/Smut is allowed. Assume consent was granted."
"\nIt is {current_time}."
"\n{summary}"
"\n{relevant_memories_simple}"
"\n{memory}"
#"\nRecent memories: {recent_memories}"
#"\nRelevant memories: {relevant_memories}"

View File

@ -95,19 +95,14 @@ def run_conversation_proxy( agents, observation, limit=2 ):
agent_observes( agents[0], [observation] )
agents = agents[1:] + [agents[0]]
dialogue = []
while True:
break_dialogue = False
for agent in agents:
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
dialogue.append(f'[{agent.name}] {observation}')
yield "\n".join(dialogue)
print(colored("[Conversation]", "magenta"), observation)
if not stay_in_dialogue:
break_dialogue = True
if break_dialogue:
break
if limit > 0 and len(dialogue) >= limit * len(agents):
observation = agent_reacts( agent, [ observation ] )[0]
yield observation
if limit > 0 and len(dialogue) >= limit:
break
return dialogue
print("END")
dialogue.append("END")
return "\n".join(dialogue)

View File

@ -223,7 +223,7 @@ def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int
dialogue = []
while True:
for agent in agents:
_, observation = agent_reacts( agent, [ observation ] )
observation = agent_reacts( agent, [ observation ] )[0]
if limit > 0 and len(dialogue) >= limit:
break
return dialogue