added LLaMA/Mixtral (if experts>1) model arches, utilize XMoE's loss as well, set MoE frequency to 1 to make every layer MoE'd for RetNet, etc. (going to do tests without burning out again to see how things go)

This commit is contained in:
mrq 2023-12-22 19:27:36 -06:00
parent 9c198eb75a
commit 0db3203b21
4 changed files with 93 additions and 11 deletions

View File

@ -52,6 +52,7 @@ Training is very dependent on:
* the quality of your dataset. * the quality of your dataset.
* how much data you have. * how much data you have.
* the bandwidth you quantized your audio to. * the bandwidth you quantized your audio to.
* the underlying model architecture used
### Pre-Processed Dataset ### Pre-Processed Dataset

View File

@ -48,6 +48,7 @@ setup(
"omegaconf==2.0.6", "omegaconf==2.0.6",
"tqdm>=4.64.1", "tqdm>=4.64.1",
"humanize>=4.4.0", "humanize>=4.4.0",
"transformer>4.36.0",
"pandas>=1.5.0", "pandas>=1.5.0",
"torch>=1.13.0", "torch>=1.13.0",

View File

@ -132,10 +132,13 @@ class AR_NAR(Base):
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
else: else:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else: else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
"""
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
@ -338,7 +341,7 @@ def example_usage():
'd_model': 256, 'd_model': 256,
'n_heads': 4, 'n_heads': 4,
'n_layers': 12, 'n_layers': 12,
'n_experts': 1, 'n_experts': 8,
} }
""" """
@ -349,7 +352,7 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 250 steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0) optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
@ -385,6 +388,10 @@ def example_usage():
tqdm.write(f"{stats}") tqdm.write(f"{stats}")
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
sample("init", 5) sample("init", 5)
train() train()
sample("final") sample("final")

View File

@ -14,10 +14,32 @@ from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample from ..samplers import reptition_penalize, length_penalize, top_k_top_p_filtering, dynamic_temperature, top_k_logits_list, mirostat_sample
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
except Exception as e:
print("Error importing `transformer` arch:", e)
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from transformers import LlamaModel, LlamaConfig
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
except Exception as e:
print("Error importing `mixtral` arch:", e)
def _create_mask(l, device): def _create_mask(l, device):
"""1 is valid region and 0 is invalid.""" """1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -254,9 +276,40 @@ class Base(nn.Module):
norm_type=self.norm_type, norm_type=self.norm_type,
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,
) for _ in range(n_layers) ]) ) for _ in range(n_layers) ])
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
self.retnet = RetNetDecoder(RetNetConfig( self.model = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens, vocab_size=n_resp_tokens,
decoder_embed_dim=d_model, decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2, decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads, decoder_retention_heads=n_heads,
@ -278,8 +331,9 @@ class Base(nn.Module):
# MoE # MoE
use_xmoe=n_experts>1, use_xmoe=n_experts>1,
moe_freq=2, moe_freq=1,
moe_expert_count=n_experts, moe_expert_count=n_experts,
moe_gating_use_fp32=False,
)) ))
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -326,6 +380,7 @@ class Base(nn.Module):
) )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
aux_loss = None
device = x.device device = x.device
@ -336,12 +391,26 @@ class Base(nn.Module):
# run the initial prompt to fill the KV cache # run the initial prompt to fill the KV cache
for n in range(prefill_size): for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1) xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True) self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
# grab last token(s) # grab last token(s)
x = x[:, -1, :].unsqueeze(1) x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model
elif self.arch_type == "llama":
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
)
if self.n_experts > 1:
kwargs["output_router_logits"] = True
if self.arch_type == "transformer": t = self.model(**kwargs)
x = t[0]
if self.n_experts > 1:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
# ensures we specify a quant_level for the transformer implementation's AdaLN # ensures we specify a quant_level for the transformer implementation's AdaLN
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device) l = l.to(device)
@ -352,8 +421,9 @@ class Base(nn.Module):
x = block(x, m, l) x = block(x, m, l)
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
# pass our inputs through the RetNet # pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
if _ is not None and "l_aux" in _:
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
# output projection layer with masking # output projection layer with masking
x = self.classifier(x) * m x = self.classifier(x) * m
@ -390,6 +460,9 @@ class Base(nn.Module):
acc = self.accuracy_metric( inputs, target ), acc = self.accuracy_metric( inputs, target ),
precision = self.precision_metric( inputs, target ), precision = self.precision_metric( inputs, target ),
) )
if aux_loss is not None:
self.loss["nll"] += aux_loss
return logits return logits