This commit is contained in:
mrq 2023-09-07 09:14:03 -05:00
parent 712808494f
commit e7a67410d1
5 changed files with 86 additions and 50 deletions

View File

@ -168,7 +168,7 @@ class Model:
def full_name(self): def full_name(self):
name = [ self.name ] name = [ self.name ]
if self.size != "full": if self.size != "full" and isinstance(self.size, str):
name.append(self.size) name.append(self.size)
if self.arch_type != "transformer": if self.arch_type != "transformer":
@ -287,6 +287,7 @@ class Hyperparameters:
gradient_clipping: int = 100 gradient_clipping: int = 100
optimizer: str = "Adamw" optimizer: str = "Adamw"
torch_optimizer: bool = False
optimizer_params: dict = field(default_factory=lambda: {}) optimizer_params: dict = field(default_factory=lambda: {})
learning_rate: float = 3.25e-4 learning_rate: float = 3.25e-4
@ -328,7 +329,7 @@ class DeepSpeed:
"params": { "params": {
"lr": cfg.hyperparameters.learning_rate, "lr": cfg.hyperparameters.learning_rate,
} }
} if not cfg.hyperparameters.optimizer.endswith("-torch") else None, } if not cfg.hyperparameters.torch_optimizer else None,
"scheduler": { "scheduler": {
"type": cfg.hyperparameters.scheduler_type, "type": cfg.hyperparameters.scheduler_type,
"params": scheduler_params, "params": scheduler_params,

View File

@ -17,7 +17,7 @@ class AR_NAR(Base):
@property @property
def norm_type(self): def norm_type(self):
return "ln" if self.n_resp_levels == 1 else "adaln" return "ln" # if self.n_resp_levels == 1 else "adaln"
@property @property
def arch_type(self) -> str: def arch_type(self) -> str:
@ -202,9 +202,9 @@ def example_usage():
kwargs = { kwargs = {
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1536, # 1536 'd_model': 1024, # 1536
'n_heads': 24, # 24 'n_heads': 16, # 24
'n_layers': 24, # 32 'n_layers': 12, # 32
} }
""" """
@ -219,6 +219,8 @@ def example_usage():
#optimizer = ml.AdamW(model.parameters(), lr=0.0001) #optimizer = ml.AdamW(model.parameters(), lr=0.0001)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
def sample( name, steps=600 ): def sample( name, steps=600 ):
engine.eval() engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
@ -243,7 +245,7 @@ def example_usage():
tqdm.write(f"{stats}") tqdm.write(f"{stats}")
sample("init", 75) #sample("init", 75)
train() train()
sample("final") sample("final")

View File

@ -87,8 +87,31 @@ class MultiEmbedding(nn.Embedding):
x_list = x.split([*map(len, x_list)]) x_list = x.split([*map(len, x_list)])
return x_list return x_list
"""
class PromEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
if len(x_list) == 0:
return []
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
class RespEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0:
return []
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
return res
"""
class Base(nn.Module): class Base(nn.Module):
@property @property
def causal(self) -> bool: def causal(self) -> bool:
@ -130,6 +153,10 @@ class Base(nn.Module):
def dual(self) -> bool: def dual(self) -> bool:
return False return False
@property
def n_embeddings(self):
return self.n_resp_levels if self.dual else 1
@property @property
def stop_token(self): def stop_token(self):
if not self.causal: if not self.causal:
@ -172,12 +199,18 @@ class Base(nn.Module):
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model) self.text_emb = Embedding(n_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
if self.dual: self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)]) if self.n_embeddings == 1:
else:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
else:
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
"""
if self.n_embeddings == 1:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
else:
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
"""
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
@ -262,18 +295,19 @@ class Base(nn.Module):
state: dict | None = None, state: dict | None = None,
): ):
if self.dual: if self.n_embeddings == 1:
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
self.proms_emb(proms_list), self.proms_emb(proms_list),
self.resps_emb[0 if quant_levels is None else 1](resps_list), self.resps_emb(resps_list),
sep=self.sep, sep=self.sep,
) )
else: else:
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
self.proms_emb(proms_list), self.proms_emb(proms_list),
self.resps_emb(resps_list), self.resps_emb[0 if quant_levels is None else 1](resps_list),
#self.resps_emb(resps_list, quant_levels),
sep=self.sep, sep=self.sep,
) )
@ -296,8 +330,11 @@ class Base(nn.Module):
if self.arch_type == "transformer": if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x) x = self.sin_emb.add_pe(x)
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
for block in self.blocks: for block in self.blocks:
x = block(x, m, quant_levels) x = block(x, m, l)
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)

View File

@ -91,8 +91,6 @@ class NAR(Base):
proms_list, proms_list,
prev_list, prev_list,
targ_list, targ_list,
return_all_resp=True,
shift_targ_list=False,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
@ -112,8 +110,6 @@ class NAR(Base):
text_list, text_list,
proms_list, proms_list,
prev_list, prev_list,
return_all_resp=True,
shift_targ_list=False,
quant_levels=quant_levels, quant_levels=quant_levels,
sampling_temperature=sampling_temperature, sampling_temperature=sampling_temperature,
) )

View File

@ -62,8 +62,8 @@ def load_engines(invert=False):
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
# cfg.deepspeed.torch_adam if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"): if cfg.hyperparameters.optimizer.lower() == "adamw":
params = { params = {
"lr": cfg.hyperparameters.learning_rate, "lr": cfg.hyperparameters.learning_rate,
"betas": (0.9, 0.96), "betas": (0.9, 0.96),
@ -75,7 +75,7 @@ def load_engines(invert=False):
model.parameters(), model.parameters(),
**params, **params,
) )
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "sgd") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "sgd-torch"): elif cfg.hyperparameters.optimizer.lower() == "sgd":
params = { params = {
"lr": cfg.hyperparameters.learning_rate, "lr": cfg.hyperparameters.learning_rate,
} }
@ -84,7 +84,7 @@ def load_engines(invert=False):
model.parameters(), model.parameters(),
**params, **params,
) )
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"): elif cfg.hyperparameters.optimizer.lower() == "prodigy":
params = { params = {
"lr": cfg.hyperparameters.learning_rate, "lr": cfg.hyperparameters.learning_rate,
} }