added prompt tuning for superCOT (which 33B seems to be the best approach for a local LLM)
This commit is contained in:
parent
089b7043b9
commit
f10ea1ec2a
|
@ -29,7 +29,7 @@ Set your environment variables accordingly:
|
||||||
- `OPENAI_API_MODEL`: target model
|
- `OPENAI_API_MODEL`: target model
|
||||||
* `LLM_MODEL`: (`./path/to/your/llama/model.bin`): path to your GGML-formatted LLaMA model, if using `llamacpp` as the LLM backend
|
* `LLM_MODEL`: (`./path/to/your/llama/model.bin`): path to your GGML-formatted LLaMA model, if using `llamacpp` as the LLM backend
|
||||||
* `LLM_EMBEDDING_TYPE`: (`oai`, `llamacpp`, `hf`): the embedding model to use for similarity computing.
|
* `LLM_EMBEDDING_TYPE`: (`oai`, `llamacpp`, `hf`): the embedding model to use for similarity computing.
|
||||||
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
* `LLM_PROMPT_TUNE`: (`oai`, `vicuna`, `supercot`): prompt formatting to use, for variants with specific finetunes for instructions, etc.
|
||||||
* `LLM_CONTEXT`: sets maximum context size
|
* `LLM_CONTEXT`: sets maximum context size
|
||||||
|
|
||||||
To run:
|
To run:
|
||||||
|
@ -44,6 +44,10 @@ I ***do not*** plan on making this uber-user friendly like [mrq/ai-voice-cloning
|
||||||
|
|
||||||
## Caveats
|
## Caveats
|
||||||
|
|
||||||
A local LM is quite slow. Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:\nASSISTANT:` structure of prompts) is still inconsistent.
|
A local LM is quite slow.
|
||||||
|
|
||||||
|
Even using one that's more instruction-tuned like Vicuna (with a `SYSTEM:\nUSER:\nASSISTANT:` structure of prompts) is still inconsistent.
|
||||||
|
|
||||||
|
However, I seem to be getting consistent results with SuperCOT 33B, it's just, well, slow.
|
||||||
|
|
||||||
GPT4 seems to Just Work, unfortunately.
|
GPT4 seems to Just Work, unfortunately.
|
|
@ -25,5 +25,6 @@ THE SOFTWARE.
|
||||||
"""Generative Agents primitives."""
|
"""Generative Agents primitives."""
|
||||||
from .generative_agent import GenerativeAgent
|
from .generative_agent import GenerativeAgent
|
||||||
from .memory import GenerativeAgentMemory
|
from .memory import GenerativeAgentMemory
|
||||||
|
from .prompts import get_prompt, get_roles
|
||||||
|
|
||||||
__all__ = ["GenerativeAgent", "GenerativeAgentMemory"]
|
__all__ = ["GenerativeAgent", "GenerativeAgentMemory"]
|
||||||
|
|
|
@ -84,17 +84,25 @@ class GenerativeAgent(BaseModel):
|
||||||
|
|
||||||
def _get_entity_from_observation(self, observation: str) -> str:
|
def _get_entity_from_observation(self, observation: str) -> str:
|
||||||
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
|
prompt = PromptTemplate.from_template(get_prompt('entity_from_observation'))
|
||||||
return self.chain(prompt).run(observation=observation).strip()
|
response = self.chain(prompt).run(observation=observation).strip().replace("Entity=", "").replace("Entity: ", "") # OAI will keep this
|
||||||
|
if self.verbose:
|
||||||
|
print(response)
|
||||||
|
return response
|
||||||
|
|
||||||
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
def _get_entity_action(self, observation: str, entity_name: str) -> str:
|
||||||
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
|
prompt = PromptTemplate.from_template(get_prompt('entity_action'))
|
||||||
return self.chain(prompt).run(entity=entity_name, observation=observation).strip()
|
response = self.chain(prompt).run(entity=entity_name, observation=observation).strip()
|
||||||
|
if self.verbose:
|
||||||
|
print(response)
|
||||||
|
return response
|
||||||
|
|
||||||
def summarize_related_memories(self, observation: str) -> str:
|
def summarize_related_memories(self, observation: str) -> str:
|
||||||
"""Summarize memories that are most relevant to an observation."""
|
"""Summarize memories that are most relevant to an observation."""
|
||||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||||
entity_name = self._get_entity_from_observation(observation).split("\n")[0]
|
entity_name = self._get_entity_from_observation(observation).split("\n")[0].strip()
|
||||||
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
q1 = f"What is the relationship between {self.name} and {entity_name}"
|
||||||
|
if self.name.strip() == entity_name:
|
||||||
|
return ""
|
||||||
|
|
||||||
# this is unused, so ignore for now
|
# this is unused, so ignore for now
|
||||||
"""
|
"""
|
||||||
|
@ -103,6 +111,8 @@ class GenerativeAgent(BaseModel):
|
||||||
summary = self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
summary = self.chain(prompt=prompt).run(q1=q1, queries=[q1, q2]).strip()
|
||||||
"""
|
"""
|
||||||
summary = self.chain(prompt=prompt).run(q1=q1, queries=[q1]).strip()
|
summary = self.chain(prompt=prompt).run(q1=q1, queries=[q1]).strip()
|
||||||
|
if self.verbose:
|
||||||
|
print(summary)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
#return self.chain(prompt=prompt).run(q1=q1, q2=q2).strip()
|
#return self.chain(prompt=prompt).run(q1=q1, q2=q2).strip()
|
||||||
|
@ -128,7 +138,10 @@ class GenerativeAgent(BaseModel):
|
||||||
consumed_tokens = self.llm.get_num_tokens(formatted_prompt)
|
consumed_tokens = self.llm.get_num_tokens(formatted_prompt)
|
||||||
|
|
||||||
kwargs[self.memory.most_recent_memories_token_key] = consumed_tokens
|
kwargs[self.memory.most_recent_memories_token_key] = consumed_tokens
|
||||||
return self.chain(prompt=prompt).run(**kwargs).strip()
|
reaction = self.chain(prompt=prompt).run(**kwargs).strip()
|
||||||
|
if self.verbose:
|
||||||
|
print(reaction)
|
||||||
|
return reaction
|
||||||
|
|
||||||
def _clean_response(self, text: str) -> str:
|
def _clean_response(self, text: str) -> str:
|
||||||
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
return re.sub(f"^{self.name} ", "", text.strip()).strip()
|
||||||
|
@ -138,18 +151,20 @@ class GenerativeAgent(BaseModel):
|
||||||
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
|
full_result = self._generate_reaction(observation, get_prompt('suffix_generate_reaction'))
|
||||||
candidates = full_result.replace(u"\u200B", "").strip().split("\n")
|
candidates = full_result.replace(u"\u200B", "").strip().split("\n")
|
||||||
|
|
||||||
result = ""
|
response = ""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
if "REACT:" in candidate or "SAY:" in candidate:
|
if "REACT:" in candidate or "SAY:" in candidate:
|
||||||
candidate = candidate.strip()
|
# can't be assed to iteratively replace
|
||||||
results.append(f'reacted by {candidate}'.replace("SAY:", "saying").replace("reacted by REACT: ", ""))
|
candidate = candidate.strip().replace("React:", "REACT:").replace("Say:", "SAY:")
|
||||||
|
results.append(f'{candidate}'.replace("SAY:", "said").replace(f"REACT: {self.name}", "").replace("REACT:", ""))
|
||||||
if len(results) > 0:
|
if len(results) > 0:
|
||||||
result = "and".join(results)
|
response = "and".join(results).strip().replace(" ", " ")
|
||||||
response = f"reacted by {result}"
|
valid = True
|
||||||
else:
|
else:
|
||||||
response = f"did not react"
|
response = f"did not react in a relevant way"
|
||||||
|
valid = False
|
||||||
|
|
||||||
# AAA
|
# AAA
|
||||||
self.memory.save_context(
|
self.memory.save_context(
|
||||||
|
@ -158,6 +173,10 @@ class GenerativeAgent(BaseModel):
|
||||||
self.memory.add_memory_key: f"{self.name} observed {observation} and {response}"
|
self.memory.add_memory_key: f"{self.name} observed {observation} and {response}"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return valid, f"{self.name} {response}"
|
||||||
|
|
||||||
|
"""
|
||||||
if "REACT:" in result:
|
if "REACT:" in result:
|
||||||
reaction = self._clean_response(result.split("REACT:")[-1])
|
reaction = self._clean_response(result.split("REACT:")[-1])
|
||||||
return True, f"{self.name} {reaction}"
|
return True, f"{self.name} {reaction}"
|
||||||
|
@ -166,6 +185,7 @@ class GenerativeAgent(BaseModel):
|
||||||
return True, f"{self.name} said {said_value}"
|
return True, f"{self.name} said {said_value}"
|
||||||
else:
|
else:
|
||||||
return False, f"{self.name} did not react in a relevant way"
|
return False, f"{self.name} did not react in a relevant way"
|
||||||
|
"""
|
||||||
|
|
||||||
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]:
|
def generate_dialogue_response(self, observation: str) -> Tuple[bool, str]:
|
||||||
"""React to a given observation."""
|
"""React to a given observation."""
|
||||||
|
@ -206,6 +226,8 @@ class GenerativeAgent(BaseModel):
|
||||||
# The agent seeks to think about their core characteristics.
|
# The agent seeks to think about their core characteristics.
|
||||||
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
||||||
summary = self.chain(prompt).run(name=self.name, queries=[f"{self.name}'s core characteristics"]).strip()
|
summary = self.chain(prompt).run(name=self.name, queries=[f"{self.name}'s core characteristics"]).strip()
|
||||||
|
if self.verbose:
|
||||||
|
print(summary)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def get_summary(self, force_refresh: bool = False) -> str:
|
def get_summary(self, force_refresh: bool = False) -> str:
|
||||||
|
|
|
@ -85,6 +85,9 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
observations = self.memory_retriever.memory_stream[-last_k:]
|
observations = self.memory_retriever.memory_stream[-last_k:]
|
||||||
observation_str = "\n".join([o.page_content for o in observations])
|
observation_str = "\n".join([o.page_content for o in observations])
|
||||||
result = self.chain(prompt).run(observations=observation_str)
|
result = self.chain(prompt).run(observations=observation_str)
|
||||||
|
if self.verbose:
|
||||||
|
print(result)
|
||||||
|
|
||||||
return self._parse_list(result)
|
return self._parse_list(result)
|
||||||
|
|
||||||
def _get_insights_on_topic(self, topic: str) -> List[str]:
|
def _get_insights_on_topic(self, topic: str) -> List[str]:
|
||||||
|
@ -121,14 +124,15 @@ class GenerativeAgentMemory(BaseMemory):
|
||||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||||
score = self.chain(prompt).run(memory_content=memory_content).strip()
|
score = self.chain(prompt).run(memory_content=memory_content).strip()
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Importance score: {score}")
|
print(f"Importance score: {score}")
|
||||||
try:
|
try:
|
||||||
match = re.search(r"(\d+)", score)
|
match = re.search(r"(\d+)", score)
|
||||||
if match:
|
if match:
|
||||||
return (float(match.group(0)) / 10) * self.importance_weight
|
return (float(match.group(0)) / 10) * self.importance_weight
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(colored("[Scoring Error]", "red"), score)
|
print(colored("[Scoring Error]", "red"), score)
|
||||||
return 0.0
|
|
||||||
|
return (float(2) / 10) * self.importance_weight
|
||||||
|
|
||||||
def add_memory(self, memory_content: str) -> List[str]:
|
def add_memory(self, memory_content: str) -> List[str]:
|
||||||
"""Add an observation or memory to the agent's memory."""
|
"""Add an observation or memory to the agent's memory."""
|
||||||
|
|
|
@ -2,6 +2,118 @@ import os
|
||||||
|
|
||||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE', "vicuna") # oai, vicuna
|
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE', "vicuna") # oai, vicuna
|
||||||
|
|
||||||
|
# split because I can't prematurely end on the END token like I can with a local LM
|
||||||
|
if LLM_PROMPT_TUNE == "oai":
|
||||||
|
PROMPTS = {
|
||||||
|
"entity_from_observation": {
|
||||||
|
"system": (
|
||||||
|
"What is the observed entity in the following observation?"
|
||||||
|
" ONLY report one object and write one sentence."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Observation: {observation}"
|
||||||
|
),
|
||||||
|
"assistant": "Entity=",
|
||||||
|
},
|
||||||
|
"entity_action": {
|
||||||
|
"system": (
|
||||||
|
"What is `{entity}` doing in the following observation?"
|
||||||
|
" ONLY report one object and write one sentence."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Observation: {observation}"
|
||||||
|
),
|
||||||
|
"assistant": "`{entity}` is ",
|
||||||
|
},
|
||||||
|
"summarize_related_memories": {
|
||||||
|
"system": (
|
||||||
|
"Given the following context, answer the following question in four sentences or less. Summarize the answer as well."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"{q1}?"
|
||||||
|
"\nContext: {relevant_memories_simple}"
|
||||||
|
),
|
||||||
|
"assistant": "Summary of relevant context: ",
|
||||||
|
},
|
||||||
|
"compute_agent_summary": {
|
||||||
|
"system": (
|
||||||
|
"Given the following statements, how would you summarize {name}'s core characteristics?"
|
||||||
|
" Do not embellish under any circumstances."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Statements: {relevant_memories_simple}"
|
||||||
|
),
|
||||||
|
"assistant": "Summary: ",
|
||||||
|
},
|
||||||
|
"topic_of_reflection": {
|
||||||
|
"system": (
|
||||||
|
"Given only the following information, what are the 3 most salient"
|
||||||
|
" high-level questions we can answer about the subjects in the statements?"
|
||||||
|
" Provide each question on a new line."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Information: {observations}"
|
||||||
|
),
|
||||||
|
"assistant": "",
|
||||||
|
},
|
||||||
|
"insights_on_topic": {
|
||||||
|
"system": (
|
||||||
|
"Given the following statements about {topic},"
|
||||||
|
" what 5 high-level insights can you infer?"
|
||||||
|
" (example format: insight (because of 1, 5, 3))"
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Statements: {related_statements}"
|
||||||
|
),
|
||||||
|
"assistant": "",
|
||||||
|
},
|
||||||
|
"memory_importance": {
|
||||||
|
"system": (
|
||||||
|
"On the scale of 1 to 10, where 1 is purely mundane"
|
||||||
|
" (e.g., brushing teeth, making bed) and 10 is extremely poignant"
|
||||||
|
" (e.g., a break up, college acceptance),"
|
||||||
|
" rate the likely poignancy of the following piece of memory."
|
||||||
|
" Respond with only a single integer."
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Memory: {memory_content}"
|
||||||
|
),
|
||||||
|
"assistant": "Rating: ",
|
||||||
|
},
|
||||||
|
"generate_reaction": {
|
||||||
|
"system": (
|
||||||
|
"It is {current_time}."
|
||||||
|
" The following is a description of {agent_name}:"
|
||||||
|
"\n{agent_summary_description}"
|
||||||
|
"\n{agent_name}'s status: {agent_status}"
|
||||||
|
"\nSummary of relevant context from {agent_name}'s memory: {relevant_memories}"
|
||||||
|
"\nMost recent observations: {most_recent_memories}"
|
||||||
|
"\n\n{suffix}"
|
||||||
|
),
|
||||||
|
"user": (
|
||||||
|
"Observation: {observation}"
|
||||||
|
),
|
||||||
|
"assistant": ""
|
||||||
|
},
|
||||||
|
|
||||||
|
#
|
||||||
|
"context": ( # insert your JB here
|
||||||
|
""
|
||||||
|
),
|
||||||
|
"suffix_generate_reaction": (
|
||||||
|
"Given the following observation, in one sentence, how would {agent_name} appropriately react?"
|
||||||
|
"\nWrite 1 reply only in internet RP style, italicize actions, and avoid quotation marks. Use markdown. Be proactive, creative, and drive the plot and conversation forward. Write no less than six sentences each. Always stay in character and avoid repetition."
|
||||||
|
"\nIf the action is to engage in dialogue, write `SAY: \"what to say\"`."
|
||||||
|
"\nOtherwise, write `REACT: {agent_name}'s reaction`."
|
||||||
|
),
|
||||||
|
"suffix_generate_dialogue_response": (
|
||||||
|
"\nWrite 1 reply only in internet RP style, italicize actions, and avoid quotation marks. Use markdown. Be proactive, creative, and drive the plot and conversation forward. Write no less than six sentences each. Always stay in character and avoid repetition."
|
||||||
|
"Given the following observation, in one sentence, what would {agent_name} say?"
|
||||||
|
"\nTo continue the conversation, write: `SAY: \"what to say\"`."
|
||||||
|
"\nOtherwise, to end the conversation, write: `GOODBYE: \"what to say\"`."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
PROMPTS = {
|
PROMPTS = {
|
||||||
"entity_from_observation": {
|
"entity_from_observation": {
|
||||||
"system": (
|
"system": (
|
||||||
|
@ -27,14 +139,14 @@ PROMPTS = {
|
||||||
},
|
},
|
||||||
"summarize_related_memories": {
|
"summarize_related_memories": {
|
||||||
"system": (
|
"system": (
|
||||||
"Given the following context, answer the following question in four sentences or less."
|
"Given the following context, answer the following question in four sentences or less. Summarize the answer as well."
|
||||||
" Write `END` afterwards."
|
" Write `END` afterwards."
|
||||||
|
"\nContext: {relevant_memories_simple}"
|
||||||
),
|
),
|
||||||
"user": (
|
"user": (
|
||||||
"{q1}?"
|
"{q1}?"
|
||||||
"\nContext: {relevant_memories_simple}"
|
|
||||||
),
|
),
|
||||||
"assistant": "Relevant context: ",
|
"assistant": "Summary of relevant context: ",
|
||||||
},
|
},
|
||||||
"compute_agent_summary": {
|
"compute_agent_summary": {
|
||||||
"system": (
|
"system": (
|
||||||
|
@ -104,12 +216,14 @@ PROMPTS = {
|
||||||
),
|
),
|
||||||
"suffix_generate_reaction": (
|
"suffix_generate_reaction": (
|
||||||
"Given the following observation, in one sentence, how would {agent_name} appropriately react?"
|
"Given the following observation, in one sentence, how would {agent_name} appropriately react?"
|
||||||
|
"\nWrite 1 reply only in internet RP style, italicize actions, and avoid quotation marks. Use markdown. Be proactive, creative, and drive the plot and conversation forward. Write no less than six sentences each. Always stay in character and avoid repetition."
|
||||||
"\nIf the action is to engage in dialogue, write `SAY: \"what to say\"`."
|
"\nIf the action is to engage in dialogue, write `SAY: \"what to say\"`."
|
||||||
"\nOtherwise, write `REACT: {agent_name}'s reaction`."
|
"\nOtherwise, write `REACT: {agent_name}'s reaction`."
|
||||||
"\nWrite 'END' afterwards."
|
"\nWrite 'END' afterwards."
|
||||||
),
|
),
|
||||||
"suffix_generate_dialogue_response": (
|
"suffix_generate_dialogue_response": (
|
||||||
"Given the following observation, in one sentence, what would {agent_name} say?"
|
"Given the following observation, in one sentence, what would {agent_name} say?"
|
||||||
|
"\nWrite 1 reply only in internet RP style, italicize actions, and avoid quotation marks. Use markdown. Be proactive, creative, and drive the plot and conversation forward. Write no less than six sentences each. Always stay in character and avoid repetition."
|
||||||
"\nTo continue the conversation, write: `SAY: \"what to say\"`."
|
"\nTo continue the conversation, write: `SAY: \"what to say\"`."
|
||||||
"\nOtherwise, to end the conversation, write: `GOODBYE: \"what to say\"`."
|
"\nOtherwise, to end the conversation, write: `GOODBYE: \"what to say\"`."
|
||||||
"\nWrite \"END\" afterwards."
|
"\nWrite \"END\" afterwards."
|
||||||
|
@ -118,11 +232,35 @@ PROMPTS = {
|
||||||
|
|
||||||
PROMPT_TUNES = {
|
PROMPT_TUNES = {
|
||||||
"default": "{query}",
|
"default": "{query}",
|
||||||
"vicuna": "{ROLE}: {query}"
|
"vicuna": "{role}: {query}",
|
||||||
|
"supercot": "{role}:\n{query}",
|
||||||
|
}
|
||||||
|
PROMPT_ROLES = {
|
||||||
|
"vicuna": {
|
||||||
|
"system": "SYSTEM",
|
||||||
|
"user": "USER",
|
||||||
|
"assistant": "ASSISTANT",
|
||||||
|
},
|
||||||
|
"supercot": {
|
||||||
|
"system": "### Instruction",
|
||||||
|
"user": "### Input",
|
||||||
|
"assistant": "### Response",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ROLES = [ "system", "user", "assistant" ]
|
ROLES = [ "system", "user", "assistant" ]
|
||||||
|
|
||||||
|
for k in PROMPTS:
|
||||||
|
if k == "context":
|
||||||
|
continue
|
||||||
|
|
||||||
|
def get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||||
|
if tune in PROMPT_ROLES:
|
||||||
|
return list(PROMPT_ROLES[tune].values())
|
||||||
|
if special:
|
||||||
|
return []
|
||||||
|
return ROLES
|
||||||
|
|
||||||
def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||||
prompt = PROMPTS[key]
|
prompt = PROMPTS[key]
|
||||||
|
|
||||||
|
@ -134,20 +272,29 @@ def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||||
if tune not in PROMPT_TUNES:
|
if tune not in PROMPT_TUNES:
|
||||||
tune = "default"
|
tune = "default"
|
||||||
|
|
||||||
outputs = []
|
context = PROMPTS["context"]
|
||||||
for role in ROLES:
|
if context:
|
||||||
if role not in prompt:
|
if "system" in prompt:
|
||||||
# implicitly add in our context as a system message
|
if context not in prompt["system"]:
|
||||||
if role == "system" and PROMPTS["context"]:
|
prompt["system"] = f'{context}\n{prompt["system"]}'
|
||||||
query = PROMPTS["context"]
|
|
||||||
else:
|
else:
|
||||||
|
prompt["system"] = f'{context}'
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for r in ROLES:
|
||||||
|
role = f'{r}' # i can't be assed to check if strings COW
|
||||||
|
if role not in prompt:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
query = prompt[role]
|
query = prompt[role]
|
||||||
|
|
||||||
|
if tune in PROMPT_ROLES:
|
||||||
|
roles = PROMPT_ROLES[tune]
|
||||||
|
if role in roles:
|
||||||
|
role = roles[role]
|
||||||
|
|
||||||
output = f'{PROMPT_TUNES[tune]}'
|
output = f'{PROMPT_TUNES[tune]}'
|
||||||
output = output.replace("{role}", role.lower())
|
output = output.replace("{role}", role)
|
||||||
output = output.replace("{ROLE}", role.upper())
|
|
||||||
output = output.replace("{query}", query)
|
output = output.replace("{query}", query)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
|
|
10
src/main.py
10
src/main.py
|
@ -38,6 +38,8 @@ def agent_observes_proxy( agents, observations ):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
|
if agent not in AGENTS:
|
||||||
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
observations = observations.split("\n")
|
observations = observations.split("\n")
|
||||||
results = agent_observes( agent, observations )
|
results = agent_observes( agent, observations )
|
||||||
|
@ -50,6 +52,8 @@ def interview_agent_proxy( agents, message ):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
|
if agent not in AGENTS:
|
||||||
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
messages.append(interview_agent( agent, message )[-1])
|
messages.append(interview_agent( agent, message )[-1])
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
@ -60,13 +64,15 @@ def get_summary_proxy( agents ):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
|
if agent not in AGENTS:
|
||||||
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
messages.append(get_summary( agent, force_refresh = True ))
|
messages.append(get_summary( agent, force_refresh = True ))
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def run_conversation_proxy( agents, message ):
|
def run_conversation_proxy( agents, message ):
|
||||||
agents = [ AGENTS[agent] for agent in agents ]
|
agents = [ AGENTS[agent] for agent in agents ]
|
||||||
messages = run_conversation( agents, message, limit=len(agents)*3 )
|
messages = run_conversation( agents, message, limit=len(agents)*2 )
|
||||||
return "\n".join(messages)
|
return "\n".join(messages)
|
||||||
|
|
||||||
def view_agent( agents, last_k = 50 ):
|
def view_agent( agents, last_k = 50 ):
|
||||||
|
@ -75,6 +81,8 @@ def view_agent( agents, last_k = 50 ):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
|
if agent not in AGENTS:
|
||||||
|
load_agent( agent )
|
||||||
agent = AGENTS[agent]
|
agent = AGENTS[agent]
|
||||||
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
|
memories = agent.memory.memory_retriever.memory_stream[-last_k:]
|
||||||
memories = "\n".join([ document.page_content for document in memories])
|
memories = "\n".join([ document.page_content for document in memories])
|
||||||
|
|
38
src/utils.py
38
src/utils.py
|
@ -21,29 +21,40 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
|
|
||||||
# shit I can shove behind an env var
|
# shit I can shove behind an env var
|
||||||
|
os.environ['LLM_PROMPT_TUNE'] = "supercot"
|
||||||
|
|
||||||
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
LLM_TYPE = os.environ.get('LLM_TYPE', "llamacpp") # options: llamacpp, oai
|
||||||
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL', "./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin") # "./models/llama-13b-supercot-ggml/ggml-model-q4_0.bin"
|
LLM_LOCAL_MODEL = os.environ.get('LLM_MODEL',
|
||||||
|
#"./models/ggml-vicuna-13b-1.1/ggml-vic13b-uncensored-q4_2.bin"
|
||||||
|
#"./models/llama-13b-supercot-ggml/ggml-model-q4_0.bin"
|
||||||
|
"./models/llama-33b-supercot-ggml/ggml-model-q4_2.bin"
|
||||||
|
)
|
||||||
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
LLM_CONTEXT = int(os.environ.get('LLM_CONTEXT', '2048'))
|
||||||
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
LLM_THREADS = int(os.environ.get('LLM_THREADS', '6'))
|
||||||
EMBEDDING_TYPE = os.environ.get("LLM_EMBEDDING_TYPE", "hf") # options: llamacpp, oai, hf
|
EMBEDDING_TYPE = os.environ.get("LLM_EMBEDDING_TYPE", "hf") # options: llamacpp, oai, hf
|
||||||
|
|
||||||
|
if LLM_TYPE=="oai":
|
||||||
|
os.environ['LLM_PROMPT_TUNE'] = "oai"
|
||||||
|
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE', "supercot")
|
||||||
|
|
||||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) # unncessesary but whatever
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) # unncessesary but whatever
|
||||||
if LLM_TYPE=="llamacpp":
|
|
||||||
from langchain.llms import LlamaCpp
|
# Overrides for some fixes, like scoring memory and LLM-specific promptings
|
||||||
|
from ext import GenerativeAgent, GenerativeAgentMemory, get_roles
|
||||||
|
|
||||||
STOP_TOKENS = ["END"]
|
STOP_TOKENS = ["END"]
|
||||||
|
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||||
|
STOP_TOKENS.append(f'{role}:')
|
||||||
|
|
||||||
if os.environ.get('LLM_PROMPT_TUNE', "vicuna") == "vicuna":
|
if LLM_TYPE=="llamacpp":
|
||||||
STOP_TOKENS.append("SYSTEM:")
|
from langchain.llms import LlamaCpp
|
||||||
STOP_TOKENS.append("USER:")
|
|
||||||
STOP_TOKENS.append("ASSISTANT:")
|
|
||||||
|
|
||||||
LLM = LlamaCpp(
|
LLM = LlamaCpp(
|
||||||
model_path=LLM_LOCAL_MODEL,
|
model_path=LLM_LOCAL_MODEL,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
n_ctx=LLM_CONTEXT,
|
n_ctx=LLM_CONTEXT,
|
||||||
n_threads=LLM_THREADS,
|
#n_threads=LLM_THREADS,
|
||||||
use_mlock=True,
|
use_mlock=True,
|
||||||
use_mmap=True,
|
use_mmap=True,
|
||||||
stop=STOP_TOKENS
|
stop=STOP_TOKENS
|
||||||
|
@ -51,10 +62,6 @@ if LLM_TYPE=="llamacpp":
|
||||||
elif LLM_TYPE=="oai":
|
elif LLM_TYPE=="oai":
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
# os.environ["OPENAI_API_BASE"] = ""
|
|
||||||
# os.environ["OPENAI_API_KEY"] = ""
|
|
||||||
os.environ['LLM_PROMPT_TUNE'] = "vicuna"
|
|
||||||
|
|
||||||
# Override for Todd
|
# Override for Todd
|
||||||
if os.environ.get('LANGCHAIN_OVERRIDE_RESULT', '1') == '1':
|
if os.environ.get('LANGCHAIN_OVERRIDE_RESULT', '1') == '1':
|
||||||
from langchain.schema import Generation, ChatResult, LLMResult, ChatGeneration
|
from langchain.schema import Generation, ChatResult, LLMResult, ChatGeneration
|
||||||
|
@ -104,12 +111,6 @@ elif EMBEDDING_TYPE == "llamacpp":
|
||||||
else:
|
else:
|
||||||
raise f"Invalid embedding type: {EMBEDDING_TYPE}"
|
raise f"Invalid embedding type: {EMBEDDING_TYPE}"
|
||||||
|
|
||||||
# Overrides for some fixes, like scoring memory and LLM-specific promptings
|
|
||||||
if os.environ.get('LANGCHAIN_OVERRIDE', '1') == '1':
|
|
||||||
from ext import GenerativeAgent, GenerativeAgentMemory
|
|
||||||
else:
|
|
||||||
from langchain.experimental.generative_agents import GenerativeAgent, GenerativeAgentMemory
|
|
||||||
|
|
||||||
def _relevance_score_fn(score: float) -> float:
|
def _relevance_score_fn(score: float) -> float:
|
||||||
if EMBEDDING_TYPE == "oai":
|
if EMBEDDING_TYPE == "oai":
|
||||||
return 1.0 - score / math.sqrt(2)
|
return 1.0 - score / math.sqrt(2)
|
||||||
|
@ -140,6 +141,7 @@ def _create_new_memories():
|
||||||
def create_agent(**kwargs):
|
def create_agent(**kwargs):
|
||||||
settings = {
|
settings = {
|
||||||
"llm": LLM,
|
"llm": LLM,
|
||||||
|
"verbose": True,
|
||||||
"memory": _create_new_memories(),
|
"memory": _create_new_memories(),
|
||||||
}
|
}
|
||||||
settings.update(kwargs)
|
settings.update(kwargs)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user