tunings
This commit is contained in:
parent
f13d05dbb2
commit
287406e7ba
|
@ -29,7 +29,7 @@ Set your environment variables accordingly:
|
|||
- `OPENAI_API_MODEL`: target model
|
||||
* `LLM_MODEL`: (`./path/to/your/llama/model.bin`): path to your GGML-formatted LLaMA model, if using `llamacpp` as the LLM backend
|
||||
* `LLM_EMBEDDING_TYPE`: (`oai`, `llamacpp`, `hf`): the embedding model to use for similarity computing.
|
||||
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`, `supercot`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
||||
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`, `supercot`, `cocktail`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
||||
* `LLM_CONTEXT`: sets maximum context size
|
||||
|
||||
To run:
|
||||
|
@ -44,11 +44,11 @@ I ***do not*** plan on making this uber-user friendly like [mrq/ai-voice-cloning
|
|||
|
||||
## Caveats
|
||||
|
||||
A local LM is quite slow.
|
||||
A local LM is quite slow. Things seem to be getting faster as llama.cpp is being developed.
|
||||
|
||||
Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:\nASSISTANT:` structure of prompts), it's still inconsistent.
|
||||
|
||||
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so.
|
||||
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so. Cocktail 13B seems to be the best of the 13Bs.
|
||||
|
||||
A ***lot*** of prompt wrangling is needed, and a lot of the routines could be polished up (for example, an observation queries the LM for a rating, and each response reaction requires quering for the observed entity, then the relationship between an agent and observed entity which ends up just summarizing relevant context/memories, and then queries for a response), and if one of these steps fails, then the fail rate is higher. If anything, I might as well just work from the ground up and only really salvage the use of FAISS to store embedded-vectors.
|
||||
|
||||
|
|
|
@ -86,16 +86,26 @@ class GenerativeAgent(BaseModel):
|
|||
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
||||
)
|
||||
|
||||
def get_most_recent_memories(self, last_k: int = 4) -> str:
|
||||
def get_most_recent_memories(self, last_k: int = 8) -> str:
|
||||
memories = self.memory.memory_retriever.memory_stream[-last_k:]
|
||||
return [ document.page_content for document in memories ]
|
||||
return [ document.page_content.replace(u"\u200B", "").strip() for document in memories ]
|
||||
|
||||
def summarize_related_memories(self, observation: str) -> str:
|
||||
"""Summarize memories that are most relevant to an observation."""
|
||||
def get_relevant_memories(self, observation: str, first_k : int = 8) -> str:
|
||||
queries = [ observation ]
|
||||
relevant_memories = [
|
||||
mem.page_content.replace(u"\u200B", "").strip() for query in queries for mem in self.memory.fetch_memories(query)
|
||||
]
|
||||
relevant_memories = relevant_memories[:first_k]
|
||||
relevant_memories.reverse()
|
||||
return relevant_memories
|
||||
|
||||
"""
|
||||
def summarize_related_memories(self, observation: str, first_k : int = 4) -> str:
|
||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||
q1 = f"Summarize the relationship between the subjects in that interaction."
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
|
||||
query = f"Summarize the relationship between the subjects in that interaction in two sentences or less. Avoid repeating."
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), query=query, observation=observation, queries=[observation]).strip()
|
||||
return f'{self.name} {summary}'
|
||||
"""
|
||||
|
||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||
|
||||
|
@ -104,17 +114,27 @@ class GenerativeAgent(BaseModel):
|
|||
prompt = PromptTemplate.from_template(
|
||||
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
||||
)
|
||||
summary = self.get_summary().replace(u"\u200B", "").strip()
|
||||
relevant_memories = self.summarize_related_memories(observation).replace(u"\u200B", "").strip()
|
||||
recent_memories = "\n".join(self.get_most_recent_memories())
|
||||
summary = self.get_summary()
|
||||
relevant_memories = self.get_relevant_memories(observation)
|
||||
recent_memories = self.get_most_recent_memories()
|
||||
|
||||
# I think relevant_memories is suppose to only provide context for a relationship between agent and observer, as suggested with the query
|
||||
# but the original implementation seems to just leverage it to further filter relevant memories, per the name
|
||||
# avoid repeating
|
||||
memory = ""
|
||||
|
||||
if relevant_memories and relevant_memories != "N/A":
|
||||
memory = relevant_memories
|
||||
else:
|
||||
memory = "\n".join(self.get_most_recent_memories())
|
||||
for mem in relevant_memories:
|
||||
if mem in summary or mem in memory:
|
||||
continue
|
||||
memory += f"\n{mem}"
|
||||
|
||||
for mem in recent_memories:
|
||||
if mem in summary:
|
||||
continue
|
||||
if mem is observation:
|
||||
continue
|
||||
# erase it, move it to bottom
|
||||
if mem in memory:
|
||||
memory = memory.replace(f'{mem}\n', "")
|
||||
memory += f"\n{mem}"
|
||||
|
||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
kwargs: Dict[str, Any] = dict(
|
||||
|
@ -127,12 +147,23 @@ class GenerativeAgent(BaseModel):
|
|||
#recent_memories=recent_memories if recent_memories else "N/A",
|
||||
observation=observation if observation else "N/A",
|
||||
)
|
||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), queries=[observation], **kwargs).strip()
|
||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
|
||||
import re
|
||||
|
||||
emoji_pattern = re.compile("["
|
||||
u"\U0001F600-\U0001F64F" # emoticons
|
||||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
"]+", flags=re.UNICODE)
|
||||
reaction = emoji_pattern.sub(r'', reaction)
|
||||
|
||||
# cleanup
|
||||
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
||||
|
||||
for reaction in reactions:
|
||||
if reaction in summary or reaction in memory:
|
||||
continue
|
||||
if reaction:
|
||||
break
|
||||
|
||||
|
@ -140,20 +171,14 @@ class GenerativeAgent(BaseModel):
|
|||
print(reaction)
|
||||
return reaction
|
||||
|
||||
def _clean_response(self, text: str) -> str:
|
||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||
|
||||
def generate_response(self, observation: str) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = get_prompt('suffix_generate_response')
|
||||
full_result = f"{self.name} {self._generate_reaction(observation, call_to_action_template)}"
|
||||
|
||||
self.memory.save_context(
|
||||
{},
|
||||
{
|
||||
self.memory.add_memory_key: full_result
|
||||
},
|
||||
)
|
||||
full_result = ""
|
||||
while not full_result:
|
||||
full_result = f"{self._generate_reaction(observation, call_to_action_template)}"
|
||||
if full_result:
|
||||
break
|
||||
|
||||
return True, full_result
|
||||
|
||||
|
@ -191,7 +216,8 @@ class GenerativeAgent(BaseModel):
|
|||
f"Status: {self.status}"
|
||||
]
|
||||
|
||||
return "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\n{self.summary.strip()}"
|
||||
summary = "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\nSummary: {self.summary.strip()}"
|
||||
return summary.replace(u"\u200B", "").strip()
|
||||
|
||||
def get_full_header(self, force_refresh: bool = False) -> str:
|
||||
"""Return a full header of the agent's status, summary, and current time."""
|
||||
|
|
|
@ -71,6 +71,8 @@ class GenerativeAgentMemory(BaseMemory):
|
|||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||
most_recent_memories_key: str = "most_recent_memories"
|
||||
|
||||
reflecting: bool = False
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
||||
|
||||
|
@ -133,9 +135,10 @@ class GenerativeAgentMemory(BaseMemory):
|
|||
|
||||
return (float(2) / 10) * self.importance_weight
|
||||
|
||||
def add_memory(self, memory_content: str) -> List[str]:
|
||||
def add_memory(self, memory_content: str, importance_score: int = 0) -> List[str]:
|
||||
"""Add an observation or memory to the agent's memory."""
|
||||
importance_score = self._score_memory_importance(memory_content)
|
||||
if not importance_score:
|
||||
importance_score = self._score_memory_importance(memory_content)
|
||||
self.aggregate_importance += importance_score
|
||||
document = Document( page_content=memory_content, metadata={"importance": importance_score} )
|
||||
result = self.memory_retriever.add_documents([document])
|
||||
|
@ -146,10 +149,13 @@ class GenerativeAgentMemory(BaseMemory):
|
|||
if (
|
||||
self.reflection_threshold is not None
|
||||
and self.aggregate_importance > self.reflection_threshold
|
||||
and not self.reflecting
|
||||
):
|
||||
self.reflecting = True
|
||||
self.pause_to_reflect()
|
||||
# Hack to clear the importance from reflection
|
||||
self.aggregate_importance = 0.0
|
||||
self.reflecting = False
|
||||
|
||||
return (importance_score, result)
|
||||
|
||||
|
@ -169,7 +175,7 @@ class GenerativeAgentMemory(BaseMemory):
|
|||
return "\n".join([f"{mem}" for mem in content])
|
||||
|
||||
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
|
||||
return "; ".join([f"{mem.page_content}" for mem in relevant_memories]).replace(".;", ";")
|
||||
return "; ".join([f"{mem.page_content}" for mem in relevant_memories]).replace(".;", ".\n")
|
||||
|
||||
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
|
||||
"""Reduce the number of tokens in the documents."""
|
||||
|
|
|
@ -2,32 +2,10 @@ import os
|
|||
|
||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
||||
|
||||
USE_STOP_HINT = [ "llama" ]
|
||||
|
||||
PROMPTS = {
|
||||
"entity_from_observation": {
|
||||
"system": (
|
||||
"What is the observed entity in the following observation?"
|
||||
" ONLY report one object and write one sentence."
|
||||
),
|
||||
"user": (
|
||||
"{observation}"
|
||||
),
|
||||
"assistant": "Entity = ",
|
||||
},
|
||||
"entity_action": {
|
||||
"system": (
|
||||
"What is `{entity}` doing in the following observation?"
|
||||
" ONLY write one sentence."
|
||||
),
|
||||
"user": (
|
||||
"{observation}"
|
||||
),
|
||||
"assistant": "{entity} is ",
|
||||
},
|
||||
"summarize_related_memories": {
|
||||
"system": (
|
||||
"{q1}"
|
||||
"{query}"
|
||||
),
|
||||
"user": (
|
||||
"{relevant_memories_simple}"
|
||||
|
@ -44,7 +22,7 @@ PROMPTS = {
|
|||
"{summary}"
|
||||
"\n{relevant_memories_simple}"
|
||||
),
|
||||
"assistant": "",
|
||||
"assistant": "{name} ",
|
||||
},
|
||||
"topic_of_reflection": {
|
||||
"system": (
|
||||
|
@ -53,7 +31,7 @@ PROMPTS = {
|
|||
" Provide each question on a new line."
|
||||
),
|
||||
"user": (
|
||||
"{observations}"
|
||||
"Information: {observations}"
|
||||
),
|
||||
"assistant": "",
|
||||
},
|
||||
|
@ -77,24 +55,22 @@ PROMPTS = {
|
|||
"\nRespond with only a single integer."
|
||||
),
|
||||
"user": (
|
||||
"{memory_content}"
|
||||
"Event: {memory_content}"
|
||||
),
|
||||
"assistant": "",
|
||||
"assistant": "Rating: ",
|
||||
},
|
||||
"generate_reaction": {
|
||||
"system": (
|
||||
"\nIt is {current_time}."
|
||||
"[Write one reply. Always stay in character. Maintain a casual tone using beige prose. Be brief. Avoid repeating anything below.]"
|
||||
"\nCurrent Time: {current_time}"
|
||||
"\n{summary}"
|
||||
"\n{relevant_memories_simple}"
|
||||
"\n{memory}"
|
||||
#"\nRecent memories: {recent_memories}"
|
||||
#"\nRelevant memories: {relevant_memories}"
|
||||
"\n\n{suffix}"
|
||||
"\n{observation}"
|
||||
),
|
||||
"user": (
|
||||
"{observation}"
|
||||
"{suffix}"
|
||||
),
|
||||
"assistant": "{name} "
|
||||
"assistant": ""
|
||||
},
|
||||
|
||||
#
|
||||
|
@ -102,24 +78,7 @@ PROMPTS = {
|
|||
""
|
||||
),
|
||||
"suffix_generate_response": (
|
||||
"Given the following observation, how would {name} respond?"
|
||||
"\nWrite only one sentence."
|
||||
),
|
||||
|
||||
##
|
||||
"suffix_generate_reaction": (
|
||||
"Given the following observation, how would {name} appropriately react?"
|
||||
"\nIf the action is to engage in dialogue, only write `SAY: \"what to say\"`."
|
||||
"\nOr otherwise, only write `REACT: how to react`."
|
||||
"\nWrite ONLY one line, one sentence."
|
||||
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
||||
),
|
||||
"suffix_generate_dialogue": (
|
||||
"Given the following observation, what would {name} say?"
|
||||
"\nTo continue the conversation, only write: `SAY: \"what to say\"`."
|
||||
"\nOr otherwise, to end the conversation, only write: `GOODBYE: \"what to say\"`."
|
||||
"\nWrite ONLY one line, one sentence."
|
||||
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
||||
"Given the current situation, in one sentence, what is {name}'s next response?"
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -128,6 +87,7 @@ PROMPT_TUNES = {
|
|||
"vicuna": "{role}: {query}",
|
||||
"supercot": "{role}:\n{query}",
|
||||
"alpasta": "{role}# {query}",
|
||||
"cocktail": "{role}: {query}",
|
||||
}
|
||||
PROMPT_ROLES = {
|
||||
"vicuna": {
|
||||
|
@ -145,6 +105,11 @@ PROMPT_ROLES = {
|
|||
"user": "<|user|>",
|
||||
"assistant": "<|assistant|>",
|
||||
},
|
||||
"cocktail": {
|
||||
"system": "",
|
||||
"user": "USER",
|
||||
"assistant": "ASSOCIATE",
|
||||
},
|
||||
}
|
||||
|
||||
ROLES = [ "system", "user", "assistant" ]
|
||||
|
@ -153,7 +118,8 @@ ROLES = [ "system", "user", "assistant" ]
|
|||
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
||||
STOP_TOKENS = ["###"] + tokens
|
||||
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||
STOP_TOKENS.append(f'{role}')
|
||||
if role:
|
||||
STOP_TOKENS.append(f'{role}')
|
||||
return STOP_TOKENS
|
||||
|
||||
for k in PROMPTS:
|
||||
|
@ -204,4 +170,7 @@ def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
|||
output = output.replace("{query}", query)
|
||||
outputs.append(output)
|
||||
|
||||
return "\n".join(outputs)
|
||||
output = "\n".join(outputs)
|
||||
#if LLM_PROMPT_TUNE == "cocktail":
|
||||
output = output.strip()
|
||||
return output
|
28
src/main.py
28
src/main.py
|
@ -37,8 +37,7 @@ def agent_observes_proxy( agents, observations ):
|
|||
if agent not in AGENTS:
|
||||
load_agent( agent )
|
||||
agent = AGENTS[agent]
|
||||
observations = observations.split("\n")
|
||||
results = agent_observes( agent, observations )
|
||||
results = agent_observes( agent, observations.split("\n") )
|
||||
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
||||
return "\n".join(messages)
|
||||
|
||||
|
@ -51,8 +50,7 @@ def agent_reacts_proxy( agents, observations ):
|
|||
if agent not in AGENTS:
|
||||
load_agent( agent )
|
||||
agent = AGENTS[agent]
|
||||
observations = observations.split("\n")
|
||||
response = agent_reacts( agent, observations )
|
||||
response = agent_reacts( agent, observations.split("\n") )
|
||||
messages.append(f"[{agent.name}] {response}")
|
||||
return "\n".join(messages)
|
||||
|
||||
|
@ -80,29 +78,35 @@ def get_summary_proxy( agents ):
|
|||
messages.append(get_summary( agent, force_refresh = True ))
|
||||
return "\n".join(messages)
|
||||
|
||||
def run_conversation_proxy( agents, observation, limit=2 ):
|
||||
def run_conversation_proxy( agents, message, limit=4 ):
|
||||
agents = [ AGENTS[agent] for agent in agents ]
|
||||
|
||||
if len(agents) < 2:
|
||||
raise "Not enough agents"
|
||||
|
||||
dialogue = []
|
||||
dialogue.append(f'[{agents[0].name}] {observation}')
|
||||
dialogue.append(f'[{agents[0].name}] {message}')
|
||||
yield "\n".join(dialogue)
|
||||
|
||||
"""Runs a conversation between agents."""
|
||||
print(colored("[Conversation]", "magenta"))
|
||||
yield "\n".join(dialogue)
|
||||
agent_observes( agents[0], [observation] )
|
||||
importance_score = 0
|
||||
for agent in agents:
|
||||
importance_score = agent_observes( agent, [ message ], importance_score=importance_score )[0][0]
|
||||
agents = agents[1:] + [agents[0]]
|
||||
|
||||
dialogue = []
|
||||
while True:
|
||||
for agent in agents:
|
||||
observation = agent_reacts( agent, [ observation ] )[0]
|
||||
yield observation
|
||||
if limit > 0 and len(dialogue) >= limit:
|
||||
message = agent_reacts( agent, [ message ] )[0]
|
||||
importance_score = 0
|
||||
for a in agents:
|
||||
importance_score = agent_observes( a, [ message ], importance_score=importance_score )[0][0]
|
||||
|
||||
dialogue.append(f'[{agent.name}] {message}')
|
||||
yield "\n".join(dialogue)
|
||||
if limit > 0 and len(dialogue) >= limit * len(agents):
|
||||
break
|
||||
return dialogue
|
||||
print("END")
|
||||
dialogue.append("END")
|
||||
return "\n".join(dialogue)
|
||||
|
|
47
src/utils.py
47
src/utils.py
|
@ -27,14 +27,20 @@ from langchain.vectorstores import FAISS
|
|||
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
||||
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
||||
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
||||
"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
||||
"./models/ggml-vicuna-13b-cocktail-v1-q5_0.bin"
|
||||
#"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
||||
#"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
||||
#"./models/gpt4-x-alpasta-30b-ggml-q4_1.bin"
|
||||
)
|
||||
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
||||
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
||||
LLM_TEMPERATURE = float(os.environ.get('LLM_TEMPERATURE', '0.99'))
|
||||
EMBEDDING_TYPE = os.environ.get("LLM_EMBEDDING_TYPE", "hf") # options: llamacpp, oai, hf
|
||||
|
||||
#LLM_TYPE="oai"
|
||||
#os.environ['OPENAI_API_BASE']="https://oai.ecker.tech/proxy/openai"
|
||||
#os.environ['OPENAI_API_KEY']=""
|
||||
|
||||
# deduce a default given a model path
|
||||
if LLM_TYPE=="oai":
|
||||
LLM_PROMPT_TUNE_DEFAULT = "oai"
|
||||
|
@ -45,6 +51,8 @@ else:
|
|||
LLM_PROMPT_TUNE_DEFAULT = "vicuna"
|
||||
elif "alpasta" in LLM_LOCAL_MODEL.lower():
|
||||
LLM_PROMPT_TUNE_DEFAULT = "alpasta"
|
||||
elif "cocktail" in LLM_LOCAL_MODEL.lower():
|
||||
LLM_PROMPT_TUNE_DEFAULT = "cocktail"
|
||||
else:
|
||||
LLM_PROMPT_TUNE_DEFAULT = "llama"
|
||||
|
||||
|
@ -64,6 +72,7 @@ if LLM_TYPE=="llamacpp":
|
|||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
n_ctx=LLM_CONTEXT,
|
||||
temperature=LLM_TEMPERATURE,
|
||||
#n_threads=LLM_THREADS,
|
||||
#use_mlock=True,
|
||||
#use_mmap=True,
|
||||
|
@ -89,6 +98,7 @@ elif LLM_TYPE=="oai":
|
|||
|
||||
LLM = ChatOpenAI(
|
||||
max_tokens=LLM_CONTEXT,
|
||||
temperature=LLM_TEMPERATURE,
|
||||
model_name=os.environ.get('OPENAI_MODEL_NAME', 'gpt-4'),
|
||||
)
|
||||
|
||||
|
@ -98,7 +108,7 @@ else:
|
|||
if EMBEDDING_TYPE == "hf":
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
EMBEDDINGS_MODEL = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
||||
EMBEDDINGS_MODEL = HuggingFaceEmbeddings()
|
||||
EMBEDDINGS_SIZE = 768
|
||||
elif EMBEDDING_TYPE == "oai":
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
@ -110,10 +120,6 @@ elif EMBEDDING_TYPE == "llamacpp":
|
|||
|
||||
EMBEDDINGS_MODEL = LlamaCppEmbeddings(
|
||||
model_path=LLM_LOCAL_MODEL,
|
||||
n_ctx=LLM_CONTEXT,
|
||||
n_threads=LLM_THREADS,
|
||||
use_mlock=True,
|
||||
use_mmap=True,
|
||||
)
|
||||
EMBEDDINGS_SIZE = 5120
|
||||
else:
|
||||
|
@ -143,7 +149,7 @@ def _create_new_memories():
|
|||
memory_retriever=_create_new_memory_retriever(),
|
||||
reflection_threshold=8,
|
||||
verbose=True,
|
||||
max_tokens_limit=128 # LLM_CONTEXT/4
|
||||
max_tokens_limit=LLM_CONTEXT/2
|
||||
)
|
||||
|
||||
def create_agent(**kwargs):
|
||||
|
@ -190,40 +196,47 @@ def get_summary(agent: GenerativeAgent, force_refresh: bool = True) -> str:
|
|||
print(summary)
|
||||
return summary
|
||||
|
||||
def agent_observes( agent: GenerativeAgent, observations: List[str] ):
|
||||
def agent_observes( agent: GenerativeAgent, observations: List[str], importance_score=0 ):
|
||||
results = []
|
||||
for observation in observations:
|
||||
observation = observation.replace("{name}", agent.name)
|
||||
print(colored("[Observation]", "magenta"), observation)
|
||||
results.append(agent.memory.add_memory(observation))
|
||||
print(colored("[Observation]", "magenta"), f'[{agent.name}] {observation}')
|
||||
results.append(agent.memory.add_memory(observation, importance_score=importance_score))
|
||||
return results
|
||||
|
||||
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
|
||||
results = []
|
||||
for observation in observations:
|
||||
observation = observation.replace("{name}", agent.name)
|
||||
print(colored("[Observation]", "magenta"), observation)
|
||||
print(colored("[Observation]", "magenta"), f'[{agent.name}] {observation}')
|
||||
_, response = agent.generate_response(observation)
|
||||
print(colored("[Reaction]", "magenta"), response)
|
||||
print(colored("[Reaction]", "magenta"), f'[{agent.name}] {response}')
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
|
||||
def interview_agent(agent: GenerativeAgent, message: str) -> str:
|
||||
message = message.replace("{name}", agent.name)
|
||||
new_message = f"{username} says {message}"
|
||||
print(colored("[Interview]", "magenta"), message)
|
||||
return agent.generate_dialogue_response(new_message)
|
||||
print(colored("[Interview]", "magenta"), f"[User] {message}")
|
||||
_, response = agent.generate_response(message)
|
||||
print(colored("[Interview]", "magenta"), f"[{agent.name}] {response}")
|
||||
return response
|
||||
|
||||
|
||||
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
||||
print(colored("[Conversation]", "magenta"))
|
||||
agent_observes( agents[0], [observation] )
|
||||
for agent in agents:
|
||||
agent_observes( agent, [observation] )
|
||||
|
||||
agents = agents[1:] + [agents[0]]
|
||||
|
||||
dialogue = []
|
||||
while True:
|
||||
for agent in agents:
|
||||
observation = agent_reacts( agent, [ observation ] )[0]
|
||||
for a in agents:
|
||||
if a is agent:
|
||||
continue
|
||||
agent_observes( a, [ observation ] )
|
||||
if limit > 0 and len(dialogue) >= limit:
|
||||
break
|
||||
return dialogue
|
Loading…
Reference in New Issue
Block a user