added Mistral (non-Mixtral) backend, useless optimization when not training, proper adjustment of the LR for Prodigyopt through d_coeff (maybe), recurrent sampling for LLaMA/Mistral/Mixtral backends (again, doesn't actually work)

This commit is contained in:
mrq 2024-01-31 21:48:36 -06:00
parent cce929e136
commit 3da1518ace
7 changed files with 98 additions and 30 deletions

View File

@ -25,8 +25,8 @@ except Exception as e:
from functools import cache
@cache
def load_engines():
models = get_models(cfg.models.get())
def load_engines(training=True):
models = get_models(cfg.models.get(), training=training)
engines = dict()
for name, model in models.items():
@ -59,6 +59,9 @@ def load_engines():
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
params['d_coef'] = params['lr']
params['lr'] = 1.0
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')

View File

@ -193,13 +193,17 @@ class Engine():
def get_lr(self):
lrs = []
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
lrs.append(param_group['d_coeff'])
elif 'lr' in param_group:
lrs.append(param_group['lr'])
return lrs
def set_lr(self, lr):
for param_group in self.optimizer.param_groups:
if 'lr' in param_group:
if 'd_coeff' in param_group:
param_group['d_coeff'] = lr
elif 'lr' in param_group:
param_group['lr'] = lr
def get_global_grad_norm(self):

View File

@ -99,7 +99,7 @@ class Engine(DeepSpeedEngine):
try:
if hasattr(self.optimizer, 'param_groups'):
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
else:
self.optimizer.set_lr(lr)
except Exception as e:

View File

@ -73,7 +73,7 @@ class TTS():
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
models = get_models(cfg.models.get(), training=False)
for name, model in models.items():
if name.startswith("ar"):
@ -101,7 +101,7 @@ class TTS():
self.loading = False
def load_models( self ):
engines = load_engines()
engines = load_engines(training=False)
for name, engine in engines.items():
if name.startswith("ar"):
self.ar = engine.module

View File

@ -2,7 +2,7 @@ from .ar import AR
from .nar import NAR
from .ar_nar import AR_NAR
def get_model(cfg):
def get_model(cfg, training=True):
if cfg.name == "ar":
Model = AR
elif cfg.name == "nar":
@ -20,6 +20,7 @@ def get_model(cfg):
n_layers=cfg.layers,
n_experts=cfg.experts,
training=training,
config = cfg,
)
model._cfg = cfg
@ -28,5 +29,5 @@ def get_model(cfg):
return model
def get_models(models):
return { model.full_name: get_model(model) for model in models }
def get_models(models, training=True):
return { model.full_name: get_model(model, training=training) for model in models }

View File

@ -120,7 +120,7 @@ class AR_NAR(Base):
if n_levels == self.n_resp_levels:
# might be better to have this decided on the dataloader level
if cfg.experimental:
if cfg.experimental and False:
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo
@ -228,13 +228,22 @@ class AR_NAR(Base):
else:
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
if recurrent_state is not None:
logits, recurrent_state = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
else:
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
state=recurrent_state
)
r = super().sample(
logits=logits,

View File

@ -34,6 +34,12 @@ except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from transformers import MistralModel, MistralConfig
except Exception as e:
print("Error importing `mistral` arch:", e)
pass
try:
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
@ -269,9 +275,11 @@ class Base(nn.Module):
n_experts: int=1,
training = True,
config = None,
):
super().__init__()
self.training = training
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
@ -312,21 +320,21 @@ class Base(nn.Module):
self.blocks = nn.ModuleList([TransformerBlock(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout,
p_dropout=p_dropout if training else 0.0,
causal=self.causal,
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "llama":
elif self.arch_type == "mistral":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
self.model = MistralModel(MistralConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
@ -340,14 +348,51 @@ class Base(nn.Module):
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "llama":
if n_experts <= 1:
self.model = LlamaModel(LlamaConfig(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
))
elif self.arch_type == "retnet":
kwargs = dict(
vocab_size=n_resp_tokens,
@ -356,7 +401,7 @@ class Base(nn.Module):
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
dropout=p_dropout if training else 0.0,
checkpoint_activations=self.activation_checkpointing,
activation_fn="gelu",
use_layernorm=True, # self.version < 3,
@ -409,7 +454,7 @@ class Base(nn.Module):
lang_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | None = None,
state: dict | list | None = None,
):
batch_size = len(text_list)
@ -441,18 +486,25 @@ class Base(nn.Module):
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
# HF transformer derived model
elif self.arch_type == "llama":
elif self.arch_type == "llama" or self.arch_type == "mistral":
kwargs = dict(
#attention_mask=m,
inputs_embeds=x,
past_key_values=state,
use_cache=state is not None,
# return_dict=True,
)
if self.n_experts > 1:
if self.n_experts > 1 and targ_list is not None:
kwargs["output_router_logits"] = True
t = self.model(**kwargs)
x = t[0]
if self.n_experts > 1:
if state is not None:
state = t[1]
if self.n_experts > 1 and targ_list is not None:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
@ -477,7 +529,6 @@ class Base(nn.Module):
# compute loss if the target is given
if targ_list is not None:
target_list = self._samplewise_merge_tensors(
text_list,
lang_list,
@ -509,7 +560,7 @@ class Base(nn.Module):
if aux_loss is not None:
self.loss["nll"] += aux_loss
return logits
return (logits, state) if state is not None else logits
def sample(
self,