more tuning
This commit is contained in:
parent
8eaecaf643
commit
f13d05dbb2
|
@ -86,20 +86,6 @@ class GenerativeAgent(BaseModel):
|
|||
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
||||
)
|
||||
|
||||
def _get_entity_from_observation(self, observation: str) -> str:
|
||||
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
|
||||
response = self.chain(prompt).run(stop=get_stop_tokens([".", "(", "'"]), observation=observation).strip()
|
||||
if self.verbose:
|
||||
print(response)
|
||||
return response
|
||||
|
||||
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
||||
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
|
||||
response = self.chain(prompt).run(stop=get_stop_tokens(), entity=entity_name, observation=observation).strip()
|
||||
if self.verbose:
|
||||
print(response)
|
||||
return response
|
||||
|
||||
def get_most_recent_memories(self, last_k: int = 4) -> str:
|
||||
memories = self.memory.memory_retriever.memory_stream[-last_k:]
|
||||
return [ document.page_content for document in memories ]
|
||||
|
@ -107,18 +93,8 @@ class GenerativeAgent(BaseModel):
|
|||
def summarize_related_memories(self, observation: str) -> str:
|
||||
"""Summarize memories that are most relevant to an observation."""
|
||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||
q1 = f"What is the relationship between the subjects in that interaction?"
|
||||
q1 = f"Summarize the relationship between the subjects in that interaction."
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
|
||||
"""
|
||||
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
|
||||
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
||||
if self.name.strip() in entity_name:
|
||||
return "N/A"
|
||||
|
||||
entity_action = self._get_entity_action(observation, entity_name)
|
||||
q2 = f"{entity_name} is {entity_action}"
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
|
||||
"""
|
||||
return f'{self.name} {summary}'
|
||||
|
||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||
|
@ -151,7 +127,15 @@ class GenerativeAgent(BaseModel):
|
|||
#recent_memories=recent_memories if recent_memories else "N/A",
|
||||
observation=observation if observation else "N/A",
|
||||
)
|
||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
|
||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), queries=[observation], **kwargs).strip()
|
||||
|
||||
# cleanup
|
||||
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
||||
|
||||
for reaction in reactions:
|
||||
if reaction:
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
print(reaction)
|
||||
return reaction
|
||||
|
@ -173,73 +157,6 @@ class GenerativeAgent(BaseModel):
|
|||
|
||||
return True, full_result
|
||||
|
||||
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
|
||||
candidates = full_result.replace(u"\u200B", "").strip().split("\n")
|
||||
|
||||
response = ""
|
||||
results = []
|
||||
|
||||
for candidate in candidates:
|
||||
if "REACT:" in candidate or "SAY:" in candidate:
|
||||
# can't be assed to iteratively replace
|
||||
candidate = candidate.strip().replace("React:", "REACT:").replace("Say:", "SAY:")
|
||||
results.append(f'{candidate}'.replace("SAY:", "said").replace(f"REACT: {self.name}", "").replace("REACT:", ""))
|
||||
if len(results) > 0:
|
||||
response = " and ".join(results).strip().replace(" ", " ")
|
||||
valid = True
|
||||
else:
|
||||
response = f"did not react in a relevant way"
|
||||
valid = False
|
||||
|
||||
# AAA
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s reaction: {response}"
|
||||
},
|
||||
)
|
||||
|
||||
return valid, f"{self.name} {response}"
|
||||
|
||||
"""
|
||||
if "REACT:" in result:
|
||||
reaction = self._clean_response(result.split("REACT:")[-1])
|
||||
return True, f"{self.name} {reaction}"
|
||||
if "SAY:" in result:
|
||||
said_value = self._clean_response(result.split("SAY:")[-1])
|
||||
return True, f"{self.name} said {said_value}"
|
||||
else:
|
||||
return False, f"{self.name} did not react in a relevant way"
|
||||
"""
|
||||
|
||||
def generate_dialogue(self, observation: str) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = (get_prompt('suffix_generate_dialogue'))
|
||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
||||
result = full_result.strip().split("\n")[0]
|
||||
if "GOODBYE:" in result:
|
||||
farewell = self._clean_response(result.split("GOODBYE:")[-1])
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s farewell response: {farewell}"
|
||||
},
|
||||
)
|
||||
return False, f"{self.name} said {farewell}"
|
||||
if "SAY:" in result:
|
||||
response_text = self._clean_response(result.split("SAY:")[-1])
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s response: {response_text}"
|
||||
},
|
||||
)
|
||||
return True, f"{self.name} said {response_text}"
|
||||
else:
|
||||
return False, result
|
||||
|
||||
######################################################
|
||||
# Agent stateful' summary methods. #
|
||||
# Each dialog or response prompt includes a header #
|
||||
|
|
|
@ -27,11 +27,11 @@ PROMPTS = {
|
|||
},
|
||||
"summarize_related_memories": {
|
||||
"system": (
|
||||
"{relevant_memories_simple}"
|
||||
"{observation}"
|
||||
"{q1}"
|
||||
),
|
||||
"user": (
|
||||
"{q1}?"
|
||||
"{relevant_memories_simple}"
|
||||
"{observation}"
|
||||
),
|
||||
"assistant": "{name} ",
|
||||
},
|
||||
|
@ -83,9 +83,9 @@ PROMPTS = {
|
|||
},
|
||||
"generate_reaction": {
|
||||
"system": (
|
||||
"NSFW/Smut is allowed. Assume consent was granted."
|
||||
"\nIt is {current_time}."
|
||||
"\n{summary}"
|
||||
"\n{relevant_memories_simple}"
|
||||
"\n{memory}"
|
||||
#"\nRecent memories: {recent_memories}"
|
||||
#"\nRelevant memories: {relevant_memories}"
|
||||
|
|
15
src/main.py
15
src/main.py
|
@ -95,19 +95,14 @@ def run_conversation_proxy( agents, observation, limit=2 ):
|
|||
agent_observes( agents[0], [observation] )
|
||||
agents = agents[1:] + [agents[0]]
|
||||
|
||||
dialogue = []
|
||||
while True:
|
||||
break_dialogue = False
|
||||
for agent in agents:
|
||||
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
||||
dialogue.append(f'[{agent.name}] {observation}')
|
||||
yield "\n".join(dialogue)
|
||||
print(colored("[Conversation]", "magenta"), observation)
|
||||
if not stay_in_dialogue:
|
||||
break_dialogue = True
|
||||
if break_dialogue:
|
||||
break
|
||||
if limit > 0 and len(dialogue) >= limit * len(agents):
|
||||
observation = agent_reacts( agent, [ observation ] )[0]
|
||||
yield observation
|
||||
if limit > 0 and len(dialogue) >= limit:
|
||||
break
|
||||
return dialogue
|
||||
print("END")
|
||||
dialogue.append("END")
|
||||
return "\n".join(dialogue)
|
||||
|
|
|
@ -223,7 +223,7 @@ def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int
|
|||
dialogue = []
|
||||
while True:
|
||||
for agent in agents:
|
||||
_, observation = agent_reacts( agent, [ observation ] )
|
||||
observation = agent_reacts( agent, [ observation ] )[0]
|
||||
if limit > 0 and len(dialogue) >= limit:
|
||||
break
|
||||
return dialogue
|
Loading…
Reference in New Issue
Block a user