tunings
This commit is contained in:
parent
f13d05dbb2
commit
287406e7ba
|
@ -29,7 +29,7 @@ Set your environment variables accordingly:
|
||||||
- `OPENAI_API_MODEL`: target model
|
- `OPENAI_API_MODEL`: target model
|
||||||
* `LLM_MODEL`: (`./path/to/your/llama/model.bin`): path to your GGML-formatted LLaMA model, if using `llamacpp` as the LLM backend
|
* `LLM_MODEL`: (`./path/to/your/llama/model.bin`): path to your GGML-formatted LLaMA model, if using `llamacpp` as the LLM backend
|
||||||
* `LLM_EMBEDDING_TYPE`: (`oai`, `llamacpp`, `hf`): the embedding model to use for similarity computing.
|
* `LLM_EMBEDDING_TYPE`: (`oai`, `llamacpp`, `hf`): the embedding model to use for similarity computing.
|
||||||
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`, `supercot`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`, `supercot`, `cocktail`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
||||||
* `LLM_CONTEXT`: sets maximum context size
|
* `LLM_CONTEXT`: sets maximum context size
|
||||||
|
|
||||||
To run:
|
To run:
|
||||||
|
@ -44,11 +44,11 @@ I ***do not*** plan on making this uber-user friendly like [mrq/ai-voice-cloning
|
||||||
|
|
||||||
## Caveats
|
## Caveats
|
||||||
|
|
||||||
A local LM is quite slow.
|
A local LM is quite slow. Things seem to be getting faster as llama.cpp is being developed.
|
||||||
|
|
||||||
Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:\nASSISTANT:` structure of prompts), it's still inconsistent.
|
Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:\nASSISTANT:` structure of prompts), it's still inconsistent.
|
||||||
|
|
||||||
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so.
|
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow. SuperCOT 13B seems to be giving better answers over Vicuna-1.1 13B, so. Cocktail 13B seems to be the best of the 13Bs.
|
||||||
|
|
||||||
A ***lot*** of prompt wrangling is needed, and a lot of the routines could be polished up (for example, an observation queries the LM for a rating, and each response reaction requires quering for the observed entity, then the relationship between an agent and observed entity which ends up just summarizing relevant context/memories, and then queries for a response), and if one of these steps fails, then the fail rate is higher. If anything, I might as well just work from the ground up and only really salvage the use of FAISS to store embedded-vectors.
|
A ***lot*** of prompt wrangling is needed, and a lot of the routines could be polished up (for example, an observation queries the LM for a rating, and each response reaction requires quering for the observed entity, then the relationship between an agent and observed entity which ends up just summarizing relevant context/memories, and then queries for a response), and if one of these steps fails, then the fail rate is higher. If anything, I might as well just work from the ground up and only really salvage the use of FAISS to store embedded-vectors.
|
||||||
|
|
||||||
|
|
|
@ -86,16 +86,26 @@ class GenerativeAgent(BaseModel):
|
||||||
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_most_recent_memories(self, last_k: int = 4) -> str:
|
def get_most_recent_memories(self, last_k: int = 8) -> str:
|
||||||
memories = self.memory.memory_retriever.memory_stream[-last_k:]
|
memories = self.memory.memory_retriever.memory_stream[-last_k:]
|
||||||
return [ document.page_content for document in memories ]
|
return [ document.page_content.replace(u"\u200B", "").strip() for document in memories ]
|
||||||
|
|
||||||
def summarize_related_memories(self, observation: str) -> str:
|
def get_relevant_memories(self, observation: str, first_k : int = 8) -> str:
|
||||||
"""Summarize memories that are most relevant to an observation."""
|
queries = [ observation ]
|
||||||
|
relevant_memories = [
|
||||||
|
mem.page_content.replace(u"\u200B", "").strip() for query in queries for mem in self.memory.fetch_memories(query)
|
||||||
|
]
|
||||||
|
relevant_memories = relevant_memories[:first_k]
|
||||||
|
relevant_memories.reverse()
|
||||||
|
return relevant_memories
|
||||||
|
|
||||||
|
"""
|
||||||
|
def summarize_related_memories(self, observation: str, first_k : int = 4) -> str:
|
||||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||||
q1 = f"Summarize the relationship between the subjects in that interaction."
|
query = f"Summarize the relationship between the subjects in that interaction in two sentences or less. Avoid repeating."
|
||||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), q1=q1, observation=observation, queries=[observation]).strip()
|
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), query=query, observation=observation, queries=[observation]).strip()
|
||||||
return f'{self.name} {summary}'
|
return f'{self.name} {summary}'
|
||||||
|
"""
|
||||||
|
|
||||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||||
|
|
||||||
|
@ -104,17 +114,27 @@ class GenerativeAgent(BaseModel):
|
||||||
prompt = PromptTemplate.from_template(
|
prompt = PromptTemplate.from_template(
|
||||||
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
||||||
)
|
)
|
||||||
summary = self.get_summary().replace(u"\u200B", "").strip()
|
summary = self.get_summary()
|
||||||
relevant_memories = self.summarize_related_memories(observation).replace(u"\u200B", "").strip()
|
relevant_memories = self.get_relevant_memories(observation)
|
||||||
recent_memories = "\n".join(self.get_most_recent_memories())
|
recent_memories = self.get_most_recent_memories()
|
||||||
|
|
||||||
# I think relevant_memories is suppose to only provide context for a relationship between agent and observer, as suggested with the query
|
# avoid repeating
|
||||||
# but the original implementation seems to just leverage it to further filter relevant memories, per the name
|
memory = ""
|
||||||
|
|
||||||
if relevant_memories and relevant_memories != "N/A":
|
for mem in relevant_memories:
|
||||||
memory = relevant_memories
|
if mem in summary or mem in memory:
|
||||||
else:
|
continue
|
||||||
memory = "\n".join(self.get_most_recent_memories())
|
memory += f"\n{mem}"
|
||||||
|
|
||||||
|
for mem in recent_memories:
|
||||||
|
if mem in summary:
|
||||||
|
continue
|
||||||
|
if mem is observation:
|
||||||
|
continue
|
||||||
|
# erase it, move it to bottom
|
||||||
|
if mem in memory:
|
||||||
|
memory = memory.replace(f'{mem}\n', "")
|
||||||
|
memory += f"\n{mem}"
|
||||||
|
|
||||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||||
kwargs: Dict[str, Any] = dict(
|
kwargs: Dict[str, Any] = dict(
|
||||||
|
@ -127,12 +147,23 @@ class GenerativeAgent(BaseModel):
|
||||||
#recent_memories=recent_memories if recent_memories else "N/A",
|
#recent_memories=recent_memories if recent_memories else "N/A",
|
||||||
observation=observation if observation else "N/A",
|
observation=observation if observation else "N/A",
|
||||||
)
|
)
|
||||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), queries=[observation], **kwargs).strip()
|
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
|
||||||
|
import re
|
||||||
|
|
||||||
|
emoji_pattern = re.compile("["
|
||||||
|
u"\U0001F600-\U0001F64F" # emoticons
|
||||||
|
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||||
|
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||||
|
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||||
|
"]+", flags=re.UNICODE)
|
||||||
|
reaction = emoji_pattern.sub(r'', reaction)
|
||||||
|
|
||||||
# cleanup
|
# cleanup
|
||||||
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
||||||
|
|
||||||
for reaction in reactions:
|
for reaction in reactions:
|
||||||
|
if reaction in summary or reaction in memory:
|
||||||
|
continue
|
||||||
if reaction:
|
if reaction:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -140,20 +171,14 @@ class GenerativeAgent(BaseModel):
|
||||||
print(reaction)
|
print(reaction)
|
||||||
return reaction
|
return reaction
|
||||||
|
|
||||||
def _clean_response(self, text: str) -> str:
|
|
||||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
|
||||||
|
|
||||||
def generate_response(self, observation: str) -> Tuple[bool, str]:
|
def generate_response(self, observation: str) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
call_to_action_template = get_prompt('suffix_generate_response')
|
call_to_action_template = get_prompt('suffix_generate_response')
|
||||||
full_result = f"{self.name} {self._generate_reaction(observation, call_to_action_template)}"
|
full_result = ""
|
||||||
|
while not full_result:
|
||||||
self.memory.save_context(
|
full_result = f"{self._generate_reaction(observation, call_to_action_template)}"
|
||||||
{},
|
if full_result:
|
||||||
{
|
break
|
||||||
self.memory.add_memory_key: full_result
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return True, full_result
|
return True, full_result
|
||||||
|
|
||||||
|
@ -191,7 +216,8 @@ class GenerativeAgent(BaseModel):
|
||||||
f"Status: {self.status}"
|
f"Status: {self.status}"
|
||||||
]
|
]
|
||||||
|
|
||||||
return "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\n{self.summary.strip()}"
|
summary = "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\nSummary: {self.summary.strip()}"
|
||||||
|
return summary.replace(u"\u200B", "").strip()
|
||||||
|
|
||||||
def get_full_header(self, force_refresh: bool = False) -> str:
|
def get_full_header(self, force_refresh: bool = False) -> str:
|
||||||
"""Return a full header of the agent's status, summary, and current time."""
|
"""Return a full header of the agent's status, summary, and current time."""
|
||||||
|
|
|
@ -71,6 +71,8 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||||
most_recent_memories_key: str = "most_recent_memories"
|
most_recent_memories_key: str = "most_recent_memories"
|
||||||
|
|
||||||
|
reflecting: bool = False
|
||||||
|
|
||||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||||
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
||||||
|
|
||||||
|
@ -133,9 +135,10 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
|
|
||||||
return (float(2) / 10) * self.importance_weight
|
return (float(2) / 10) * self.importance_weight
|
||||||
|
|
||||||
def add_memory(self, memory_content: str) -> List[str]:
|
def add_memory(self, memory_content: str, importance_score: int = 0) -> List[str]:
|
||||||
"""Add an observation or memory to the agent's memory."""
|
"""Add an observation or memory to the agent's memory."""
|
||||||
importance_score = self._score_memory_importance(memory_content)
|
if not importance_score:
|
||||||
|
importance_score = self._score_memory_importance(memory_content)
|
||||||
self.aggregate_importance += importance_score
|
self.aggregate_importance += importance_score
|
||||||
document = Document( page_content=memory_content, metadata={"importance": importance_score} )
|
document = Document( page_content=memory_content, metadata={"importance": importance_score} )
|
||||||
result = self.memory_retriever.add_documents([document])
|
result = self.memory_retriever.add_documents([document])
|
||||||
|
@ -146,10 +149,13 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
if (
|
if (
|
||||||
self.reflection_threshold is not None
|
self.reflection_threshold is not None
|
||||||
and self.aggregate_importance > self.reflection_threshold
|
and self.aggregate_importance > self.reflection_threshold
|
||||||
|
and not self.reflecting
|
||||||
):
|
):
|
||||||
|
self.reflecting = True
|
||||||
self.pause_to_reflect()
|
self.pause_to_reflect()
|
||||||
# Hack to clear the importance from reflection
|
# Hack to clear the importance from reflection
|
||||||
self.aggregate_importance = 0.0
|
self.aggregate_importance = 0.0
|
||||||
|
self.reflecting = False
|
||||||
|
|
||||||
return (importance_score, result)
|
return (importance_score, result)
|
||||||
|
|
||||||
|
@ -169,7 +175,7 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
return "\n".join([f"{mem}" for mem in content])
|
return "\n".join([f"{mem}" for mem in content])
|
||||||
|
|
||||||
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
|
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
|
||||||
return "; ".join([f"{mem.page_content}" for mem in relevant_memories]).replace(".;", ";")
|
return "; ".join([f"{mem.page_content}" for mem in relevant_memories]).replace(".;", ".\n")
|
||||||
|
|
||||||
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
|
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
|
||||||
"""Reduce the number of tokens in the documents."""
|
"""Reduce the number of tokens in the documents."""
|
||||||
|
|
|
@ -2,32 +2,10 @@ import os
|
||||||
|
|
||||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
||||||
|
|
||||||
USE_STOP_HINT = [ "llama" ]
|
|
||||||
|
|
||||||
PROMPTS = {
|
PROMPTS = {
|
||||||
"entity_from_observation": {
|
|
||||||
"system": (
|
|
||||||
"What is the observed entity in the following observation?"
|
|
||||||
" ONLY report one object and write one sentence."
|
|
||||||
),
|
|
||||||
"user": (
|
|
||||||
"{observation}"
|
|
||||||
),
|
|
||||||
"assistant": "Entity = ",
|
|
||||||
},
|
|
||||||
"entity_action": {
|
|
||||||
"system": (
|
|
||||||
"What is `{entity}` doing in the following observation?"
|
|
||||||
" ONLY write one sentence."
|
|
||||||
),
|
|
||||||
"user": (
|
|
||||||
"{observation}"
|
|
||||||
),
|
|
||||||
"assistant": "{entity} is ",
|
|
||||||
},
|
|
||||||
"summarize_related_memories": {
|
"summarize_related_memories": {
|
||||||
"system": (
|
"system": (
|
||||||
"{q1}"
|
"{query}"
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{relevant_memories_simple}"
|
"{relevant_memories_simple}"
|
||||||
|
@ -44,7 +22,7 @@ PROMPTS = {
|
||||||
"{summary}"
|
"{summary}"
|
||||||
"\n{relevant_memories_simple}"
|
"\n{relevant_memories_simple}"
|
||||||
),
|
),
|
||||||
"assistant": "",
|
"assistant": "{name} ",
|
||||||
},
|
},
|
||||||
"topic_of_reflection": {
|
"topic_of_reflection": {
|
||||||
"system": (
|
"system": (
|
||||||
|
@ -53,7 +31,7 @@ PROMPTS = {
|
||||||
" Provide each question on a new line."
|
" Provide each question on a new line."
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{observations}"
|
"Information: {observations}"
|
||||||
),
|
),
|
||||||
"assistant": "",
|
"assistant": "",
|
||||||
},
|
},
|
||||||
|
@ -77,24 +55,22 @@ PROMPTS = {
|
||||||
"\nRespond with only a single integer."
|
"\nRespond with only a single integer."
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{memory_content}"
|
"Event: {memory_content}"
|
||||||
),
|
),
|
||||||
"assistant": "",
|
"assistant": "Rating: ",
|
||||||
},
|
},
|
||||||
"generate_reaction": {
|
"generate_reaction": {
|
||||||
"system": (
|
"system": (
|
||||||
"\nIt is {current_time}."
|
"[Write one reply. Always stay in character. Maintain a casual tone using beige prose. Be brief. Avoid repeating anything below.]"
|
||||||
|
"\nCurrent Time: {current_time}"
|
||||||
"\n{summary}"
|
"\n{summary}"
|
||||||
"\n{relevant_memories_simple}"
|
|
||||||
"\n{memory}"
|
"\n{memory}"
|
||||||
#"\nRecent memories: {recent_memories}"
|
"\n{observation}"
|
||||||
#"\nRelevant memories: {relevant_memories}"
|
|
||||||
"\n\n{suffix}"
|
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{observation}"
|
"{suffix}"
|
||||||
),
|
),
|
||||||
"assistant": "{name} "
|
"assistant": ""
|
||||||
},
|
},
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -102,24 +78,7 @@ PROMPTS = {
|
||||||
""
|
""
|
||||||
),
|
),
|
||||||
"suffix_generate_response": (
|
"suffix_generate_response": (
|
||||||
"Given the following observation, how would {name} respond?"
|
"Given the current situation, in one sentence, what is {name}'s next response?"
|
||||||
"\nWrite only one sentence."
|
|
||||||
),
|
|
||||||
|
|
||||||
##
|
|
||||||
"suffix_generate_reaction": (
|
|
||||||
"Given the following observation, how would {name} appropriately react?"
|
|
||||||
"\nIf the action is to engage in dialogue, only write `SAY: \"what to say\"`."
|
|
||||||
"\nOr otherwise, only write `REACT: how to react`."
|
|
||||||
"\nWrite ONLY one line, one sentence."
|
|
||||||
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
|
||||||
),
|
|
||||||
"suffix_generate_dialogue": (
|
|
||||||
"Given the following observation, what would {name} say?"
|
|
||||||
"\nTo continue the conversation, only write: `SAY: \"what to say\"`."
|
|
||||||
"\nOr otherwise, to end the conversation, only write: `GOODBYE: \"what to say\"`."
|
|
||||||
"\nWrite ONLY one line, one sentence."
|
|
||||||
#"\nBe proactive, creative, and drive the plot and conversation forward."
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,6 +87,7 @@ PROMPT_TUNES = {
|
||||||
"vicuna": "{role}: {query}",
|
"vicuna": "{role}: {query}",
|
||||||
"supercot": "{role}:\n{query}",
|
"supercot": "{role}:\n{query}",
|
||||||
"alpasta": "{role}# {query}",
|
"alpasta": "{role}# {query}",
|
||||||
|
"cocktail": "{role}: {query}",
|
||||||
}
|
}
|
||||||
PROMPT_ROLES = {
|
PROMPT_ROLES = {
|
||||||
"vicuna": {
|
"vicuna": {
|
||||||
|
@ -145,6 +105,11 @@ PROMPT_ROLES = {
|
||||||
"user": "<|user|>",
|
"user": "<|user|>",
|
||||||
"assistant": "<|assistant|>",
|
"assistant": "<|assistant|>",
|
||||||
},
|
},
|
||||||
|
"cocktail": {
|
||||||
|
"system": "",
|
||||||
|
"user": "USER",
|
||||||
|
"assistant": "ASSOCIATE",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ROLES = [ "system", "user", "assistant" ]
|
ROLES = [ "system", "user", "assistant" ]
|
||||||
|
@ -153,7 +118,8 @@ ROLES = [ "system", "user", "assistant" ]
|
||||||
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
||||||
STOP_TOKENS = ["###"] + tokens
|
STOP_TOKENS = ["###"] + tokens
|
||||||
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||||
STOP_TOKENS.append(f'{role}')
|
if role:
|
||||||
|
STOP_TOKENS.append(f'{role}')
|
||||||
return STOP_TOKENS
|
return STOP_TOKENS
|
||||||
|
|
||||||
for k in PROMPTS:
|
for k in PROMPTS:
|
||||||
|
@ -204,4 +170,7 @@ def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||||
output = output.replace("{query}", query)
|
output = output.replace("{query}", query)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return "\n".join(outputs)
|
output = "\n".join(outputs)
|
||||||
|
#if LLM_PROMPT_TUNE == "cocktail":
|
||||||
|
output = output.strip()
|
||||||
|
return output
|
28
src/main.py
28
src/main.py
|
@ -37,8 +37,7 @@ def agent_observes_proxy( agents, observations ):
|
||||||
if agent not in AGENTS:
|
if agent not in AGENTS:
|
||||||
load_agent( agent )
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
observations = observations.split("\n")
|
results = agent_observes( agent, observations.split("\n") )
|
||||||
results = agent_observes( agent, observations )
|
|
||||||
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
messages.append(f"[{agent.name}] Observation noted. Importance score: {[ result[0] for result in results ]}")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
@ -51,8 +50,7 @@ def agent_reacts_proxy( agents, observations ):
|
||||||
if agent not in AGENTS:
|
if agent not in AGENTS:
|
||||||
load_agent( agent )
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
observations = observations.split("\n")
|
response = agent_reacts( agent, observations.split("\n") )
|
||||||
response = agent_reacts( agent, observations )
|
|
||||||
messages.append(f"[{agent.name}] {response}")
|
messages.append(f"[{agent.name}] {response}")
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
@ -80,29 +78,35 @@ def get_summary_proxy( agents ):
|
||||||
messages.append(get_summary( agent, force_refresh = True ))
|
messages.append(get_summary( agent, force_refresh = True ))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def run_conversation_proxy( agents, observation, limit=2 ):
|
def run_conversation_proxy( agents, message, limit=4 ):
|
||||||
agents = [ AGENTS[agent] for agent in agents ]
|
agents = [ AGENTS[agent] for agent in agents ]
|
||||||
|
|
||||||
if len(agents) < 2:
|
if len(agents) < 2:
|
||||||
raise "Not enough agents"
|
raise "Not enough agents"
|
||||||
|
|
||||||
dialogue = []
|
dialogue = []
|
||||||
dialogue.append(f'[{agents[0].name}] {observation}')
|
dialogue.append(f'[{agents[0].name}] {message}')
|
||||||
|
yield "\n".join(dialogue)
|
||||||
|
|
||||||
"""Runs a conversation between agents."""
|
"""Runs a conversation between agents."""
|
||||||
print(colored("[Conversation]", "magenta"))
|
print(colored("[Conversation]", "magenta"))
|
||||||
yield "\n".join(dialogue)
|
importance_score = 0
|
||||||
agent_observes( agents[0], [observation] )
|
for agent in agents:
|
||||||
|
importance_score = agent_observes( agent, [ message ], importance_score=importance_score )[0][0]
|
||||||
agents = agents[1:] + [agents[0]]
|
agents = agents[1:] + [agents[0]]
|
||||||
|
|
||||||
dialogue = []
|
dialogue = []
|
||||||
while True:
|
while True:
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
observation = agent_reacts( agent, [ observation ] )[0]
|
message = agent_reacts( agent, [ message ] )[0]
|
||||||
yield observation
|
importance_score = 0
|
||||||
if limit > 0 and len(dialogue) >= limit:
|
for a in agents:
|
||||||
|
importance_score = agent_observes( a, [ message ], importance_score=importance_score )[0][0]
|
||||||
|
|
||||||
|
dialogue.append(f'[{agent.name}] {message}')
|
||||||
|
yield "\n".join(dialogue)
|
||||||
|
if limit > 0 and len(dialogue) >= limit * len(agents):
|
||||||
break
|
break
|
||||||
return dialogue
|
|
||||||
print("END")
|
print("END")
|
||||||
dialogue.append("END")
|
dialogue.append("END")
|
||||||
return "\n".join(dialogue)
|
return "\n".join(dialogue)
|
||||||
|
|
47
src/utils.py
47
src/utils.py
|
@ -27,14 +27,20 @@ from langchain.vectorstores import FAISS
|
||||||
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
||||||
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
||||||
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
||||||
"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
"./models/ggml-vicuna-13b-cocktail-v1-q5_0.bin"
|
||||||
|
#"./models/llama-13b-supercot-ggml/ggml-model-q4_2.bin"
|
||||||
#"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
#"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
||||||
#"./models/gpt4-x-alpasta-30b-ggml-q4_1.bin"
|
#"./models/gpt4-x-alpasta-30b-ggml-q4_1.bin"
|
||||||
)
|
)
|
||||||
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
||||||
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
||||||
|
LLM_TEMPERATURE = float(os.environ.get('LLM_TEMPERATURE', '0.99'))
|
||||||
EMBEDDING_TYPE = os.environ.get("LLM_EMBEDDING_TYPE", "hf") # options: llamacpp, oai, hf
|
EMBEDDING_TYPE = os.environ.get("LLM_EMBEDDING_TYPE", "hf") # options: llamacpp, oai, hf
|
||||||
|
|
||||||
|
#LLM_TYPE="oai"
|
||||||
|
#os.environ['OPENAI_API_BASE']="https://oai.ecker.tech/proxy/openai"
|
||||||
|
#os.environ['OPENAI_API_KEY']=""
|
||||||
|
|
||||||
# deduce a default given a model path
|
# deduce a default given a model path
|
||||||
if LLM_TYPE=="oai":
|
if LLM_TYPE=="oai":
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "oai"
|
LLM_PROMPT_TUNE_DEFAULT = "oai"
|
||||||
|
@ -45,6 +51,8 @@ else:
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "vicuna"
|
LLM_PROMPT_TUNE_DEFAULT = "vicuna"
|
||||||
elif "alpasta" in LLM_LOCAL_MODEL.lower():
|
elif "alpasta" in LLM_LOCAL_MODEL.lower():
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "alpasta"
|
LLM_PROMPT_TUNE_DEFAULT = "alpasta"
|
||||||
|
elif "cocktail" in LLM_LOCAL_MODEL.lower():
|
||||||
|
LLM_PROMPT_TUNE_DEFAULT = "cocktail"
|
||||||
else:
|
else:
|
||||||
LLM_PROMPT_TUNE_DEFAULT = "llama"
|
LLM_PROMPT_TUNE_DEFAULT = "llama"
|
||||||
|
|
||||||
|
@ -64,6 +72,7 @@ if LLM_TYPE=="llamacpp":
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
n_ctx=LLM_CONTEXT,
|
n_ctx=LLM_CONTEXT,
|
||||||
|
temperature=LLM_TEMPERATURE,
|
||||||
#n_threads=LLM_THREADS,
|
#n_threads=LLM_THREADS,
|
||||||
#use_mlock=True,
|
#use_mlock=True,
|
||||||
#use_mmap=True,
|
#use_mmap=True,
|
||||||
|
@ -89,6 +98,7 @@ elif LLM_TYPE=="oai":
|
||||||
|
|
||||||
LLM = ChatOpenAI(
|
LLM = ChatOpenAI(
|
||||||
max_tokens=LLM_CONTEXT,
|
max_tokens=LLM_CONTEXT,
|
||||||
|
temperature=LLM_TEMPERATURE,
|
||||||
model_name=os.environ.get('OPENAI_MODEL_NAME', 'gpt-4'),
|
model_name=os.environ.get('OPENAI_MODEL_NAME', 'gpt-4'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,7 +108,7 @@ else:
|
||||||
if EMBEDDING_TYPE == "hf":
|
if EMBEDDING_TYPE == "hf":
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
EMBEDDINGS_MODEL = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
EMBEDDINGS_MODEL = HuggingFaceEmbeddings()
|
||||||
EMBEDDINGS_SIZE = 768
|
EMBEDDINGS_SIZE = 768
|
||||||
elif EMBEDDING_TYPE == "oai":
|
elif EMBEDDING_TYPE == "oai":
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
@ -110,10 +120,6 @@ elif EMBEDDING_TYPE == "llamacpp":
|
||||||
|
|
||||||
EMBEDDINGS_MODEL = LlamaCppEmbeddings(
|
EMBEDDINGS_MODEL = LlamaCppEmbeddings(
|
||||||
model_path=LLM_LOCAL_MODEL,
|
model_path=LLM_LOCAL_MODEL,
|
||||||
n_ctx=LLM_CONTEXT,
|
|
||||||
n_threads=LLM_THREADS,
|
|
||||||
use_mlock=True,
|
|
||||||
use_mmap=True,
|
|
||||||
)
|
)
|
||||||
EMBEDDINGS_SIZE = 5120
|
EMBEDDINGS_SIZE = 5120
|
||||||
else:
|
else:
|
||||||
|
@ -143,7 +149,7 @@ def _create_new_memories():
|
||||||
memory_retriever=_create_new_memory_retriever(),
|
memory_retriever=_create_new_memory_retriever(),
|
||||||
reflection_threshold=8,
|
reflection_threshold=8,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
max_tokens_limit=128 # LLM_CONTEXT/4
|
max_tokens_limit=LLM_CONTEXT/2
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_agent(**kwargs):
|
def create_agent(**kwargs):
|
||||||
|
@ -190,40 +196,47 @@ def get_summary(agent: GenerativeAgent, force_refresh: bool = True) -> str:
|
||||||
print(summary)
|
print(summary)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def agent_observes( agent: GenerativeAgent, observations: List[str] ):
|
def agent_observes( agent: GenerativeAgent, observations: List[str], importance_score=0 ):
|
||||||
results = []
|
results = []
|
||||||
for observation in observations:
|
for observation in observations:
|
||||||
observation = observation.replace("{name}", agent.name)
|
observation = observation.replace("{name}", agent.name)
|
||||||
print(colored("[Observation]", "magenta"), observation)
|
print(colored("[Observation]", "magenta"), f'[{agent.name}] {observation}')
|
||||||
results.append(agent.memory.add_memory(observation))
|
results.append(agent.memory.add_memory(observation, importance_score=importance_score))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
|
def agent_reacts( agent: GenerativeAgent, observations: List[str] ):
|
||||||
results = []
|
results = []
|
||||||
for observation in observations:
|
for observation in observations:
|
||||||
observation = observation.replace("{name}", agent.name)
|
observation = observation.replace("{name}", agent.name)
|
||||||
print(colored("[Observation]", "magenta"), observation)
|
print(colored("[Observation]", "magenta"), f'[{agent.name}] {observation}')
|
||||||
_, response = agent.generate_response(observation)
|
_, response = agent.generate_response(observation)
|
||||||
print(colored("[Reaction]", "magenta"), response)
|
print(colored("[Reaction]", "magenta"), f'[{agent.name}] {response}')
|
||||||
results.append(response)
|
results.append(response)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def interview_agent(agent: GenerativeAgent, message: str, username: str = "Person A") -> str:
|
def interview_agent(agent: GenerativeAgent, message: str) -> str:
|
||||||
message = message.replace("{name}", agent.name)
|
message = message.replace("{name}", agent.name)
|
||||||
new_message = f"{username} says {message}"
|
print(colored("[Interview]", "magenta"), f"[User] {message}")
|
||||||
print(colored("[Interview]", "magenta"), message)
|
_, response = agent.generate_response(message)
|
||||||
return agent.generate_dialogue_response(new_message)
|
print(colored("[Interview]", "magenta"), f"[{agent.name}] {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
def run_conversation(agents: List[GenerativeAgent], observation: str, limit: int = 0, p_reaction: float = 1 ) -> None:
|
||||||
print(colored("[Conversation]", "magenta"))
|
print(colored("[Conversation]", "magenta"))
|
||||||
agent_observes( agents[0], [observation] )
|
for agent in agents:
|
||||||
|
agent_observes( agent, [observation] )
|
||||||
|
|
||||||
agents = agents[1:] + [agents[0]]
|
agents = agents[1:] + [agents[0]]
|
||||||
|
|
||||||
dialogue = []
|
dialogue = []
|
||||||
while True:
|
while True:
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
observation = agent_reacts( agent, [ observation ] )[0]
|
observation = agent_reacts( agent, [ observation ] )[0]
|
||||||
|
for a in agents:
|
||||||
|
if a is agent:
|
||||||
|
continue
|
||||||
|
agent_observes( a, [ observation ] )
|
||||||
if limit > 0 and len(dialogue) >= limit:
|
if limit > 0 and len(dialogue) >= limit:
|
||||||
break
|
break
|
||||||
return dialogue
|
return dialogue
|
Loading…
Reference in New Issue
Block a user