added LLaMA/Mixtral (if experts>1) model arches, utilize XMoE's loss as well, set MoE frequency to 1 to make every layer MoE'd for RetNet, etc. (going to do tests without burning out again to see how things go)

This commit is contained in:
mrq 2023-12-22 19:27:36 -06:00
parent 9c198eb75a
commit 0db3203b21
4 changed files with 93 additions and 11 deletions

View File

@ -52,6 +52,7 @@ Training is very dependent on:
* the quality of your dataset.
* how much data you have.
* the bandwidth you quantized your audio to.
* the underlying model architecture used
### Pre-Processed Dataset

View File

@ -48,6 +48,7 @@ setup(
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"transformer>4.36.0",
"pandas>=1.5.0",
"torch>=1.13.0",

View File

@ -132,10 +132,13 @@ class AR_NAR(Base):
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
else:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
"""
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
@ -338,7 +341,7 @@ def example_usage():
'd_model': 256,
'n_heads': 4,
'n_layers': 12,
'n_experts': 1,
'n_experts': 8,
}
"""
@ -349,7 +352,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 250
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
@ -385,6 +388,10 @@ def example_usage():
tqdm.write(f"{stats}")
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
sample("init", 5)
train()
sample("final")

View File

@ -14,10 +14,32 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
except Exception as e:
print("Error importing `transformer` arch:", e)
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from transformers import LlamaModel, LlamaConfig
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
except Exception as e:
print("Error importing `mixtral` arch:", e)
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -254,9 +276,40 @@ class Base(nn.Module):
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "retnet":
self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens,
self.model = RetNetDecoder(RetNetConfig(
vocab_size=n_resp_tokens,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
@ -278,8 +331,9 @@ class Base(nn.Module):
# MoE
use_xmoe=n_experts>1,
moe_freq=2,
moe_freq=1,
moe_expert_count=n_experts,
moe_gating_use_fp32=False,
))
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -326,6 +380,7 @@ class Base(nn.Module):
)
x, m = list_to_tensor(x_list)
aux_loss = None
device = x.device
@ -336,12 +391,26 @@ class Base(nn.Module):
# run the initial prompt to fill the KV cache
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model
elif self.arch_type == "llama":
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
)
if self.n_experts > 1:
kwargs["output_router_logits"] = True
if self.arch_type == "transformer":
t = self.model(**kwargs)
x = t[0]
if self.n_experts > 1:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
# ensures we specify a quant_level for the transformer implementation's AdaLN
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
@ -352,8 +421,9 @@ class Base(nn.Module):
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
# output projection layer with masking
x = self.classifier(x) * m
@ -390,6 +460,9 @@ class Base(nn.Module):
acc = self.accuracy_metric( inputs, target ),
precision = self.precision_metric( inputs, target ),
)
if aux_loss is not None:
self.loss["nll"] += aux_loss
return logits