fixes and compat (MoE-fying an existing model and retraining from there just ruins it after a second of audio...)
This commit is contained in:
parent
e513d2ef19
commit
c690aa509d
|
@ -184,7 +184,7 @@ class Model:
|
||||||
name.append(self.size)
|
name.append(self.size)
|
||||||
|
|
||||||
if self.arch_type != "transformer":
|
if self.arch_type != "transformer":
|
||||||
if self.experts:
|
if self.experts > 1:
|
||||||
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
name.append(f'{self.experts}x'+self.arch_type.replace("/", "-"))
|
||||||
else:
|
else:
|
||||||
name.append(self.arch_type.replace("/", "-"))
|
name.append(self.arch_type.replace("/", "-"))
|
||||||
|
|
|
@ -13,6 +13,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
||||||
from ..models import get_models
|
from ..models import get_models
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
import torch
|
import torch
|
||||||
|
import re
|
||||||
|
|
||||||
deepspeed_available = False
|
deepspeed_available = False
|
||||||
try:
|
try:
|
||||||
|
@ -90,6 +91,22 @@ def load_engines():
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state["module"]
|
state = state["module"]
|
||||||
|
|
||||||
|
# maintain compat if I change variable names
|
||||||
|
insert = {}
|
||||||
|
erase = []
|
||||||
|
|
||||||
|
for k in state.keys():
|
||||||
|
key = re.sub(r'^retnet\.', "model.", k)
|
||||||
|
if k != key:
|
||||||
|
insert[key] = state[k]
|
||||||
|
erase.append(k)
|
||||||
|
|
||||||
|
for k in insert.keys():
|
||||||
|
state[k] = insert[k]
|
||||||
|
|
||||||
|
for k in erase:
|
||||||
|
del state[k]
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# deepspeed inferencing
|
# deepspeed inferencing
|
||||||
|
|
|
@ -175,40 +175,47 @@ class TTS():
|
||||||
mirostat_eta=0.1,
|
mirostat_eta=0.1,
|
||||||
out_path=None
|
out_path=None
|
||||||
):
|
):
|
||||||
if out_path is None:
|
lines = text.split("\n")
|
||||||
out_path = f"./data/{cfg.start_time}.wav"
|
|
||||||
|
|
||||||
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
wavs = []
|
||||||
phns = self.encode_text( text, language=language )
|
sr = None
|
||||||
lang = self.encode_lang( language )
|
|
||||||
|
|
||||||
prom = to_device(prom, self.device).to(torch.int16)
|
for line in lines:
|
||||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
if out_path is None:
|
||||||
lang = to_device(lang, self.device).to(torch.uint8)
|
out_path = f"./data/{cfg.start_time}.wav"
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
prom = self.encode_audio( references, trim_length=input_prompt_length )
|
||||||
resps_list = self.ar(
|
phns = self.encode_text( line, language=language )
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
lang = self.encode_lang( language )
|
||||||
sampling_temperature=ar_temp,
|
|
||||||
sampling_min_temperature=min_ar_temp,
|
|
||||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
|
||||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
|
||||||
sampling_length_penalty=length_penalty,
|
|
||||||
sampling_beam_width=beam_width,
|
|
||||||
sampling_mirostat_tau=mirostat_tau,
|
|
||||||
sampling_mirostat_eta=mirostat_eta,
|
|
||||||
)
|
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
|
||||||
resps_list = self.nar(
|
|
||||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
|
||||||
max_levels=max_nar_levels,
|
|
||||||
sampling_temperature=nar_temp,
|
|
||||||
sampling_min_temperature=min_nar_temp,
|
|
||||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
|
||||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
|
||||||
)
|
|
||||||
|
|
||||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
prom = to_device(prom, self.device).to(torch.int16)
|
||||||
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
|
lang = to_device(lang, self.device).to(torch.uint8)
|
||||||
|
|
||||||
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||||
|
resps_list = self.ar(
|
||||||
|
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
|
||||||
|
sampling_temperature=ar_temp,
|
||||||
|
sampling_min_temperature=min_ar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
sampling_length_penalty=length_penalty,
|
||||||
|
sampling_beam_width=beam_width,
|
||||||
|
sampling_mirostat_tau=mirostat_tau,
|
||||||
|
sampling_mirostat_eta=mirostat_eta,
|
||||||
|
)
|
||||||
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
|
resps_list = self.nar(
|
||||||
|
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
|
||||||
|
max_levels=max_nar_levels,
|
||||||
|
sampling_temperature=nar_temp,
|
||||||
|
sampling_min_temperature=min_nar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||||
|
wavs.append(wav)
|
||||||
|
|
||||||
return (wav, sr)
|
return (torch.concat(wavs, dim=-1), sr)
|
||||||
|
|
||||||
|
|
|
@ -327,7 +327,6 @@ def example_usage():
|
||||||
proms_list = proms_list[:1]
|
proms_list = proms_list[:1]
|
||||||
resps_list = resps_list[:1]
|
resps_list = resps_list[:1]
|
||||||
|
|
||||||
"""
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024, # 256, # 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
|
@ -343,6 +342,7 @@ def example_usage():
|
||||||
'n_layers': 12,
|
'n_layers': 12,
|
||||||
'n_experts': 8,
|
'n_experts': 8,
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -308,7 +308,7 @@ class Base(nn.Module):
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
))
|
))
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
self.model = RetNetDecoder(RetNetConfig(
|
kwargs = dict(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
decoder_embed_dim=d_model,
|
decoder_embed_dim=d_model,
|
||||||
decoder_value_embed_dim =d_model * 2,
|
decoder_value_embed_dim =d_model * 2,
|
||||||
|
@ -328,13 +328,17 @@ class Base(nn.Module):
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||||
|
)
|
||||||
|
|
||||||
# MoE
|
if n_experts > 1:
|
||||||
use_xmoe=n_experts>1,
|
kwargs.update(dict(
|
||||||
moe_freq=1,
|
use_xmoe=True,
|
||||||
moe_expert_count=n_experts,
|
moe_freq=1,
|
||||||
moe_gating_use_fp32=False,
|
moe_expert_count=n_experts,
|
||||||
))
|
moe_gating_use_fp32=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
|
@ -422,7 +426,7 @@ class Base(nn.Module):
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
# pass our inputs through the RetNet
|
# pass our inputs through the RetNet
|
||||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
if _ is not None and "l_aux" in _:
|
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
|
Loading…
Reference in New Issue
Block a user