diff --git a/src/ext/generative_agent.py b/src/ext/generative_agent.py index b7d11d5..9cd9e11 100755 --- a/src/ext/generative_agent.py +++ b/src/ext/generative_agent.py @@ -86,20 +86,6 @@ class GenerativeAgent(BaseModel): llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory ) - def _get_entity_from_observation(self, observation: str) -> str: - prompt = PromptTemplate.from_template(get_prompt('entity_from_observation')) - response = self.chain(prompt).run(stop=get_stop_tokens([".", "(", "'"]), observation=observation).strip() - if self.verbose: - print(response) - return response - - def _get_entity_action(self, observation: str, entity_name: str) -> str: - prompt = PromptTemplate.from_template(get_prompt('entity_action')) - response = self.chain(prompt).run(stop=get_stop_tokens(), entity=entity_name, observation=observation).strip() - if self.verbose: - print(response) - return response - def get_most_recent_memories(self, last_k: int = 4) -> str: memories = self.memory.memory_retriever.memory_stream[-last_k:] return [ document.page_content for document in memories ] @@ -107,18 +93,8 @@ class GenerativeAgent(BaseModel): def summarize_related_memories(self, observation: str) -> str: """Summarize memories that are most relevant to an observation.""" prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories')) - q1 = f"What is the relationship between the subjects in that interaction?" + q1 = f"Summarize the relationship between the subjects in that interaction." summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip() - """ - entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip() - q1 = f"What is the relationship between {self.name} and {entity_name}" - if self.name.strip() in entity_name: - return "N/A" - - entity_action = self._get_entity_action(observation, entity_name) - q2 = f"{entity_name} is {entity_action}" - summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip() - """ return f'{self.name} {summary}' #return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip() @@ -151,7 +127,15 @@ class GenerativeAgent(BaseModel): #recent_memories=recent_memories if recent_memories else "N/A", observation=observation if observation else "N/A", ) - reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip() + reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), queries=[observation], **kwargs).strip() + + # cleanup + reactions = reaction.replace(u"\u200B", "").strip().split("\n") + + for reaction in reactions: + if reaction: + break + if self.verbose: print(reaction) return reaction @@ -173,73 +157,6 @@ class GenerativeAgent(BaseModel): return True, full_result - def generate_reaction(self, observation: str) -> Tuple[bool, str]: - """React to a given observation.""" - full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction')) - candidates = full_result.replace(u"\u200B", "").strip().split("\n") - - response = "" - results = [] - - for candidate in candidates: - if "REACT:" in candidate or "SAY:" in candidate: - # can't be assed to iteratively replace - candidate = candidate.strip().replace("React:", "REACT:").replace("Say:", "SAY:") - results.append(f'{candidate}'.replace("SAY:", "said").replace(f"REACT: {self.name}", "").replace("REACT:", "")) - if len(results) > 0: - response = " and ".join(results).strip().replace(" ", " ") - valid = True - else: - response = f"did not react in a relevant way" - valid = False - - # AAA - self.memory.save_context( - {}, - { - self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s reaction: {response}" - }, - ) - - return valid, f"{self.name} {response}" - - """ - if "REACT:" in result: - reaction = self._clean_response(result.split("REACT:")[-1]) - return True, f"{self.name} {reaction}" - if "SAY:" in result: - said_value = self._clean_response(result.split("SAY:")[-1]) - return True, f"{self.name} said {said_value}" - else: - return False, f"{self.name} did not react in a relevant way" - """ - - def generate_dialogue(self, observation: str) -> Tuple[bool, str]: - """React to a given observation.""" - call_to_action_template = (get_prompt('suffix_generate_dialogue')) - full_result = self._generate_reaction(observation, call_to_action_template) - result = full_result.strip().split("\n")[0] - if "GOODBYE:" in result: - farewell = self._clean_response(result.split("GOODBYE:")[-1]) - self.memory.save_context( - {}, - { - self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s farewell response: {farewell}" - }, - ) - return False, f"{self.name} said {farewell}" - if "SAY:" in result: - response_text = self._clean_response(result.split("SAY:")[-1]) - self.memory.save_context( - {}, - { - self.memory.add_memory_key: f"{self.name} observed: {observation}; {self.name}'s response: {response_text}" - }, - ) - return True, f"{self.name} said {response_text}" - else: - return False, result - ###################################################### # Agent stateful' summary methods. # # Each dialog or response prompt includes a header # diff --git a/src/ext/prompts.py b/src/ext/prompts.py index e41d9d8..43e521a 100755 --- a/src/ext/prompts.py +++ b/src/ext/prompts.py @@ -27,11 +27,11 @@ PROMPTS = { }, "summarize_related_memories": { "system": ( - "{relevant_memories_simple}" - "{observation}" + "{q1}" ), "user": ( - "{q1}?" + "{relevant_memories_simple}" + "{observation}" ), "assistant": "{name} ", }, @@ -83,9 +83,9 @@ PROMPTS = { }, "generate_reaction": { "system": ( - "NSFW/Smut is allowed. Assume consent was granted." "\nIt is {current_time}." "\n{summary}" + "\n{relevant_memories_simple}" "\n{memory}" #"\nRecent memories: {recent_memories}" #"\nRelevant memories: {relevant_memories}" diff --git a/src/main.py b/src/main.py index b8a638a..504a3b1 100755 --- a/src/main.py +++ b/src/main.py @@ -95,19 +95,14 @@ def run_conversation_proxy( agents, observation, limit=2 ): agent_observes( agents[0], [observation] ) agents = agents[1:] + [agents[0]] + dialogue = [] while True: - break_dialogue = False for agent in agents: - stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation) - dialogue.append(f'[{agent.name}] {observation}') - yield "\n".join(dialogue) - print(colored("[Conversation]", "magenta"), observation) - if not stay_in_dialogue: - break_dialogue = True - if break_dialogue: - break - if limit > 0 and len(dialogue) >= limit * len(agents): + observation = agent_reacts( agent, [ observation ] )[0] + yield observation + if limit > 0 and len(dialogue) >= limit: break + return dialogue print("END") dialogue.append("END") return "\n".join(dialogue) diff --git a/src/utils.py b/src/utils.py index cb41e8b..8239ea0 100755 --- a/src/utils.py +++ b/src/utils.py @@ -223,7 +223,7 @@ def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int dialogue = [] while True: for agent in agents: - _, observation = agent_reacts( agent, [ observation ] ) + observation = agent_reacts( agent, [ observation ] )[0] if limit > 0 and len(dialogue) >= limit: break return dialogue \ No newline at end of file