added Mistral (non-Mixtral) backend, useless optimization when not training, proper adjustment of the LR for Prodigyopt through d_coeff (maybe), recurrent sampling for LLaMA/Mistral/Mixtral backends (again, doesn't actually work)
This commit is contained in:
parent
cce929e136
commit
3da1518ace
|
@ -25,8 +25,8 @@ except Exception as e:
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def load_engines():
|
def load_engines(training=True):
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get(), training=training)
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
|
@ -59,6 +59,9 @@ def load_engines():
|
||||||
optimizer = ml.SGD
|
optimizer = ml.SGD
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||||
optimizer_class = ml.Prodigy
|
optimizer_class = ml.Prodigy
|
||||||
|
|
||||||
|
params['d_coef'] = params['lr']
|
||||||
|
params['lr'] = 1.0
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
|
|
|
@ -193,13 +193,17 @@ class Engine():
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lrs = []
|
lrs = []
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self.optimizer.param_groups:
|
||||||
if 'lr' in param_group:
|
if 'd_coeff' in param_group:
|
||||||
|
lrs.append(param_group['d_coeff'])
|
||||||
|
elif 'lr' in param_group:
|
||||||
lrs.append(param_group['lr'])
|
lrs.append(param_group['lr'])
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
def set_lr(self, lr):
|
def set_lr(self, lr):
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self.optimizer.param_groups:
|
||||||
if 'lr' in param_group:
|
if 'd_coeff' in param_group:
|
||||||
|
param_group['d_coeff'] = lr
|
||||||
|
elif 'lr' in param_group:
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
def get_global_grad_norm(self):
|
def get_global_grad_norm(self):
|
||||||
|
|
|
@ -99,7 +99,7 @@ class Engine(DeepSpeedEngine):
|
||||||
try:
|
try:
|
||||||
if hasattr(self.optimizer, 'param_groups'):
|
if hasattr(self.optimizer, 'param_groups'):
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self.optimizer.param_groups:
|
||||||
param_group['lr'] = lr
|
param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
|
||||||
else:
|
else:
|
||||||
self.optimizer.set_lr(lr)
|
self.optimizer.set_lr(lr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -73,7 +73,7 @@ class TTS():
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
self.nar_ckpt = nar_ckpt
|
self.nar_ckpt = nar_ckpt
|
||||||
|
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get(), training=False)
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar"):
|
||||||
|
@ -101,7 +101,7 @@ class TTS():
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_models( self ):
|
def load_models( self ):
|
||||||
engines = load_engines()
|
engines = load_engines(training=False)
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar"):
|
||||||
self.ar = engine.module
|
self.ar = engine.module
|
||||||
|
|
|
@ -2,7 +2,7 @@ from .ar import AR
|
||||||
from .nar import NAR
|
from .nar import NAR
|
||||||
from .ar_nar import AR_NAR
|
from .ar_nar import AR_NAR
|
||||||
|
|
||||||
def get_model(cfg):
|
def get_model(cfg, training=True):
|
||||||
if cfg.name == "ar":
|
if cfg.name == "ar":
|
||||||
Model = AR
|
Model = AR
|
||||||
elif cfg.name == "nar":
|
elif cfg.name == "nar":
|
||||||
|
@ -20,6 +20,7 @@ def get_model(cfg):
|
||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
n_experts=cfg.experts,
|
n_experts=cfg.experts,
|
||||||
|
|
||||||
|
training=training,
|
||||||
config = cfg,
|
config = cfg,
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
|
@ -28,5 +29,5 @@ def get_model(cfg):
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_models(models):
|
def get_models(models, training=True):
|
||||||
return { model.full_name: get_model(model) for model in models }
|
return { model.full_name: get_model(model, training=training) for model in models }
|
||||||
|
|
|
@ -120,7 +120,7 @@ class AR_NAR(Base):
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
# might be better to have this decided on the dataloader level
|
# might be better to have this decided on the dataloader level
|
||||||
|
|
||||||
if cfg.experimental:
|
if cfg.experimental and False:
|
||||||
# makes higher levels less likely
|
# makes higher levels less likely
|
||||||
def generate( lo=0, hi=8 ):
|
def generate( lo=0, hi=8 ):
|
||||||
index = lo
|
index = lo
|
||||||
|
@ -228,6 +228,15 @@ class AR_NAR(Base):
|
||||||
else:
|
else:
|
||||||
resps_list = self._unsqueeze_list(sequence_list)
|
resps_list = self._unsqueeze_list(sequence_list)
|
||||||
|
|
||||||
|
if recurrent_state is not None:
|
||||||
|
logits, recurrent_state = super().forward(
|
||||||
|
text_list=text_list,
|
||||||
|
proms_list=proms_list,
|
||||||
|
resps_list=resps_list,
|
||||||
|
lang_list=lang_list,
|
||||||
|
state=recurrent_state
|
||||||
|
)
|
||||||
|
else:
|
||||||
logits = super().forward(
|
logits = super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
|
|
@ -34,6 +34,12 @@ except Exception as e:
|
||||||
print("Error importing `llama` arch:", e)
|
print("Error importing `llama` arch:", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import MistralModel, MistralConfig
|
||||||
|
except Exception as e:
|
||||||
|
print("Error importing `mistral` arch:", e)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import MixtralModel, MixtralConfig
|
from transformers import MixtralModel, MixtralConfig
|
||||||
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
|
||||||
|
@ -269,9 +275,11 @@ class Base(nn.Module):
|
||||||
|
|
||||||
n_experts: int=1,
|
n_experts: int=1,
|
||||||
|
|
||||||
|
training = True,
|
||||||
config = None,
|
config = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.training = training
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
|
@ -312,21 +320,21 @@ class Base(nn.Module):
|
||||||
self.blocks = nn.ModuleList([TransformerBlock(
|
self.blocks = nn.ModuleList([TransformerBlock(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_heads=n_heads,
|
n_heads=n_heads,
|
||||||
p_dropout=p_dropout,
|
p_dropout=p_dropout if training else 0.0,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "mistral":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
self.model = LlamaModel(LlamaConfig(
|
self.model = MistralModel(MistralConfig(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
hidden_size=d_model,
|
hidden_size=d_model,
|
||||||
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||||
intermediate_size=d_model*4,
|
intermediate_size=d_model*4,
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=n_heads,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
|
@ -340,14 +348,51 @@ class Base(nn.Module):
|
||||||
intermediate_size=d_model*4,
|
intermediate_size=d_model*4,
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=n_heads,
|
||||||
|
sliding_window=75 * 12, # 12 second context window
|
||||||
|
output_router_logits=training,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
))
|
))
|
||||||
|
elif self.arch_type == "llama":
|
||||||
|
if n_experts <= 1:
|
||||||
|
self.model = LlamaModel(LlamaConfig(
|
||||||
|
vocab_size=n_resp_tokens,
|
||||||
|
hidden_size=d_model,
|
||||||
|
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||||
|
intermediate_size=d_model*4,
|
||||||
|
num_hidden_layers=n_layers,
|
||||||
|
num_attention_heads=n_heads,
|
||||||
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
|
num_key_value_heads=n_heads,
|
||||||
|
sliding_window=75 * 12, # 12 second context window
|
||||||
|
hidden_act="gelu",
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
is_decoder=True,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
vocab_size =n_resp_tokens,
|
||||||
|
hidden_size=d_model,
|
||||||
|
max_position_embeddings=75 * 60, # max-length of 60 seconds
|
||||||
|
intermediate_size=d_model*4,
|
||||||
|
num_hidden_layers=n_layers,
|
||||||
|
num_attention_heads=n_heads,
|
||||||
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
|
num_key_value_heads=n_heads,
|
||||||
|
sliding_window=75 * 12, # 12 second context window
|
||||||
|
output_router_logits=training,
|
||||||
|
hidden_act="gelu",
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
is_decoder=True,
|
||||||
|
num_local_experts=n_experts,
|
||||||
|
num_experts_per_tok=min(2, n_experts),
|
||||||
|
))
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
|
@ -356,7 +401,7 @@ class Base(nn.Module):
|
||||||
decoder_retention_heads=n_heads,
|
decoder_retention_heads=n_heads,
|
||||||
decoder_ffn_embed_dim=d_model * 4,
|
decoder_ffn_embed_dim=d_model * 4,
|
||||||
decoder_layers=n_layers,
|
decoder_layers=n_layers,
|
||||||
dropout=p_dropout,
|
dropout=p_dropout if training else 0.0,
|
||||||
checkpoint_activations=self.activation_checkpointing,
|
checkpoint_activations=self.activation_checkpointing,
|
||||||
activation_fn="gelu",
|
activation_fn="gelu",
|
||||||
use_layernorm=True, # self.version < 3,
|
use_layernorm=True, # self.version < 3,
|
||||||
|
@ -409,7 +454,7 @@ class Base(nn.Module):
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
quant_levels: Tensor | None = None,
|
quant_levels: Tensor | None = None,
|
||||||
state: dict | None = None,
|
state: dict | list | None = None,
|
||||||
):
|
):
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
|
||||||
|
@ -441,18 +486,25 @@ class Base(nn.Module):
|
||||||
# grab last token(s)
|
# grab last token(s)
|
||||||
x = x[:, -1, :].unsqueeze(1)
|
x = x[:, -1, :].unsqueeze(1)
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama" or self.arch_type == "mistral":
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
#attention_mask=m,
|
#attention_mask=m,
|
||||||
inputs_embeds=x,
|
inputs_embeds=x,
|
||||||
|
past_key_values=state,
|
||||||
|
use_cache=state is not None,
|
||||||
|
# return_dict=True,
|
||||||
)
|
)
|
||||||
if self.n_experts > 1:
|
if self.n_experts > 1 and targ_list is not None:
|
||||||
kwargs["output_router_logits"] = True
|
kwargs["output_router_logits"] = True
|
||||||
|
|
||||||
t = self.model(**kwargs)
|
t = self.model(**kwargs)
|
||||||
|
|
||||||
x = t[0]
|
x = t[0]
|
||||||
|
|
||||||
if self.n_experts > 1:
|
if state is not None:
|
||||||
|
state = t[1]
|
||||||
|
|
||||||
|
if self.n_experts > 1 and targ_list is not None:
|
||||||
router_logits = t[-1]
|
router_logits = t[-1]
|
||||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||||
elif self.arch_type == "transformer":
|
elif self.arch_type == "transformer":
|
||||||
|
@ -477,7 +529,6 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if targ_list is not None:
|
||||||
|
|
||||||
target_list = self._samplewise_merge_tensors(
|
target_list = self._samplewise_merge_tensors(
|
||||||
text_list,
|
text_list,
|
||||||
lang_list,
|
lang_list,
|
||||||
|
@ -509,7 +560,7 @@ class Base(nn.Module):
|
||||||
if aux_loss is not None:
|
if aux_loss is not None:
|
||||||
self.loss["nll"] += aux_loss
|
self.loss["nll"] += aux_loss
|
||||||
|
|
||||||
return logits
|
return (logits, state) if state is not None else logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user