oops
This commit is contained in:
parent
712808494f
commit
e7a67410d1
|
@ -168,7 +168,7 @@ class Model:
|
|||
def full_name(self):
|
||||
name = [ self.name ]
|
||||
|
||||
if self.size != "full":
|
||||
if self.size != "full" and isinstance(self.size, str):
|
||||
name.append(self.size)
|
||||
|
||||
if self.arch_type != "transformer":
|
||||
|
@ -287,6 +287,7 @@ class Hyperparameters:
|
|||
gradient_clipping: int = 100
|
||||
|
||||
optimizer: str = "Adamw"
|
||||
torch_optimizer: bool = False
|
||||
optimizer_params: dict = field(default_factory=lambda: {})
|
||||
learning_rate: float = 3.25e-4
|
||||
|
||||
|
@ -328,7 +329,7 @@ class DeepSpeed:
|
|||
"params": {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
} if not cfg.hyperparameters.optimizer.endswith("-torch") else None,
|
||||
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||
"scheduler": {
|
||||
"type": cfg.hyperparameters.scheduler_type,
|
||||
"params": scheduler_params,
|
||||
|
|
|
@ -17,7 +17,7 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln" if self.n_resp_levels == 1 else "adaln"
|
||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
|
@ -202,9 +202,9 @@ def example_usage():
|
|||
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'd_model': 1536, # 1536
|
||||
'n_heads': 24, # 24
|
||||
'n_layers': 24, # 32
|
||||
'd_model': 1024, # 1536
|
||||
'n_heads': 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
}
|
||||
|
||||
"""
|
||||
|
@ -219,6 +219,8 @@ def example_usage():
|
|||
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
def sample( name, steps=600 ):
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
|
@ -243,7 +245,7 @@ def example_usage():
|
|||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
sample("init", 75)
|
||||
#sample("init", 75)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
|
|
|
@ -87,8 +87,31 @@ class MultiEmbedding(nn.Embedding):
|
|||
x_list = x.split([*map(len, x_list)])
|
||||
|
||||
return x_list
|
||||
"""
|
||||
class PromEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
||||
|
||||
class RespEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
|
||||
return res
|
||||
"""
|
||||
class Base(nn.Module):
|
||||
@property
|
||||
def causal(self) -> bool:
|
||||
|
@ -130,6 +153,10 @@ class Base(nn.Module):
|
|||
def dual(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def n_embeddings(self):
|
||||
return self.n_resp_levels if self.dual else 1
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.causal:
|
||||
|
@ -172,12 +199,18 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
|
||||
if self.dual:
|
||||
self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)])
|
||||
else:
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
if self.n_embeddings == 1:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
|
||||
"""
|
||||
if self.n_embeddings == 1:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
"""
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -262,18 +295,19 @@ class Base(nn.Module):
|
|||
|
||||
state: dict | None = None,
|
||||
):
|
||||
if self.dual:
|
||||
if self.n_embeddings == 1:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
||||
self.resps_emb(resps_list),
|
||||
sep=self.sep,
|
||||
)
|
||||
else:
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
self.proms_emb(proms_list),
|
||||
self.resps_emb(resps_list),
|
||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
||||
#self.resps_emb(resps_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
|
@ -296,8 +330,11 @@ class Base(nn.Module):
|
|||
|
||||
if self.arch_type == "transformer":
|
||||
x = self.sin_emb.add_pe(x)
|
||||
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
|
||||
l = l.to(device)
|
||||
for block in self.blocks:
|
||||
x = block(x, m, quant_levels)
|
||||
x = block(x, m, l)
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
|
||||
|
|
|
@ -91,8 +91,6 @@ class NAR(Base):
|
|||
proms_list,
|
||||
prev_list,
|
||||
targ_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
|
@ -112,8 +110,6 @@ class NAR(Base):
|
|||
text_list,
|
||||
proms_list,
|
||||
prev_list,
|
||||
return_all_resp=True,
|
||||
shift_targ_list=False,
|
||||
quant_levels=quant_levels,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
|
|
|
@ -62,8 +62,8 @@ def load_engines(invert=False):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
# cfg.deepspeed.torch_adam
|
||||
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
|
||||
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
"betas": (0.9, 0.96),
|
||||
|
@ -75,7 +75,7 @@ def load_engines(invert=False):
|
|||
model.parameters(),
|
||||
**params,
|
||||
)
|
||||
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "sgd") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "sgd-torch"):
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ def load_engines(invert=False):
|
|||
model.parameters(),
|
||||
**params,
|
||||
)
|
||||
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"):
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user