updated requirements because I had installed this in WSL2
This commit is contained in:
parent
41e48497cd
commit
e152cd98a4
|
@ -50,4 +50,6 @@ Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:
|
||||||
|
|
||||||
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so.
|
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so.
|
||||||
|
|
||||||
|
A ***lot*** of prompt wrangling is needed, and a lot of the routines could be polished up (for example, an observation queries the LM for a rating, and each response reaction requires quering for the observed entity, then the relationship between an agent and observed entity which ends up just summarizing relevant context/memories, and then queries for a response), and if one of these steps fails, then the fail rate is higher. If anything, I might as well just work from the ground up and only really salvage the use of FAISS to store embedded-vectors.
|
||||||
|
|
||||||
GPT4 seems to Just Work, unfortunately.
|
GPT4 seems to Just Work, unfortunately.
|
|
@ -1,4 +1,6 @@
|
||||||
langchain
|
langchain
|
||||||
openai
|
openai
|
||||||
llamacpp
|
llama-cpp-python
|
||||||
gradio
|
gradio
|
||||||
|
faiss-cpu
|
||||||
|
termcolor
|
|
@ -34,7 +34,7 @@ from langchain.prompts import PromptTemplate
|
||||||
from langchain.schema import BaseLanguageModel
|
from langchain.schema import BaseLanguageModel
|
||||||
|
|
||||||
from .memory import GenerativeAgentMemory
|
from .memory import GenerativeAgentMemory
|
||||||
from .prompts import get_prompt
|
from .prompts import get_prompt, get_stop_tokens
|
||||||
|
|
||||||
class GenerativeAgent(BaseModel):
|
class GenerativeAgent(BaseModel):
|
||||||
"""A character with memory and innate characteristics."""
|
"""A character with memory and innate characteristics."""
|
||||||
|
@ -87,14 +87,14 @@ class GenerativeAgent(BaseModel):
|
||||||
|
|
||||||
def _get_entity_from_observation(self, observation: str) -> str:
|
def _get_entity_from_observation(self, observation: str) -> str:
|
||||||
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
|
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
|
||||||
response = self.chain(prompt).run(observation=observation).strip().replace("Entity=", "").replace("Entity: ", "") # OAI will keep this
|
response = self.chain(prompt).run(stop=get_stop_tokens([".", "(", "'"]), observation=observation).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
||||||
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
|
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
|
||||||
response = self.chain(prompt).run(entity=entity_name, observation=observation).strip()
|
response = self.chain(prompt).run(stop=get_stop_tokens(), entity=entity_name, observation=observation).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
return response
|
return response
|
||||||
|
@ -113,21 +113,23 @@ class GenerativeAgent(BaseModel):
|
||||||
|
|
||||||
entity_action = self._get_entity_action(observation, entity_name)
|
entity_action = self._get_entity_action(observation, entity_name)
|
||||||
q2 = f"{entity_name} is {entity_action}"
|
q2 = f"{entity_name} is {entity_action}"
|
||||||
summary = self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, queries=[q1, q2]).strip()
|
||||||
return summary
|
return f'{self.name} {summary}'
|
||||||
|
|
||||||
#return self.chain(prompt=prompt).run(q1=q1, q2=q2).strip()
|
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||||
|
|
||||||
def _generate_reaction(self, observation: str, suffix: str) -> str:
|
def _generate_reaction(self, observation: str, suffix: str) -> str:
|
||||||
"""React to a given observation or dialogue act."""
|
"""React to a given observation or dialogue act."""
|
||||||
prompt = PromptTemplate.from_template(
|
prompt = PromptTemplate.from_template(
|
||||||
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
||||||
)
|
)
|
||||||
summary = self.get_summary()
|
summary = self.get_summary().replace(u"\u200B", "").strip()
|
||||||
relevant_memories = self.summarize_related_memories(observation)
|
relevant_memories = self.summarize_related_memories(observation).replace(u"\u200B", "").strip()
|
||||||
|
recent_memories = "\n".join(self.get_most_recent_memories())
|
||||||
|
|
||||||
# I think relevant_memories is suppose to only provide context for a relationship between agent and observer, as suggested with the query
|
# I think relevant_memories is suppose to only provide context for a relationship between agent and observer, as suggested with the query
|
||||||
# but the original implementation seems to just leverage it to further filter relevant memories, per the name
|
# but the original implementation seems to just leverage it to further filter relevant memories, per the name
|
||||||
|
|
||||||
if relevant_memories and relevant_memories != "N/A":
|
if relevant_memories and relevant_memories != "N/A":
|
||||||
memory = relevant_memories
|
memory = relevant_memories
|
||||||
else:
|
else:
|
||||||
|
@ -140,9 +142,11 @@ class GenerativeAgent(BaseModel):
|
||||||
status=self.status if self.status else "N/A",
|
status=self.status if self.status else "N/A",
|
||||||
summary=summary if summary else "N/A",
|
summary=summary if summary else "N/A",
|
||||||
memory=memory if memory else "N/A",
|
memory=memory if memory else "N/A",
|
||||||
|
#relevant_memories=relevant_memories if relevant_memories else "N/A",
|
||||||
|
#recent_memories=recent_memories if recent_memories else "N/A",
|
||||||
observation=observation if observation else "N/A",
|
observation=observation if observation else "N/A",
|
||||||
)
|
)
|
||||||
reaction = self.chain(prompt=prompt).run(**kwargs).strip()
|
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(reaction)
|
print(reaction)
|
||||||
return reaction
|
return reaction
|
||||||
|
@ -150,6 +154,20 @@ class GenerativeAgent(BaseModel):
|
||||||
def _clean_response(self, text: str) -> str:
|
def _clean_response(self, text: str) -> str:
|
||||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||||
|
|
||||||
|
def generate_response(self, observation: str) -> Tuple[bool, str]:
|
||||||
|
"""React to a given observation."""
|
||||||
|
call_to_action_template = get_prompt('suffix_generate_response')
|
||||||
|
full_result = f"{self.name} {self._generate_reaction(observation, call_to_action_template)}"
|
||||||
|
|
||||||
|
self.memory.save_context(
|
||||||
|
{},
|
||||||
|
{
|
||||||
|
self.memory.add_memory_key: full_result
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return True, full_result
|
||||||
|
|
||||||
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
|
def generate_reaction(self, observation: str) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
|
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
|
||||||
|
@ -191,9 +209,9 @@ class GenerativeAgent(BaseModel):
|
||||||
return False, f"{self.name} did not react in a relevant way"
|
return False, f"{self.name} did not react in a relevant way"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]:
|
def generate_dialogue(self, observation: str) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
call_to_action_template = (get_prompt('suffix_generate_dialogue_response'))
|
call_to_action_template = (get_prompt('suffix_generate_dialogue'))
|
||||||
full_result = self._generate_reaction(observation, call_to_action_template)
|
full_result = self._generate_reaction(observation, call_to_action_template)
|
||||||
result = full_result.strip().split("\n")[0]
|
result = full_result.strip().split("\n")[0]
|
||||||
if "GOODBYE:" in result:
|
if "GOODBYE:" in result:
|
||||||
|
@ -227,7 +245,7 @@ class GenerativeAgent(BaseModel):
|
||||||
""""""
|
""""""
|
||||||
# The agent seeks to think about their core characteristics.
|
# The agent seeks to think about their core characteristics.
|
||||||
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
||||||
summary = self.chain(prompt).run(name=self.name, summary=self.summaries[-1] if len(self.summaries) else self.summary, queries=[f"{self.name}'s core characteristics"]).strip()
|
summary = self.chain(prompt).run(stop=get_stop_tokens(), name=self.name, summary=self.summaries[-1] if len(self.summaries) else self.summary, queries=[f"{self.name}'s core characteristics"]).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(summary)
|
print(summary)
|
||||||
return summary
|
return summary
|
||||||
|
@ -247,8 +265,8 @@ class GenerativeAgent(BaseModel):
|
||||||
|
|
||||||
values = [
|
values = [
|
||||||
f"Name: {self.name} (sex: {self.sex}, age: {self.age if self.age is not None else 'N/A'})",
|
f"Name: {self.name} (sex: {self.sex}, age: {self.age if self.age is not None else 'N/A'})",
|
||||||
f"\nInnate traits: {self.traits}",
|
f"Innate traits: {self.traits}",
|
||||||
f"\nStatus: {self.status}"
|
f"Status: {self.status}"
|
||||||
]
|
]
|
||||||
|
|
||||||
return "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\n{self.summary.strip()}"
|
return "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\n{self.summary.strip()}"
|
||||||
|
@ -259,4 +277,4 @@ class GenerativeAgent(BaseModel):
|
||||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||||
return (
|
return (
|
||||||
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
||||||
)
|
)
|
|
@ -34,7 +34,7 @@ from langchain.schema import BaseLanguageModel, BaseMemory, Document
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from .prompts import get_prompt
|
from .prompts import get_prompt, get_stop_tokens
|
||||||
|
|
||||||
class GenerativeAgentMemory(BaseMemory):
|
class GenerativeAgentMemory(BaseMemory):
|
||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
|
@ -84,7 +84,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
prompt = PromptTemplate.from_template(get_prompt("topic_of_reflection"))
|
prompt = PromptTemplate.from_template(get_prompt("topic_of_reflection"))
|
||||||
observations = self.memory_retriever.memory_stream[-last_k:]
|
observations = self.memory_retriever.memory_stream[-last_k:]
|
||||||
observation_str = "\n".join([o.page_content for o in observations])
|
observation_str = "\n".join([o.page_content for o in observations])
|
||||||
result = self.chain(prompt).run(observations=observation_str)
|
result = self.chain(prompt).run(stop=get_stop_tokens(), observations=observation_str)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
@ -100,9 +100,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
for i, memory in enumerate(related_memories)
|
for i, memory in enumerate(related_memories)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
result = self.chain(prompt).run(
|
result = self.chain(prompt).run( stop=get_stop_tokens(), topic=topic, related_statements=related_statements )
|
||||||
topic=topic, related_statements=related_statements
|
|
||||||
)
|
|
||||||
# TODO: Parse the connections between memories and insights
|
# TODO: Parse the connections between memories and insights
|
||||||
return self._parse_list(result)
|
return self._parse_list(result)
|
||||||
|
|
||||||
|
@ -122,7 +120,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
def _score_memory_importance(self, memory_content: str) -> float:
|
def _score_memory_importance(self, memory_content: str) -> float:
|
||||||
"""Score the absolute importance of the given memory."""
|
"""Score the absolute importance of the given memory."""
|
||||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||||
score = self.chain(prompt).run(memory_content=memory_content).strip()
|
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/"]), memory_content=memory_content).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Importance score: {score}")
|
print(f"Importance score: {score}")
|
||||||
try:
|
try:
|
||||||
|
@ -138,9 +136,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
"""Add an observation or memory to the agent's memory."""
|
"""Add an observation or memory to the agent's memory."""
|
||||||
importance_score = self._score_memory_importance(memory_content)
|
importance_score = self._score_memory_importance(memory_content)
|
||||||
self.aggregate_importance += importance_score
|
self.aggregate_importance += importance_score
|
||||||
document = Document(
|
document = Document( page_content=memory_content, metadata={"importance": importance_score} )
|
||||||
page_content=memory_content, metadata={"importance": importance_score}
|
|
||||||
)
|
|
||||||
result = self.memory_retriever.add_documents([document])
|
result = self.memory_retriever.add_documents([document])
|
||||||
|
|
||||||
# After an agent has processed a certain amount of memories (as measured by
|
# After an agent has processed a certain amount of memories (as measured by
|
||||||
|
@ -198,20 +194,14 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
mem for query in queries for mem in self.fetch_memories(query)
|
mem for query in queries for mem in self.fetch_memories(query)
|
||||||
]
|
]
|
||||||
return {
|
return {
|
||||||
self.relevant_memories_key: self.format_memories_detail(
|
self.relevant_memories_key: self.format_memories_detail( relevant_memories ),
|
||||||
relevant_memories
|
self.relevant_memories_simple_key: self.format_memories_simple( relevant_memories ),
|
||||||
),
|
|
||||||
self.relevant_memories_simple_key: self.format_memories_simple(
|
|
||||||
relevant_memories
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
most_recent_memories_token = inputs.get(self.most_recent_memories_token_key)
|
most_recent_memories_token = inputs.get(self.most_recent_memories_token_key)
|
||||||
if most_recent_memories_token is not None:
|
if most_recent_memories_token is not None:
|
||||||
return {
|
return {
|
||||||
self.most_recent_memories_key: self._get_memories_until_limit(
|
self.most_recent_memories_key: self._get_memories_until_limit( most_recent_memories_token )
|
||||||
most_recent_memories_token
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
||||||
STOP_TOKEN_HINT = "" # "\nWrite \"END\" afterwards."
|
|
||||||
|
|
||||||
USE_STOP_HINT = [ "llama" ]
|
USE_STOP_HINT = [ "llama" ]
|
||||||
|
|
||||||
|
@ -10,57 +9,50 @@ PROMPTS = {
|
||||||
"system": (
|
"system": (
|
||||||
"What is the observed entity in the following observation?"
|
"What is the observed entity in the following observation?"
|
||||||
" ONLY report one object and write one sentence."
|
" ONLY report one object and write one sentence."
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Observation: {observation}"
|
"{observation}"
|
||||||
),
|
),
|
||||||
"assistant": "Entity=",
|
"assistant": "Entity = ",
|
||||||
},
|
},
|
||||||
"entity_action": {
|
"entity_action": {
|
||||||
"system": (
|
"system": (
|
||||||
"What is the following entity doing in the following observation?"
|
"What is `{entity}` doing in the following observation?"
|
||||||
" ONLY write one sentence."
|
" ONLY write one sentence."
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Entity: {entity}"
|
"{observation}"
|
||||||
"\nObservation: {observation}"
|
|
||||||
),
|
),
|
||||||
"assistant": "`{entity}` is ",
|
"assistant": "{entity} is ",
|
||||||
},
|
},
|
||||||
"summarize_related_memories": {
|
"summarize_related_memories": {
|
||||||
"system": (
|
"system": (
|
||||||
"Given the following context, answer the following question."
|
"{relevant_memories_simple}"
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Context: {relevant_memories_simple}"
|
"{q1}?"
|
||||||
"\nQuestion: {q1}?"
|
|
||||||
),
|
),
|
||||||
"assistant": "Summary of relevant context: ",
|
"assistant": "{name} ",
|
||||||
},
|
},
|
||||||
"compute_agent_summary": {
|
"compute_agent_summary": {
|
||||||
"system": (
|
"system": (
|
||||||
"Given the following previous summary and the following statements, how would you summarize {name}'s core characteristics?"
|
"Given the following previous summary and the following statements, how would you summarize {name}'s core characteristics?"
|
||||||
" Do not embellish under any circumstances."
|
" Do not embellish under any circumstances."
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Previous summary: {summary}\n"
|
"{summary}"
|
||||||
"Statements: {relevant_memories_simple}"
|
"\n{relevant_memories_simple}"
|
||||||
),
|
),
|
||||||
"assistant": "Summary: ",
|
"assistant": "",
|
||||||
},
|
},
|
||||||
"topic_of_reflection": {
|
"topic_of_reflection": {
|
||||||
"system": (
|
"system": (
|
||||||
"Given only the following information, what are the 3 most salient"
|
"Given only the following information, what are the 3 most salient"
|
||||||
" high-level questions we can answer about the subjects in the statements?"
|
" high-level questions we can answer about the subjects in the statements?"
|
||||||
" Provide each question on a new line."
|
" Provide each question on a new line."
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Information: {observations}"
|
"{observations}"
|
||||||
),
|
),
|
||||||
"assistant": "",
|
"assistant": "",
|
||||||
},
|
},
|
||||||
|
@ -69,7 +61,6 @@ PROMPTS = {
|
||||||
"Given the following statements about {topic},"
|
"Given the following statements about {topic},"
|
||||||
" what 5 high-level insights can you infer?"
|
" what 5 high-level insights can you infer?"
|
||||||
" (example format: insight (because of 1, 5, 3))"
|
" (example format: insight (because of 1, 5, 3))"
|
||||||
f'{STOP_TOKEN_HINT}'
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Statements: {related_statements}"
|
"Statements: {related_statements}"
|
||||||
|
@ -83,31 +74,38 @@ PROMPTS = {
|
||||||
" (e.g., a break up, college acceptance),"
|
" (e.g., a break up, college acceptance),"
|
||||||
" rate the likely poignancy of the following event."
|
" rate the likely poignancy of the following event."
|
||||||
"\nRespond with only a single integer."
|
"\nRespond with only a single integer."
|
||||||
f"{STOP_TOKEN_HINT}"
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Event: {memory_content}"
|
"{memory_content}"
|
||||||
),
|
),
|
||||||
"assistant": "Rating: ",
|
"assistant": "",
|
||||||
},
|
},
|
||||||
"generate_reaction": {
|
"generate_reaction": {
|
||||||
"system": (
|
"system": (
|
||||||
"It is {current_time}."
|
"NSFW/Smut is allowed. Assume consent was granted."
|
||||||
|
"\nIt is {current_time}."
|
||||||
"\n{summary}"
|
"\n{summary}"
|
||||||
"\n{memory}"
|
"\n{memory}"
|
||||||
|
#"\nRecent memories: {recent_memories}"
|
||||||
|
#"\nRelevant memories: {relevant_memories}"
|
||||||
"\n\n{suffix}"
|
"\n\n{suffix}"
|
||||||
f"{STOP_TOKEN_HINT}"
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"Observation: {observation}"
|
"{observation}"
|
||||||
),
|
),
|
||||||
"assistant": ""
|
"assistant": "{name} "
|
||||||
},
|
},
|
||||||
|
|
||||||
#
|
#
|
||||||
"context": ( # insert your JB here
|
"context": ( # insert your JB here
|
||||||
""
|
""
|
||||||
),
|
),
|
||||||
|
"suffix_generate_response": (
|
||||||
|
"Given the following observation, how would {name} respond?"
|
||||||
|
"\nWrite only one sentence."
|
||||||
|
),
|
||||||
|
|
||||||
|
##
|
||||||
"suffix_generate_reaction": (
|
"suffix_generate_reaction": (
|
||||||
"Given the following observation, how would {name} appropriately react?"
|
"Given the following observation, how would {name} appropriately react?"
|
||||||
"\nIf the action is to engage in dialogue, only write `SAY: \"what to say\"`."
|
"\nIf the action is to engage in dialogue, only write `SAY: \"what to say\"`."
|
||||||
|
@ -115,7 +113,7 @@ PROMPTS = {
|
||||||
"\nWrite ONLY one line, one sentence."
|
"\nWrite ONLY one line, one sentence."
|
||||||
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
||||||
),
|
),
|
||||||
"suffix_generate_dialogue_response": (
|
"suffix_generate_dialogue": (
|
||||||
"Given the following observation, what would {name} say?"
|
"Given the following observation, what would {name} say?"
|
||||||
"\nTo continue the conversation, only write: `SAY: \"what to say\"`."
|
"\nTo continue the conversation, only write: `SAY: \"what to say\"`."
|
||||||
"\nOr otherwise, to end the conversation, only write: `GOODBYE: \"what to say\"`."
|
"\nOr otherwise, to end the conversation, only write: `GOODBYE: \"what to say\"`."
|
||||||
|
@ -128,6 +126,7 @@ PROMPT_TUNES = {
|
||||||
"default": "{query}",
|
"default": "{query}",
|
||||||
"vicuna": "{role}: {query}",
|
"vicuna": "{role}: {query}",
|
||||||
"supercot": "{role}:\n{query}",
|
"supercot": "{role}:\n{query}",
|
||||||
|
"alpasta": "{role}# {query}",
|
||||||
}
|
}
|
||||||
PROMPT_ROLES = {
|
PROMPT_ROLES = {
|
||||||
"vicuna": {
|
"vicuna": {
|
||||||
|
@ -139,11 +138,23 @@ PROMPT_ROLES = {
|
||||||
"system": "### Instruction",
|
"system": "### Instruction",
|
||||||
"user": "### Input",
|
"user": "### Input",
|
||||||
"assistant": "### Response",
|
"assistant": "### Response",
|
||||||
}
|
},
|
||||||
|
"alpasta": {
|
||||||
|
"system": "<|system|>",
|
||||||
|
"user": "<|user|>",
|
||||||
|
"assistant": "<|assistant|>",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ROLES = [ "system", "user", "assistant" ]
|
ROLES = [ "system", "user", "assistant" ]
|
||||||
|
|
||||||
|
|
||||||
|
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
||||||
|
STOP_TOKENS = ["###"] + tokens
|
||||||
|
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||||
|
STOP_TOKENS.append(f'{role}')
|
||||||
|
return STOP_TOKENS
|
||||||
|
|
||||||
for k in PROMPTS:
|
for k in PROMPTS:
|
||||||
if k == "context":
|
if k == "context":
|
||||||
continue
|
continue
|
||||||
|
@ -187,10 +198,6 @@ def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||||
if role in roles:
|
if role in roles:
|
||||||
role = roles[role]
|
role = roles[role]
|
||||||
|
|
||||||
# remove stop token hinting if we're using OAI since I don't have control over early terminating
|
|
||||||
if STOP_TOKEN_HINT in query and tune in USE_STOP_HINT:
|
|
||||||
query = query.replace(STOP_TOKEN_HINT, "")
|
|
||||||
|
|
||||||
output = f'{PROMPT_TUNES[tune]}'
|
output = f'{PROMPT_TUNES[tune]}'
|
||||||
output = output.replace("{role}", role)
|
output = output.replace("{role}", role)
|
||||||
output = output.replace("{query}", query)
|
output = output.replace("{query}", query)
|
||||||
|
|
37
src/main.py
37
src/main.py
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
from utils import create_agent, agent_observes, interview_agent, run_conversation, get_summary, save_agent, load_agent
|
||||||
|
|
||||||
|
@ -65,10 +66,42 @@ def get_summary_proxy( agents ):
|
||||||
messages.append(get_summary( agent, force_refresh = True ))
|
messages.append(get_summary( agent, force_refresh = True ))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def run_conversation_proxy( agents, message ):
|
def run_conversation_proxy( agents, observation, limit=2 ):
|
||||||
agents = [ AGENTS[agent] for agent in agents ]
|
agents = [ AGENTS[agent] for agent in agents ]
|
||||||
messages = run_conversation( agents, message, limit=len(agents)*2 )
|
|
||||||
|
if len(agents) < 2:
|
||||||
|
raise "Not enough agents"
|
||||||
|
|
||||||
|
dialogue = []
|
||||||
|
dialogue.append(f'[{agents[0].name}] {observation}')
|
||||||
|
|
||||||
|
"""Runs a conversation between agents."""
|
||||||
|
print(colored("[Conversation]", "magenta"))
|
||||||
|
yield "\n".join(dialogue)
|
||||||
|
agent_observes( agents[0], [observation] )
|
||||||
|
agents = agents[1:] + [agents[0]]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
break_dialogue = False
|
||||||
|
for agent in agents:
|
||||||
|
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
||||||
|
dialogue.append(f'[{agent.name}] {observation}')
|
||||||
|
yield "\n".join(dialogue)
|
||||||
|
print(colored("[Conversation]", "magenta"), observation)
|
||||||
|
if not stay_in_dialogue:
|
||||||
|
break_dialogue = True
|
||||||
|
if break_dialogue:
|
||||||
|
break
|
||||||
|
if limit > 0 and len(dialogue) >= limit * len(agents):
|
||||||
|
break
|
||||||
|
print("END")
|
||||||
|
dialogue.append("END")
|
||||||
|
return "\n".join(dialogue)
|
||||||
|
|
||||||
|
"""
|
||||||
|
messages = run_conversation( agents, observation, limit=len(agents)*2 )
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
"""
|
||||||
|
|
||||||
def view_agent( agents, last_k = 50 ):
|
def view_agent( agents, last_k = 50 ):
|
||||||
if not isinstance( agents, list ):
|
if not isinstance( agents, list ):
|
||||||
|
|
24
src/utils.py
24
src/utils.py
|
@ -25,8 +25,9 @@ from langchain.vectorstores import FAISS
|
||||||
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
||||||
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
||||||
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
||||||
#"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
||||||
"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
#"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
||||||
|
#"./models/gpt4-x-alpasta-30b-ggml-q4_1.bin"
|
||||||
)
|
)
|
||||||
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
||||||
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
||||||
|
@ -40,6 +41,8 @@ else:
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "supercot"
|
LLM_PROMPT_TUNE_DEFAULT = "supercot"
|
||||||
elif "vicuna" in LLM_LOCAL_MODEL.lower():
|
elif "vicuna" in LLM_LOCAL_MODEL.lower():
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "vicuna"
|
LLM_PROMPT_TUNE_DEFAULT = "vicuna"
|
||||||
|
elif "alpasta" in LLM_LOCAL_MODEL.lower():
|
||||||
|
LLM_PROMPT_TUNE_DEFAULT = "alpasta"
|
||||||
else:
|
else:
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "llama"
|
LLM_PROMPT_TUNE_DEFAULT = "llama"
|
||||||
|
|
||||||
|
@ -51,10 +54,6 @@ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) # unncess
|
||||||
# Overrides for some fixes, like scoring memory and LLM-specific promptings
|
# Overrides for some fixes, like scoring memory and LLM-specific promptings
|
||||||
from ext import GenerativeAgent, GenerativeAgentMemory, get_roles
|
from ext import GenerativeAgent, GenerativeAgentMemory, get_roles
|
||||||
|
|
||||||
STOP_TOKENS = ["END"]
|
|
||||||
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
|
||||||
STOP_TOKENS.append(f'{role}:')
|
|
||||||
|
|
||||||
if LLM_TYPE=="llamacpp":
|
if LLM_TYPE=="llamacpp":
|
||||||
from langchain.llms import LlamaCpp
|
from langchain.llms import LlamaCpp
|
||||||
|
|
||||||
|
@ -64,9 +63,8 @@ if LLM_TYPE=="llamacpp":
|
||||||
verbose=True,
|
verbose=True,
|
||||||
n_ctx=LLM_CONTEXT,
|
n_ctx=LLM_CONTEXT,
|
||||||
#n_threads=LLM_THREADS,
|
#n_threads=LLM_THREADS,
|
||||||
use_mlock=True,
|
#use_mlock=True,
|
||||||
use_mmap=True,
|
#use_mmap=True,
|
||||||
stop=STOP_TOKENS
|
|
||||||
)
|
)
|
||||||
elif LLM_TYPE=="oai":
|
elif LLM_TYPE=="oai":
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
@ -95,7 +93,6 @@ elif LLM_TYPE=="oai":
|
||||||
else:
|
else:
|
||||||
raise f"Invalid LLM type: {LLM_TYPE}"
|
raise f"Invalid LLM type: {LLM_TYPE}"
|
||||||
|
|
||||||
|
|
||||||
if EMBEDDING_TYPE == "hf":
|
if EMBEDDING_TYPE == "hf":
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
@ -144,7 +141,7 @@ def _create_new_memories():
|
||||||
memory_retriever=_create_new_memory_retriever(),
|
memory_retriever=_create_new_memory_retriever(),
|
||||||
reflection_threshold=8,
|
reflection_threshold=8,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
max_tokens_limit=256 # LLM_CONTEXT/4
|
max_tokens_limit=128 # LLM_CONTEXT/4
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_agent(**kwargs):
|
def create_agent(**kwargs):
|
||||||
|
@ -210,13 +207,13 @@ def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int
|
||||||
"""Runs a conversation between agents."""
|
"""Runs a conversation between agents."""
|
||||||
print(colored("[Conversation]", "magenta"))
|
print(colored("[Conversation]", "magenta"))
|
||||||
agent_observes( agents[0], [observation] )
|
agent_observes( agents[0], [observation] )
|
||||||
|
agents = agents[1:] + [agents[0]]
|
||||||
|
|
||||||
dialogue = []
|
dialogue = []
|
||||||
while True:
|
while True:
|
||||||
break_dialogue = False
|
break_dialogue = False
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
stay_in_dialogue, observation = agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
stay_in_dialogue, observation = agent.generate_response(observation) # agent.generate_reaction(observation) if random.random() < p_reaction else agent.generate_dialogue_response(observation)
|
||||||
yield observation
|
|
||||||
dialogue.append(observation)
|
dialogue.append(observation)
|
||||||
print(colored("[Conversation]", "magenta"), observation)
|
print(colored("[Conversation]", "magenta"), observation)
|
||||||
if not stay_in_dialogue:
|
if not stay_in_dialogue:
|
||||||
|
@ -225,4 +222,5 @@ def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int
|
||||||
break
|
break
|
||||||
if limit > 0 and len(dialogue) >= limit:
|
if limit > 0 and len(dialogue) >= limit:
|
||||||
break
|
break
|
||||||
|
agent_observes( agent, [observation] )
|
||||||
return dialogue
|
return dialogue
|
Loading…
Reference in New Issue
Block a user