added Mistral (non-Mixtral) backend, useless optimization when not training, proper adjustment of the LR for Prodigyopt through d_coeff (maybe), recurrent sampling for LLaMA/Mistral/Mixtral backends (again, doesn't actually work)
This commit is contained in:
parent
cce929e136
commit
3da1518ace
|
@ -25,8 +25,8 @@ except Exception as e:
|
|||
from functools import cache
|
||||
|
||||
@cache
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
def load_engines(training=True):
|
||||
models = get_models(cfg.models.get(), training=training)
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
|
@ -59,6 +59,9 @@ def load_engines():
|
|||
optimizer = ml.SGD
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
optimizer_class = ml.Prodigy
|
||||
|
||||
params['d_coef'] = params['lr']
|
||||
params['lr'] = 1.0
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
|
|
|
@ -193,13 +193,17 @@ class Engine():
|
|||
def get_lr(self):
|
||||
lrs = []
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
if 'd_coeff' in param_group:
|
||||
lrs.append(param_group['d_coeff'])
|
||||
elif 'lr' in param_group:
|
||||
lrs.append(param_group['lr'])
|
||||
return lrs
|
||||
|
||||
def set_lr(self, lr):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
if 'lr' in param_group:
|
||||
if 'd_coeff' in param_group:
|
||||
param_group['d_coeff'] = lr
|
||||
elif 'lr' in param_group:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def get_global_grad_norm(self):
|
||||
|
|
|
@ -99,7 +99,7 @@ class Engine(DeepSpeedEngine):
|
|||
try:
|
||||
if hasattr(self.optimizer, 'param_groups'):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
|
||||
else:
|
||||
self.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
|
|
|
@ -73,7 +73,7 @@ class TTS():
|
|||
self.ar_ckpt = ar_ckpt
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
||||
models = get_models(cfg.models.get())
|
||||
models = get_models(cfg.models.get(), training=False)
|
||||
|
||||
for name, model in models.items():
|
||||
if name.startswith("ar"):
|
||||
|
@ -101,7 +101,7 @@ class TTS():
|
|||
self.loading = False
|
||||
|
||||
def load_models( self ):
|
||||
engines = load_engines()
|
||||
engines = load_engines(training=False)
|
||||
for name, engine in engines.items():
|
||||
if name.startswith("ar"):
|
||||
self.ar = engine.module
|
||||
|
|
|
@ -2,7 +2,7 @@ from .ar import AR
|
|||
from .nar import NAR
|
||||
from .ar_nar import AR_NAR
|
||||
|
||||
def get_model(cfg):
|
||||
def get_model(cfg, training=True):
|
||||
if cfg.name == "ar":
|
||||
Model = AR
|
||||
elif cfg.name == "nar":
|
||||
|
@ -20,6 +20,7 @@ def get_model(cfg):
|
|||
n_layers=cfg.layers,
|
||||
n_experts=cfg.experts,
|
||||
|
||||
training=training,
|
||||
config = cfg,
|
||||
)
|
||||
model._cfg = cfg
|
||||
|
@ -28,5 +29,5 @@ def get_model(cfg):
|
|||
|
||||
return model
|
||||
|
||||
def get_models(models):
|
||||
return { model.full_name: get_model(model) for model in models }
|
||||
def get_models(models, training=True):
|
||||
return { model.full_name: get_model(model, training=training) for model in models }
|
||||
|
|
|
@ -120,7 +120,7 @@ class AR_NAR(Base):
|
|||
if n_levels == self.n_resp_levels:
|
||||
# might be better to have this decided on the dataloader level
|
||||
|
||||
if cfg.experimental:
|
||||
if cfg.experimental and False:
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
|
@ -228,13 +228,22 @@ class AR_NAR(Base):
|
|||
else:
|
||||
resps_list = self._unsqueeze_list(sequence_list)
|
||||
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
state=recurrent_state
|
||||
)
|
||||
if recurrent_state is not None:
|
||||
logits, recurrent_state = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
state=recurrent_state
|
||||
)
|
||||
else:
|
||||
logits = super().forward(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
state=recurrent_state
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
|
|
|
@ -34,6 +34,12 @@ except Exception as e:
|
|||
print("Error importing `llama` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import MistralModel, MistralConfig
|
||||
except Exception as e:
|
||||
print("Error importing `mistral` arch:", e)
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers import MixtralModel, MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||
|
@ -269,9 +275,11 @@ class Base(nn.Module):
|
|||
|
||||
n_experts: int=1,
|
||||
|
||||
training = True,
|
||||
config = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.training = training
|
||||
self.config = config
|
||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||
|
||||
|
@ -312,21 +320,21 @@ class Base(nn.Module):
|
|||
self.blocks = nn.ModuleList([TransformerBlock(
|
||||
d_model=d_model,
|
||||
n_heads=n_heads,
|
||||
p_dropout=p_dropout,
|
||||
p_dropout=p_dropout if training else 0.0,
|
||||
causal=self.causal,
|
||||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
elif self.arch_type == "llama":
|
||||
elif self.arch_type == "mistral":
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
self.model = MistralModel(MistralConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
|
@ -340,14 +348,51 @@ class Base(nn.Module):
|
|||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
output_router_logits=training,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
vocab_size =n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||
intermediate_size=d_model*4,
|
||||
num_hidden_layers=n_layers,
|
||||
num_attention_heads=n_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=n_heads,
|
||||
sliding_window=75 * 12, # 12 second context window
|
||||
output_router_logits=training,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
))
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
kwargs = dict(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -356,7 +401,7 @@ class Base(nn.Module):
|
|||
decoder_retention_heads=n_heads,
|
||||
decoder_ffn_embed_dim=d_model * 4,
|
||||
decoder_layers=n_layers,
|
||||
dropout=p_dropout,
|
||||
dropout=p_dropout if training else 0.0,
|
||||
checkpoint_activations=self.activation_checkpointing,
|
||||
activation_fn="gelu",
|
||||
use_layernorm=True, # self.version < 3,
|
||||
|
@ -409,7 +454,7 @@ class Base(nn.Module):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
state: dict | None = None,
|
||||
state: dict | list | None = None,
|
||||
):
|
||||
batch_size = len(text_list)
|
||||
|
||||
|
@ -441,18 +486,25 @@ class Base(nn.Module):
|
|||
# grab last token(s)
|
||||
x = x[:, -1, :].unsqueeze(1)
|
||||
# HF transformer derived model
|
||||
elif self.arch_type == "llama":
|
||||
elif self.arch_type == "llama" or self.arch_type == "mistral":
|
||||
kwargs = dict(
|
||||
#attention_mask=m,
|
||||
inputs_embeds=x,
|
||||
past_key_values=state,
|
||||
use_cache=state is not None,
|
||||
# return_dict=True,
|
||||
)
|
||||
if self.n_experts > 1:
|
||||
if self.n_experts > 1 and targ_list is not None:
|
||||
kwargs["output_router_logits"] = True
|
||||
|
||||
t = self.model(**kwargs)
|
||||
|
||||
x = t[0]
|
||||
|
||||
if self.n_experts > 1:
|
||||
if state is not None:
|
||||
state = t[1]
|
||||
|
||||
if self.n_experts > 1 and targ_list is not None:
|
||||
router_logits = t[-1]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||
elif self.arch_type == "transformer":
|
||||
|
@ -477,7 +529,6 @@ class Base(nn.Module):
|
|||
|
||||
# compute loss if the target is given
|
||||
if targ_list is not None:
|
||||
|
||||
target_list = self._samplewise_merge_tensors(
|
||||
text_list,
|
||||
lang_list,
|
||||
|
@ -509,7 +560,7 @@ class Base(nn.Module):
|
|||
if aux_loss is not None:
|
||||
self.loss["nll"] += aux_loss
|
||||
|
||||
return logits
|
||||
return (logits, state) if state is not None else logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user