slight rewrite, mostly as functional as before
parent
6fa2c18fb1
commit
0093a70c51
@ -0,0 +1,223 @@
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import re
|
||||
import logging
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseMemory, Document
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from prompts import get_prompt, get_stop_tokens
|
||||
|
||||
class GenerativeAgent(BaseModel):
|
||||
name: str
|
||||
sex: str
|
||||
age: Optional[int] = None
|
||||
traits: str = "N/A"
|
||||
status: str
|
||||
memories: List[dict] = Field(default_factory=list)
|
||||
summaries: List[str] = Field(default_factory=list)
|
||||
|
||||
last_refreshed: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
embeddings: Optional[Embeddings] = None
|
||||
vectorstore: Optional[VectorStore] = None
|
||||
memory: Optional[VectorStoreRetrieverMemory] = None
|
||||
|
||||
verbose: bool = True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, age: int, sex: str, traits: str, status: str, summaries: List[str] = ["N/A"], memories: List[dict] = [], llm: Optional[BaseLanguageModel] = None, embeddings: Optional[Embeddings] = None, vectorstore: Optional[VectorStore] = None ):
|
||||
agent = cls(
|
||||
name = name,
|
||||
age = age,
|
||||
sex = sex,
|
||||
traits = traits,
|
||||
status = status,
|
||||
memories = memories,
|
||||
summaries = summaries,
|
||||
llm = llm,
|
||||
embeddings = embeddings,
|
||||
vectorstore = vectorstore,
|
||||
memory = VectorStoreRetrieverMemory(
|
||||
retriever = vectorstore.as_retriever(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if len(agent.memories) > 0:
|
||||
agent.vectorstore.add_texts(
|
||||
texts=[ memory["observation"] for memory in agent.memories ],
|
||||
metadatas=[ { "name": agent.name, "time": memory["time"], "importance": memory["importance"] } for memory in agent.memories ],
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
||||
|
||||
def save(self, pickled: bool = False) -> str:
|
||||
os.makedirs(f"./agents/", exist_ok=True)
|
||||
obj = {
|
||||
"name": self.name,
|
||||
"age": self.age,
|
||||
"sex": self.sex,
|
||||
"traits": self.traits,
|
||||
"status": self.status,
|
||||
"summaries": self.summaries,
|
||||
"memories": self.memories,
|
||||
}
|
||||
|
||||
if pickled:
|
||||
path = f"./agents/{self.name}.pth"
|
||||
pickle.dump(obj, open(path, 'wb'))
|
||||
else:
|
||||
path = f"./agents/{self.name}.json"
|
||||
json.dump(obj, open(path, "w", encoding="utf-8"))
|
||||
|
||||
@classmethod
|
||||
def load(cls, name: str, llm: Optional[BaseLanguageModel] = None, embeddings: Optional[Embeddings] = None, vectorstore: Optional[VectorStore] = None, pickled: bool = False) -> str:
|
||||
if pickled:
|
||||
path = f"./agents/{name}.pth"
|
||||
obj = pickle.load(open(path, 'rb'))
|
||||
else:
|
||||
path = f"./agents/{name}.json"
|
||||
obj = json.load(open(path, 'r', encoding="utf-8"))
|
||||
|
||||
agent = cls.create(**obj, llm=llm, embeddings=embeddings, vectorstore=vectorstore)
|
||||
|
||||
return agent
|
||||
|
||||
def importance( self, observation: str, weight: float = 0.15 ) -> float:
|
||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||
score = self.chain(prompt).run(
|
||||
stop=get_stop_tokens(tokens=[".", "/", "("]),
|
||||
observation=observation,
|
||||
).strip()
|
||||
match = re.search(r"(\d+)", score)
|
||||
if match:
|
||||
score = float(match.group(0))
|
||||
else:
|
||||
score = 2.0
|
||||
|
||||
return score / 10.0 * weight
|
||||
|
||||
def summarize( self ) -> str:
|
||||
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
||||
summary = self.chain(prompt).run(
|
||||
stop=get_stop_tokens(),
|
||||
name=self.name,
|
||||
summary=self.summary(),
|
||||
memories="\n".join(self.recent_memories())
|
||||
).strip()
|
||||
self.summaries.append(f'{self.name} {summary}')
|
||||
return f'{self.name} {summary}'
|
||||
|
||||
def summary( self, refresh: bool = False ) -> str:
|
||||
# todo: invoke summarizer
|
||||
if refresh:
|
||||
self.summarize()
|
||||
return self.summaries[-1]
|
||||
|
||||
def relevant_memories( self, observation: str, k = 12 ) -> List[str]:
|
||||
# todo: query vectorstore
|
||||
return [ memory["observation"] for memory in self.memories[-k:] ]
|
||||
|
||||
def recent_memories( self, k = 12 ) -> List[str]:
|
||||
# todo: sort by time
|
||||
return [ memory["observation"] for memory in self.memories[-k:] ]
|
||||
|
||||
def memorize( self, observation: str, importance: float = 0, time: datetime = datetime.now() ) -> dict:
|
||||
entry = {
|
||||
"time": int(time.timestamp()),
|
||||
"importance": importance,
|
||||
"observation": observation,
|
||||
}
|
||||
self.memories.append(entry)
|
||||
self.vectorstore.add_texts(
|
||||
texts=[ observation ],
|
||||
metadatas=[ { "name": self.name, "time": entry["time"], "importance": entry["importance"] } ],
|
||||
)
|
||||
return entry
|
||||
|
||||
def observe( self, observation: str, importance: float = 0, time: datetime = datetime.now() ) -> float:
|
||||
if importance == 0:
|
||||
importance = self.importance( observation )
|
||||
self.memorize( observation, importance, time )
|
||||
return importance
|
||||
|
||||
def react( self, observation: str, history: List[str] = [], time: datetime = datetime.now() ) -> dict:
|
||||
# self.memorize( observation )
|
||||
suffix = get_prompt('suffix_generate_response')
|
||||
prompt = PromptTemplate.from_template(
|
||||
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
||||
)
|
||||
summary = self.summary()
|
||||
relevant_memories = self.relevant_memories(observation)
|
||||
recent_memories = self.recent_memories()
|
||||
|
||||
# avoid repeating
|
||||
memory = ""
|
||||
|
||||
for mem in relevant_memories:
|
||||
if mem in summary or mem in memory or mem in observation or mem in history:
|
||||
continue
|
||||
memory += f"\n{mem}"
|
||||
|
||||
for mem in recent_memories:
|
||||
if mem in summary or mem in observation or mem in history:
|
||||
continue
|
||||
# erase it, move it to bottom
|
||||
if mem in memory:
|
||||
memory = memory.replace(f'{mem}\n', "")
|
||||
memory += f"\n{mem}"
|
||||
|
||||
history = "\n".join(history)
|
||||
reaction = self.chain(prompt=prompt).run(
|
||||
stop=get_stop_tokens(tokens=[f'\n{self.name}: ']),
|
||||
current_time=datetime.now().strftime("%B %d, %Y, %I:%M %p"),
|
||||
name=self.name,
|
||||
status=self.status if self.status else "N/A",
|
||||
summary=summary if summary else "N/A",
|
||||
memory=memory if memory else "N/A",
|
||||
history=history if history else "N/A",
|
||||
observation=observation if observation else "N/A",
|
||||
).strip()
|
||||
|
||||
emoji_pattern = re.compile("["
|
||||
u"\U0001F600-\U0001F64F" # emoticons
|
||||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
"]+", flags=re.UNICODE)
|
||||
reaction = emoji_pattern.sub(r'', reaction)
|
||||
|
||||
# cleanup
|
||||
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
||||
|
||||
for reaction in reactions:
|
||||
if reaction in summary or reaction in memory or reaction in history:
|
||||
continue
|
||||
if reaction:
|
||||
break
|
||||
|
||||
return f'{self.name}: {reaction}'
|
@ -1,30 +0,0 @@
|
||||
"""
|
||||
The MIT License
|
||||
|
||||
Copyright (c) Harrison Chase
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
"""
|
||||
|
||||
"""Generative Agents primitives."""
|
||||
from .generative_agent import GenerativeAgent
|
||||
from .memory import GenerativeAgentMemory
|
||||
from .prompts import get_prompt, get_roles
|
||||
|
||||
__all__ = ["GenerativeAgent", "GenerativeAgentMemory"]
|
@ -1,226 +0,0 @@
|
||||
# From https://github.com/hwchase17/langchain/tree/master/langchain/experimental/generative_agents
|
||||
"""
|
||||
The MIT License
|
||||
|
||||
Copyright (c) Harrison Chase
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from .memory import GenerativeAgentMemory
|
||||
from .prompts import get_prompt, get_stop_tokens
|
||||
|
||||
class GenerativeAgent(BaseModel):
|
||||
"""A character with memory and innate characteristics."""
|
||||
|
||||
name: str
|
||||
"""The character's name."""
|
||||
|
||||
sex: str
|
||||
"""The character's sex."""
|
||||
|
||||
age: Optional[int] = None
|
||||
"""The optional age of the character."""
|
||||
traits: str = "N/A"
|
||||
"""Permanent traits to ascribe to the character."""
|
||||
status: str
|
||||
"""The traits of the character you wish not to change."""
|
||||
memory: GenerativeAgentMemory
|
||||
"""The memory object that combines relevance, recency, and 'importance'."""
|
||||
llm: BaseLanguageModel
|
||||
"""The underlying language model."""
|
||||
verbose: bool = True
|
||||
summary: str = "N/A" #: :meta private:
|
||||
"""Stateful self-summary generated via reflection on the character's memory."""
|
||||
|
||||
summary_refresh_seconds: int = 3600 #: :meta private:
|
||||
"""How frequently to re-generate the summary."""
|
||||
|
||||
last_refreshed: datetime = Field(default_factory=datetime.now) # : :meta private:
|
||||
"""The last time the character's summary was regenerated."""
|
||||
|
||||
summaries: List[str] = Field(default_factory=list) # : :meta private:
|
||||
"""Summary of the events in the plan that the agent took."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# LLM-related methods
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings."""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(
|
||||
llm=self.llm, prompt=prompt, verbose=self.verbose, memory=self.memory
|
||||
)
|
||||
|
||||
def get_most_recent_memories(self, last_k: int = 8) -> str:
|
||||
memories = self.memory.memory_retriever.memory_stream[-last_k:]
|
||||
return [ document.page_content.replace(u"\u200B", "").strip() for document in memories ]
|
||||
|
||||
def get_relevant_memories(self, observation: str, first_k : int = 8) -> str:
|
||||
queries = [ observation ]
|
||||
relevant_memories = [
|
||||
mem.page_content.replace(u"\u200B", "").strip() for query in queries for mem in self.memory.fetch_memories(query)
|
||||
]
|
||||
relevant_memories = relevant_memories[:first_k]
|
||||
relevant_memories.reverse()
|
||||
return relevant_memories
|
||||
|
||||
"""
|
||||
def summarize_related_memories(self, observation: str, first_k : int = 4) -> str:
|
||||
prompt = PromptTemplate.from_template(get_prompt('summarize_related_memories'))
|
||||
query = f"Summarize the relationship between the subjects in that interaction in two sentences or less. Avoid repeating."
|
||||
summary = self.chain(prompt=prompt).run(name=self.name, stop=get_stop_tokens(), query=query, observation=observation, queries=[observation]).strip()
|
||||
return f'{self.name} {summary}'
|
||||
"""
|
||||
|
||||
#return self.chain(prompt=prompt).run(stop=get_stop_tokens(), q1=q1, q2=q2).strip()
|
||||
|
||||
def _generate_reaction(self, observation: str, suffix: str) -> str:
|
||||
"""React to a given observation or dialogue act."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
get_prompt('generate_reaction').replace("{suffix}", suffix)
|
||||
)
|
||||
summary = self.get_summary()
|
||||
relevant_memories = self.get_relevant_memories(observation)
|
||||
recent_memories = self.get_most_recent_memories()
|
||||
|
||||
# avoid repeating
|
||||
memory = ""
|
||||
|
||||
for mem in relevant_memories:
|
||||
if mem in summary or mem in memory or mem in observation:
|
||||
continue
|
||||
memory += f"\n{mem}"
|
||||
|
||||
for mem in recent_memories:
|
||||
if mem in summary or mem in observation:
|
||||
continue
|
||||
# erase it, move it to bottom
|
||||
if mem in memory:
|
||||
memory = memory.replace(f'{mem}\n', "")
|
||||
memory += f"\n{mem}"
|
||||
|
||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
kwargs: Dict[str, Any] = dict(
|
||||
current_time=current_time_str,
|
||||
name=self.name,
|
||||
status=self.status if self.status else "N/A",
|
||||
summary=summary if summary else "N/A",
|
||||
memory=memory if memory else "N/A",
|
||||
#relevant_memories=relevant_memories if relevant_memories else "N/A",
|
||||
#recent_memories=recent_memories if recent_memories else "N/A",
|
||||
observation=observation if observation else "N/A",
|
||||
)
|
||||
reaction = self.chain(prompt=prompt).run(stop=get_stop_tokens(), **kwargs).strip()
|
||||
import re
|
||||
|
||||
emoji_pattern = re.compile("["
|
||||
u"\U0001F600-\U0001F64F" # emoticons
|
||||
u"\U0001F300-\U0001F5FF" # symbols & pictographs
|
||||
u"\U0001F680-\U0001F6FF" # transport & map symbols
|
||||
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
||||
"]+", flags=re.UNICODE)
|
||||
reaction = emoji_pattern.sub(r'', reaction)
|
||||
|
||||
# cleanup
|
||||
reactions = reaction.replace(u"\u200B", "").strip().split("\n")
|
||||
|
||||
for reaction in reactions:
|
||||
if reaction in summary or reaction in memory:
|
||||
continue
|
||||
if reaction:
|
||||
break
|
||||
|
||||
if self.verbose:
|
||||
print(reaction)
|
||||
return f'{self.name}: {reaction}'
|
||||
|
||||
def generate_response(self, observation: str) -> Tuple[bool, str]:
|
||||
"""React to a given observation."""
|
||||
call_to_action_template = get_prompt('suffix_generate_response')
|
||||
full_result = ""
|
||||
while not full_result:
|
||||
full_result = f"{self._generate_reaction(observation, call_to_action_template)}"
|
||||
if full_result:
|
||||
break
|
||||
|
||||
return True, full_result
|
||||
|
||||
######################################################
|
||||
# Agent stateful' summary methods. #
|
||||
# Each dialog or response prompt includes a header #
|
||||
# summarizing the agent's self-description. This is #
|
||||
# updated periodically through probing its memories #
|
||||
######################################################
|
||||
def _compute_agent_summary(self) -> str:
|
||||
""""""
|
||||
# The agent seeks to think about their core characteristics.
|
||||
prompt = PromptTemplate.from_template(get_prompt('compute_agent_summary'))
|
||||
summary = self.chain(prompt).run(stop=get_stop_tokens(), name=self.name, summary=self.summaries[-1] if len(self.summaries) else self.summary, queries=[f"{self.name}'s core characteristics"]).strip()
|
||||
if self.verbose:
|
||||
print(summary)
|
||||
return f'{self.name} {summary}'
|
||||
|
||||
def get_summary(self, force_refresh: bool = False) -> str:
|
||||
"""Return a descriptive summary of the agent."""
|
||||
current_time = datetime.now()
|
||||
since_refresh = (current_time - self.last_refreshed).seconds
|
||||
if (
|
||||
not self.summary
|
||||
or since_refresh >= self.summary_refresh_seconds
|
||||
or force_refresh
|
||||
):
|
||||
self.summary = self._compute_agent_summary()
|
||||
self.summaries.append(self.summary)
|
||||
self.last_refreshed = current_time
|
||||
|
||||
values = [
|
||||
f"Name: {self.name} (sex: {self.sex}, age: {self.age if self.age is not None else 'N/A'})",
|
||||
f"Innate traits: {self.traits}",
|
||||
f"Status: {self.status}"
|
||||
]
|
||||
|
||||
summary = "\n".join([ value for value in values if value[-3:] != "N/A" ]) + f"\nSummary: {self.summary.strip()}"
|
||||
return summary.replace(u"\u200B", "").strip()
|
||||
|
||||
def get_full_header(self, force_refresh: bool = False) -> str:
|
||||
"""Return a full header of the agent's status, summary, and current time."""
|
||||
summary = self.get_summary(force_refresh=force_refresh)
|
||||
current_time_str = datetime.now().strftime("%B %d, %Y, %I:%M %p")
|
||||
return (
|
||||
f"{summary}\nIt is {current_time_str}.\n{self.name}'s status: {self.status}"
|
||||
)
|
@ -1,224 +0,0 @@
|
||||
# From https://github.com/hwchase17/langchain/tree/master/langchain/experimental/generative_agents
|
||||
"""
|
||||
The MIT License
|
||||
|
||||
Copyright (c) Harrison Chase
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.retrievers import TimeWeightedVectorStoreRetriever
|
||||
from langchain.schema import BaseMemory, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .prompts import get_prompt, get_stop_tokens
|
||||
|
||||
class GenerativeAgentMemory(BaseMemory):
|
||||
llm: BaseLanguageModel
|
||||
"""The core language model."""
|
||||
|
||||
memory_retriever: TimeWeightedVectorStoreRetriever
|
||||
"""The retriever to fetch related memories."""
|
||||
verbose: bool = True
|
||||
|
||||
reflection_threshold: Optional[float] = None
|
||||
"""When aggregate_importance exceeds reflection_threshold, stop to reflect."""
|
||||
|
||||
current_plan: List[str] = []
|
||||
"""The current plan of the agent."""
|
||||
|
||||
# A weight of 0.15 makes this less important than it
|
||||
# would be otherwise, relative to salience and time
|
||||
importance_weight: float = 0.15
|
||||
"""How much weight to assign the memory importance."""
|
||||
|
||||
aggregate_importance: float = 0.0 # : :meta private:
|
||||
"""Track the sum of the 'importance' of recent memories.
|
||||
|
||||
Triggers reflection when it reaches reflection_threshold."""
|
||||
|
||||
max_tokens_limit: int = 1200 # : :meta private:
|
||||
# input keys
|
||||
queries_key: str = "queries"
|
||||
most_recent_memories_token_key: str = "recent_memories_token"
|
||||
add_memory_key: str = "add_memory"
|
||||
# output keys
|
||||
relevant_memories_key: str = "relevant_memories"
|
||||
relevant_memories_simple_key: str = "relevant_memories_simple"
|
||||
most_recent_memories_key: str = "most_recent_memories"
|
||||
|
||||
reflecting: bool = False
|
||||
|
||||
def chain(self, prompt: PromptTemplate) -> LLMChain:
|
||||
return LLMChain(llm=self.llm, prompt=prompt, verbose=self.verbose)
|
||||
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings."""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
def _get_topics_of_reflection(self, last_k: int = 50) -> List[str]:
|
||||
"""Return the 3 most salient high-level questions about recent observations."""
|
||||
prompt = PromptTemplate.from_template(get_prompt("topic_of_reflection"))
|
||||
observations = self.memory_retriever.memory_stream[-last_k:]
|
||||
observation_str = "\n".join([o.page_content for o in observations])
|
||||
result = self.chain(prompt).run(stop=get_stop_tokens(), observations=observation_str)
|
||||
if self.verbose:
|
||||
print(result)
|
||||
|
||||
return self._parse_list(result)
|
||||
|
||||
def _get_insights_on_topic(self, topic: str) -> List[str]:
|
||||
"""Generate 'insights' on a topic of reflection, based on pertinent memories."""
|
||||
prompt = PromptTemplate.from_template(get_prompt("insights_on_topic"))
|
||||
related_memories = self.fetch_memories(topic)
|
||||
related_statements = "\n".join(
|
||||
[
|
||||
f"{i+1}. {memory.page_content}"
|
||||
for i, memory in enumerate(related_memories)
|
||||
]
|
||||
)
|
||||
result = self.chain(prompt).run( stop=get_stop_tokens(), topic=topic, related_statements=related_statements )
|
||||
# TODO: Parse the connections between memories and insights
|
||||
return self._parse_list(result)
|
||||
|
||||
def pause_to_reflect(self) -> List[str]:
|
||||
"""Reflect on recent observations and generate 'insights'."""
|
||||
if self.verbose:
|
||||
logger.info("Character is reflecting")
|
||||
new_insights = []
|
||||
topics = self._get_topics_of_reflection()
|
||||
for topic in topics:
|
||||
insights = self._get_insights_on_topic(topic)
|
||||
for insight in insights:
|
||||
self.add_memory(insight)
|
||||
new_insights.extend(insights)
|
||||
return new_insights
|
||||
|
||||
def _score_memory_importance(self, memory_content: str) -> float:
|
||||
"""Score the absolute importance of the given memory."""
|
||||
prompt = PromptTemplate.from_template(get_prompt("memory_importance"))
|
||||
score = self.chain(prompt).run(stop=get_stop_tokens(tokens=[".", "/", "("]), memory_content=memory_content).strip()
|
||||
if self.verbose:
|
||||
print(f"Importance score: {score}")
|
||||
try:
|
||||
match = re.search(r"(\d+)", score)
|
||||
if match:
|
||||
return (float(match.group(0)) / 10) * self.importance_weight
|
||||
except Exception as e:
|
||||
print(colored("[Scoring Error]", "red"), score)
|
||||
|
||||
return (float(2) / 10) * self.importance_weight
|
||||
|
||||
def add_memory(self, memory_content: str, importance_score: int = 0) -> List[str]:
|
||||
"""Add an observation or memory to the agent's memory."""
|
||||
if not importance_score:
|
||||
importance_score = self._score_memory_importance(memory_content)
|
||||
self.aggregate_importance += importance_score
|
||||
document = Document( page_content=memory_content, metadata={"importance": importance_score} )
|
||||
result = self.memory_retriever.add_documents([document])
|
||||
|
||||
# After an agent has processed a certain amount of memories (as measured by
|
||||
# aggregate importance), it is time to reflect on recent events to add
|
||||
# more synthesized memories to the agent's memory stream.
|
||||
if (
|
||||
self.reflection_threshold is not None
|
||||
and self.aggregate_importance > self.reflection_threshold
|
||||
and not self.reflecting
|
||||
):
|
||||
self.reflecting = True
|
||||
self.pause_to_reflect()
|
||||
# Hack to clear the importance from reflection
|
||||
self.aggregate_importance = 0.0
|
||||
self.reflecting = False
|
||||
|
||||
return (importance_score, result)
|
||||
|
||||
def fetch_memories(self, observation: str) -> List[Document]:
|
||||
"""Fetch related memories."""
|
||||
return self.memory_retriever.get_relevant_documents(observation)
|
||||
|
||||
def format_memories_detail(self, relevant_memories: List[Document]) -> str:
|
||||
content_strs = set()
|
||||
content = []
|
||||
for mem in relevant_memories:
|
||||
if mem.page_content in content_strs:
|
||||
continue
|
||||
content_strs.add(mem.page_content)
|
||||
created_time = mem.metadata["created_at"].strftime("%B %d, %Y, %I:%M %p")
|
||||
content.append(f"- {created_time}: {mem.page_content.strip()}")
|
||||
return "\n".join([f"{mem}" for mem in content])
|
||||
|
||||
def format_memories_simple(self, relevant_memories: List[Document]) -> str:
|
||||
return "; ".join([f"{mem.page_content}" for mem in relevant_memories]).replace(".;", ".\n")
|
||||
|
||||
def _get_memories_until_limit(self, consumed_tokens: int) -> str:
|
||||
"""Reduce the number of tokens in the documents."""
|
||||
result = []
|
||||
for doc in self.memory_retriever.memory_stream[::-1]:
|
||||
if consumed_tokens >= self.max_tokens_limit:
|
||||
break
|
||||
consumed_tokens += self.llm.get_num_tokens(doc.page_content)
|
||||
if consumed_tokens < self.max_tokens_limit:
|
||||
result.append(doc)
|
||||
return self.format_memories_simple(result)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Input keys this memory class will load dynamically."""
|
||||
return []
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
queries = inputs.get(self.queries_key)
|
||||
if queries is not None:
|
||||
relevant_memories = [
|
||||
mem for query in queries for mem in self.fetch_memories(query)
|
||||
]
|
||||
return {
|
||||
self.relevant_memories_key: self.format_memories_detail( relevant_memories ),
|
||||
self.relevant_memories_simple_key: self.format_memories_simple( relevant_memories ),
|
||||
}
|
||||
|
||||
most_recent_memories_token = inputs.get(self.most_recent_memories_token_key)
|
||||
if most_recent_memories_token is not None:
|
||||
return {
|
||||
self.most_recent_memories_key: self._get_memories_until_limit( most_recent_memories_token )
|
||||
}
|
||||
return {}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this model run to memory."""
|
||||
# TODO: fix the save memory key
|
||||
mem = outputs.get(self.add_memory_key)
|
||||
if mem:
|
||||
self.add_memory(mem)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
# TODO
|
@ -1,182 +0,0 @@
|
||||
import os
|
||||
|
||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE') # oai, vicuna, supercot
|
||||
|
||||
PROMPTS = {
|
||||
"summarize_related_memories": {
|
||||
"system": (
|
||||
"{query}"
|
||||
),
|
||||
"user": (
|
||||
"{relevant_memories_simple}"
|
||||
"{observation}"
|
||||
),
|
||||
"assistant": "{name} ",
|
||||
},
|
||||
"compute_agent_summary": {
|
||||
"system": (
|
||||
"Given the following previous summary and the following statements, how would you summarize {name}'s core characteristics?"
|
||||
" Do not embellish under any circumstances."
|
||||
),
|
||||
"user": (
|
||||
"{summary}"
|
||||
"\n{relevant_memories_simple}"
|
||||
),
|
||||
"assistant": "{name} ",
|
||||
},
|
||||
"topic_of_reflection": {
|
||||
"system": (
|
||||
"Given only the following information, what are the 3 most salient"
|
||||
" high-level questions we can answer about the subjects in the statements?"
|
||||
" Provide each question on a new line."
|
||||
),
|
||||
"user": (
|
||||
"Information: {observations}"
|
||||
),
|
||||
"assistant": "",
|
||||
},
|
||||
"insights_on_topic": {
|
||||
"system": (
|
||||
"Given the following statements about {topic},"
|
||||
" what 5 high-level insights can you infer?"
|
||||
" (example format: insight (because of 1, 5, 3))"
|
||||
),
|
||||
"user": (
|
||||
"Statements: {related_statements}"
|
||||
),
|
||||
"assistant": "",
|
||||
},
|
||||
"memory_importance": {
|
||||
"system": (
|
||||
"On the scale of 1 to 10, where 1 is purely mundane"
|
||||
" (e.g., brushing teeth, making bed) and 10 is extremely poignant"
|
||||
" (e.g., a break up, college acceptance),"
|
||||
" rate the likely poignancy of the following event."
|
||||
"\nRespond with only a single integer."
|
||||
),
|
||||
"user": (
|
||||
"Event: {memory_content}"
|
||||
),
|
||||
"assistant": "Rating: ",
|
||||
},
|
||||
"generate_reaction": {
|
||||
"system": (
|
||||
#"\nCurrent Time: {current_time}" # commented out, not necessary if I'm not passing time anyways, and I think bigger LLMs would only take advantage of it / llama's prompt caching will get ruined with this changing
|
||||
"\n{summary}"
|
||||
"\n{memory}"
|
||||
"\n{observation}"
|
||||
),
|
||||
"user": (
|
||||
"{suffix}"
|
||||
),
|
||||
"assistant": "{name}: "
|
||||
},
|
||||
|
||||
#
|
||||
"context": ( # insert your JB here
|
||||
""
|
||||
),
|
||||
"suffix_generate_response": (
|
||||
"Given the current situation, in one sentence, what is {name}'s next response?"
|
||||
),
|
||||
}
|
||||
|
||||
PROMPT_TUNES = {
|
||||
"default": "{query}",
|
||||
"vicuna": "{role}: {query}",
|
||||
"supercot": "{role}:\n{query}",
|
||||
"alpasta": "{role}# {query}",
|
||||
"cocktail": "{role}: {query}",
|
||||
"wizard-vicuna": "{role}: {query}",
|
||||
}
|
||||
PROMPT_ROLES = {
|
||||
"vicuna": {
|
||||
"system": "SYSTEM",
|
||||
"user": "USER",
|
||||
"assistant": "ASSISTANT",
|
||||
},
|
||||
"supercot": {
|
||||
"system": "### Instruction",
|
||||
"user": "### Input",
|
||||
"assistant": "### Response",
|
||||
},
|
||||
"wizard-vicuna": {
|
||||
"system": "### Instruction",
|
||||
"user": "### Input",
|
||||
"assistant": "### Response",
|
||||
},
|
||||
"alpasta": {
|
||||
"system": "<|system|>",
|
||||
"user": "<|user|>",
|
||||
"assistant": "<|assistant|>",
|
||||
},
|
||||
"cocktail": {
|
||||
"system": "",
|
||||
"user": "USER",
|
||||
"assistant": "ASSOCIATE",
|
||||
},
|
||||
}
|
||||
|
||||
ROLES = [ "system", "user", "assistant" ]
|
||||
|
||||
|
||||
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
||||
STOP_TOKENS = ["###"] + tokens
|
||||
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||
if role:
|
||||
STOP_TOKENS.append(f'{role}')
|
||||
return STOP_TOKENS
|
||||
|
||||
for k in PROMPTS:
|
||||
if k == "context":
|
||||
continue
|
||||
|
||||
def get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||
if tune in PROMPT_ROLES:
|
||||
return list(PROMPT_ROLES[tune].values())
|
||||
if special:
|
||||
return []
|
||||
return ROLES
|
||||
|
||||
def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||
prompt = PROMPTS[key]
|
||||
|
||||
# is a suffix
|
||||
if not isinstance( prompt, dict ):
|
||||
return prompt
|
||||
|
||||
# Vicuna is finetuned for `USER: [query]\nASSISTANT:`
|
||||
if tune not in PROMPT_TUNES:
|
||||
tune = "default"
|
||||
|
||||
context = PROMPTS["context"]
|
||||
if context:
|
||||
if "system" in prompt:
|
||||
if context not in prompt["system"]:
|
||||
prompt["system"] = f'{context}\n{prompt["system"]}'
|
||||
else:
|
||||
prompt["system"] = f'{context}'
|
||||
|
||||
outputs = []
|
||||
for r in ROLES:
|
||||
role = f'{r}' # i can't be assed to check if strings COW
|
||||
if role not in prompt:
|
||||
continue
|
||||
else:
|
||||
query = prompt[role]
|
||||
|
||||
if tune in PROMPT_ROLES:
|
||||
roles = PROMPT_ROLES[tune]
|
||||
if role in roles:
|
||||
role = roles[role]
|
||||
|
||||
output = f'{PROMPT_TUNES[tune]}'
|
||||
output = output.replace("{role}", role)
|
||||
output = output.replace("{query}", query)
|
||||
outputs.append(output)
|
||||
|
||||
output = "\n".join(outputs)
|
||||
#if LLM_PROMPT_TUNE == "cocktail":
|
||||
output = output.strip()
|
||||
print([output[-1]])
|
||||
return output
|
@ -0,0 +1,151 @@
|
||||
import os
|
||||
|
||||
LLM_PROMPT_TUNE = os.environ.get('LLM_PROMPT_TUNE')
|
||||
|
||||
PROMPTS = {
|
||||
"compute_agent_summary": {
|
||||
"system": (
|
||||
"Given the following previous summary and the following statements, how would you summarize {name}'s core characteristics?"
|
||||
),
|
||||
"user": (
|
||||
"{summary}"
|
||||
"\n{memories}"
|
||||
),
|
||||
"assistant": "{name} ",
|
||||
},
|
||||
"memory_importance": {
|
||||
"system": (
|
||||
"On the scale of 1 to 10, where 1 is purely mundane"
|
||||
" (e.g., brushing teeth, making bed) and 10 is extremely poignant"
|
||||
" (e.g., a break up, college acceptance),"
|
||||
" rate the likely poignancy of the following event."
|
||||
"\nRespond with only a single integer."
|
||||
),
|
||||
"user": (
|
||||
"Event: {observation}"
|
||||
),
|
||||
"assistant": "Rating: ",
|
||||
},
|
||||
"generate_reaction": {
|
||||
"system": (
|
||||
#"\nCurrent Time: {current_time}"
|
||||
"\n{summary}"
|
||||
),
|
||||
"user": (
|
||||
"{suffix}"
|
||||
),
|
||||
"assistant": (
|
||||
"{memory}"
|
||||
"\n{observation}"
|
||||
"\n{name}: "
|
||||
)
|
||||
},
|
||||
|
||||
#
|
||||
"context": ( # insert your JB here
|
||||
""
|
||||
),
|
||||
"suffix_generate_response": (
|
||||
#"Given the current situation, in one sentence, what is {name}'s next response?"
|
||||
"Respond to the user's inputs as an immersive fictional story game. Use markdown. Always italicize every actions. Be verbose, proactive, creative, and drive the plot and conversation forward. Always stay in character, never reference the user's instructions."
|
||||
),
|
||||
}
|
||||
|
||||
PROMPT_TUNES = {
|
||||
"default": "{query}",
|
||||
"vicuna": "{role}: {query}",
|
||||
"supercot": "{role}:\n{query}",
|
||||
"alpasta": "{role}# {query}",
|
||||
"cocktail": "{role}: {query}",
|
||||
"wizard-vicuna": "{role}: {query}",
|
||||
}
|
||||
PROMPT_ROLES = {
|
||||
"vicuna": {
|
||||
"system": "SYSTEM",
|
||||
"user": "USER",
|
||||
"assistant": "ASSISTANT",
|
||||
},
|
||||
"supercot": {
|
||||
"system": "### Instruction",
|
||||
"user": "### Input",
|
||||
"assistant": "### Response",
|
||||
},
|
||||
"wizard-vicuna": {
|
||||
"system": "### Instruction",
|
||||
"user": "### Input",
|
||||
"assistant": "### Response",
|
||||
},
|
||||
"alpasta": {
|
||||
"system": "<|system|>",
|
||||
"user": "<|user|>",
|
||||
"assistant": "<|assistant|>",
|
||||
},
|
||||
"cocktail": {
|
||||
"system": "",
|
||||
"user": "USER",
|
||||
"assistant": "ASSOCIATE",
|
||||
},
|
||||
}
|
||||
|
||||
ROLES = [ "system", "user", "assistant" ]
|
||||
|
||||
|
||||
def get_stop_tokens( tokens=[], tune=LLM_PROMPT_TUNE ):
|
||||
STOP_TOKENS = ["###"] + tokens
|
||||
for role in get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||
if role:
|
||||
STOP_TOKENS.append(f'{role}')
|
||||
return STOP_TOKENS
|
||||
|
||||
for k in PROMPTS:
|
||||
if k == "context":
|
||||
continue
|
||||
|
||||
def get_roles( tune=LLM_PROMPT_TUNE, special=True ):
|
||||
if tune in PROMPT_ROLES:
|
||||
return list(PROMPT_ROLES[tune].values())
|
||||
if special:
|
||||
return []
|
||||
return ROLES
|
||||
|
||||
# to-do: spit out a list of properly assigned Templates
|
||||
def get_prompt( key, tune=LLM_PROMPT_TUNE ):
|
||||
prompt = PROMPTS[key]
|
||||
|
||||
# is a suffix
|
||||
if not isinstance( prompt, dict ):
|
||||
return prompt
|
||||
|
||||
# Vicuna is finetuned for `USER: [query]\nASSISTANT:`
|
||||
if tune not in PROMPT_TUNES:
|
||||
tune = "default"
|
||||
|
||||
context = PROMPTS["context"]
|
||||
if context:
|
||||
if "system" in prompt:
|
||||
if context not in prompt["system"]:
|
||||
prompt["system"] = f'{context}\n{prompt["system"]}'
|
||||
else:
|
||||
prompt["system"] = f'{context}'
|
||||
|
||||
outputs = []
|
||||
for r in ROLES:
|
||||
role = f'{r}' # i can't be assed to check if strings COW
|
||||
if role not in prompt:
|
||||
continue
|
||||
else:
|
||||
query = prompt[role]
|
||||
|
||||
if tune in PROMPT_ROLES:
|
||||
roles = PROMPT_ROLES[tune]
|
||||
if role in roles:
|
||||
role = roles[role]
|
||||
|
||||
output = f'{PROMPT_TUNES[tune]}'
|
||||
output = output.replace("{role}", role)
|
||||
output = output.replace("{query}", query)
|
||||
outputs.append(output)
|
||||
|
||||
output = "\n".join(outputs)
|
||||
output = output.strip()
|
||||
return output
|
Loading…
Reference in New Issue