oops
This commit is contained in:
parent
712808494f
commit
e7a67410d1
|
@ -168,7 +168,7 @@ class Model:
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
name = [ self.name ]
|
name = [ self.name ]
|
||||||
|
|
||||||
if self.size != "full":
|
if self.size != "full" and isinstance(self.size, str):
|
||||||
name.append(self.size)
|
name.append(self.size)
|
||||||
|
|
||||||
if self.arch_type != "transformer":
|
if self.arch_type != "transformer":
|
||||||
|
@ -287,6 +287,7 @@ class Hyperparameters:
|
||||||
gradient_clipping: int = 100
|
gradient_clipping: int = 100
|
||||||
|
|
||||||
optimizer: str = "Adamw"
|
optimizer: str = "Adamw"
|
||||||
|
torch_optimizer: bool = False
|
||||||
optimizer_params: dict = field(default_factory=lambda: {})
|
optimizer_params: dict = field(default_factory=lambda: {})
|
||||||
learning_rate: float = 3.25e-4
|
learning_rate: float = 3.25e-4
|
||||||
|
|
||||||
|
@ -328,7 +329,7 @@ class DeepSpeed:
|
||||||
"params": {
|
"params": {
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
}
|
}
|
||||||
} if not cfg.hyperparameters.optimizer.endswith("-torch") else None,
|
} if not cfg.hyperparameters.torch_optimizer else None,
|
||||||
"scheduler": {
|
"scheduler": {
|
||||||
"type": cfg.hyperparameters.scheduler_type,
|
"type": cfg.hyperparameters.scheduler_type,
|
||||||
"params": scheduler_params,
|
"params": scheduler_params,
|
||||||
|
|
|
@ -17,7 +17,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def norm_type(self):
|
def norm_type(self):
|
||||||
return "ln" if self.n_resp_levels == 1 else "adaln"
|
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arch_type(self) -> str:
|
def arch_type(self) -> str:
|
||||||
|
@ -202,9 +202,9 @@ def example_usage():
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1536, # 1536
|
'd_model': 1024, # 1536
|
||||||
'n_heads': 24, # 24
|
'n_heads': 16, # 24
|
||||||
'n_layers': 24, # 32
|
'n_layers': 12, # 32
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -219,6 +219,8 @@ def example_usage():
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
|
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
@ -243,7 +245,7 @@ def example_usage():
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
sample("init", 75)
|
#sample("init", 75)
|
||||||
train()
|
train()
|
||||||
sample("final")
|
sample("final")
|
||||||
|
|
||||||
|
|
|
@ -87,8 +87,31 @@ class MultiEmbedding(nn.Embedding):
|
||||||
x_list = x.split([*map(len, x_list)])
|
x_list = x.split([*map(len, x_list)])
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
"""
|
||||||
|
class PromEmbedding(nn.Module):
|
||||||
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.n_levels = n_levels
|
||||||
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||||
|
|
||||||
|
def forward(self, x_list: list[Tensor] ) -> list[Tensor]:
|
||||||
|
if len(x_list) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [ sum([ self.embeddings[k](xi[:, k]) for k in range(xi.shape[-1]) ]) for i, xi in enumerate(x_list) ]
|
||||||
|
|
||||||
|
class RespEmbedding(nn.Module):
|
||||||
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.n_levels = n_levels
|
||||||
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||||
|
|
||||||
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||||
|
if len(x_list) == 0:
|
||||||
|
return []
|
||||||
|
res = [ self.embeddings[quant_levels[i] if quant_levels is not None else 0](xi) for i, xi in enumerate(x_list) ]
|
||||||
|
return res
|
||||||
|
"""
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
def causal(self) -> bool:
|
def causal(self) -> bool:
|
||||||
|
@ -130,6 +153,10 @@ class Base(nn.Module):
|
||||||
def dual(self) -> bool:
|
def dual(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_embeddings(self):
|
||||||
|
return self.n_resp_levels if self.dual else 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
if not self.causal:
|
if not self.causal:
|
||||||
|
@ -172,12 +199,18 @@ class Base(nn.Module):
|
||||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
|
||||||
|
|
||||||
if self.dual:
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
self.resps_emb = nn.ModuleList([MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(2)])
|
if self.n_embeddings == 1:
|
||||||
else:
|
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
else:
|
||||||
|
self.resps_emb = nn.ModuleList([ MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) for _ in range(self.n_embeddings) ])
|
||||||
|
"""
|
||||||
|
if self.n_embeddings == 1:
|
||||||
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
else:
|
||||||
|
self.resps_emb = RespEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
"""
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
@ -262,18 +295,19 @@ class Base(nn.Module):
|
||||||
|
|
||||||
state: dict | None = None,
|
state: dict | None = None,
|
||||||
):
|
):
|
||||||
if self.dual:
|
if self.n_embeddings == 1:
|
||||||
x_list = self._samplewise_merge_tensors(
|
x_list = self._samplewise_merge_tensors(
|
||||||
self.text_emb(text_list),
|
self.text_emb(text_list),
|
||||||
self.proms_emb(proms_list),
|
self.proms_emb(proms_list),
|
||||||
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
self.resps_emb(resps_list),
|
||||||
sep=self.sep,
|
sep=self.sep,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x_list = self._samplewise_merge_tensors(
|
x_list = self._samplewise_merge_tensors(
|
||||||
self.text_emb(text_list),
|
self.text_emb(text_list),
|
||||||
self.proms_emb(proms_list),
|
self.proms_emb(proms_list),
|
||||||
self.resps_emb(resps_list),
|
self.resps_emb[0 if quant_levels is None else 1](resps_list),
|
||||||
|
#self.resps_emb(resps_list, quant_levels),
|
||||||
sep=self.sep,
|
sep=self.sep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -296,8 +330,11 @@ class Base(nn.Module):
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
x = self.sin_emb.add_pe(x)
|
x = self.sin_emb.add_pe(x)
|
||||||
|
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
|
||||||
|
l = l.to(device)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m, quant_levels)
|
x = block(x, m, l)
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
|
|
||||||
|
|
|
@ -91,8 +91,6 @@ class NAR(Base):
|
||||||
proms_list,
|
proms_list,
|
||||||
prev_list,
|
prev_list,
|
||||||
targ_list,
|
targ_list,
|
||||||
return_all_resp=True,
|
|
||||||
shift_targ_list=False,
|
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -112,8 +110,6 @@ class NAR(Base):
|
||||||
text_list,
|
text_list,
|
||||||
proms_list,
|
proms_list,
|
||||||
prev_list,
|
prev_list,
|
||||||
return_all_resp=True,
|
|
||||||
shift_targ_list=False,
|
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,8 +62,8 @@ def load_engines(invert=False):
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
# cfg.deepspeed.torch_adam
|
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
if (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "adamw") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "adamw-torch"):
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
params = {
|
params = {
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
"betas": (0.9, 0.96),
|
"betas": (0.9, 0.96),
|
||||||
|
@ -75,7 +75,7 @@ def load_engines(invert=False):
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "sgd") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "sgd-torch"):
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||||
params = {
|
params = {
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ def load_engines(invert=False):
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
elif (cfg.trainer.backend == "local" and cfg.hyperparameters.optimizer.lower() == "prodigy") or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.optimizer.lower() == "prodigy-torch"):
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||||
params = {
|
params = {
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user