re-added loading multiple models because I'm now entertaining having split AR/NAR models again (and need a way to load both at once)

This commit is contained in:
mrq 2024-06-06 09:48:43 -05:00
parent b05a905b95
commit b2194b859a
11 changed files with 105 additions and 91 deletions

View File

@ -13,7 +13,6 @@ def main():
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--max-nar-levels", type=int, default=7)
@ -40,7 +39,7 @@ def main():
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference(
text=args.text,
references=args.references,

View File

@ -256,7 +256,7 @@ class Model:
if self.interleave:
name.append("interleaved")
else:
name.append(f'{cfg.model.prom_levels}')
name.append(f'{self.prom_levels}')
return "-".join(name)
@ -627,8 +627,7 @@ class Config(_Config):
experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset)
model: Model = field(default_factory=lambda: Model)
models: dict | list | None = None # deprecated
models: dict | list | None = field(default_factory=lambda: [Model])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
@ -643,6 +642,14 @@ class Config(_Config):
audio_backend: str = "vocos"
@property
def model(self):
for i, model in enumerate(self.models):
if model.training:
return model
return self.models[0]
@property
def distributed(self):
return world_size() > 1
@ -681,8 +688,8 @@ class Config(_Config):
if isinstance(self.dataset, type):
self.dataset = dict()
if isinstance(self.model, type):
self.model = dict()
if isinstance(self.models, type):
self.models = dict()
if isinstance(self.hyperparameters, type):
self.hyperparameters = dict()
@ -704,10 +711,14 @@ class Config(_Config):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
"""
if self.models is not None:
self.model = Model(**next(iter(self.models)))
else:
self.model = Model(**self.model)
"""
self.models = [ Model(**model) for model in self.models ]
self.hyperparameters = Hyperparameters(**self.hyperparameters)

View File

@ -26,14 +26,14 @@ from functools import cache
@cache
def load_engines(training=True):
models = get_models(cfg.model.get(), training=training)
models = get_models(cfg.models, training=training)
engines = dict()
for name, model in models.items():
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.hyper_config.training
inferencing = cfg.mode == "inferencing" or not model.config.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
@ -43,7 +43,7 @@ def load_engines(training=True):
engine_class = _Engine if backend == "local" or inferencing else Engine
if inferencing:
model.hyper_config.training = False
model.config.training = False
if cfg.optimizations.replace and cfg.optimizations.linear:
model.model = ml.replace_linear( model.model )
@ -83,7 +83,7 @@ def load_engines(training=True):
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
**params,
)
@ -96,7 +96,7 @@ def load_engines(training=True):
raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}')
optimizer = scheduler_class(
[ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ],
[ param for name, param in model.named_parameters() if name not in model.config.frozen_params ],
lr = params['lr'],
warmup_steps = cfg.hyperparameters.warmup_steps
)
@ -143,12 +143,16 @@ def load_engines(training=True):
del state[k]
# resize text embedding
if cfg.model.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:cfg.model.text_tokens]
if model.config.text_tokens != state["text_emb.weight"].shape[0]:
state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens]
# resize text embedding
if model.config.resp_levels != state["rvq_level_emb.weight"].shape[0]:
state["rvq_level_emb.weight"] = state["rvq_level_emb.weight"][:model.config.resp_levels]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
hyper_config = model.hyper_config
hyper_config = model.config
# wrap if DDP is requested
if ddp:

View File

@ -19,7 +19,7 @@ if deepspeed_available:
import deepspeed
class TTS():
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
def __init__( self, config=None, device=None, amp=None, dtype=None ):
self.loading = True
self.input_sample_rate = 24000
@ -53,32 +53,12 @@ class TTS():
self.symmap = None
if model_ckpt:
state = torch.load(model_ckpt)
self.model = get_models(cfg.model.get(), training=False)[0]
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
self.symmap = state['symmap']
self.engines = load_engines(training=False)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if "module" in state:
state = state['module']
self.model.load_state_dict(state)
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
else:
engines = load_engines(training=False)
for name, engine in engines.items():
self.model = engine.module
break
if self.dtype != torch.int8:
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.model.eval()
self.engines.eval()
if self.symmap is None:
self.symmap = get_phone_symmap()
@ -159,6 +139,15 @@ class TTS():
wavs = []
sr = None
model_ar = None
model_nar = None
for name, engine in self.engines.items():
if "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
if "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
for line in lines:
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
@ -172,7 +161,7 @@ class TTS():
lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.model(
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
@ -183,8 +172,7 @@ class TTS():
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.model(
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,

View File

@ -1,37 +1,36 @@
def get_model(cfg, training=True):
name = cfg.name
def get_model(config, training=True):
name = config.name
if not cfg.experimental:
if not config.experimental:
from .ar_nar import AR_NAR
model = AR_NAR(
n_text_tokens=cfg.text_tokens,
n_audio_tokens=cfg.audio_tokens,
d_model=cfg.dim,
n_heads=cfg.heads,
n_layers=cfg.layers,
n_experts=cfg.experts,
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=cfg.dropout,
p_dropout=config.dropout,
l_padding = cfg.input_alignment,
l_padding = config.input_alignment,
training = training,
config = cfg,
config = config,
)
model._cfg = cfg
else:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=cfg.text_tokens,
n_audio_tokens=cfg.audio_tokens,
d_model=cfg.dim,
n_layers=cfg.layers,
n_heads=cfg.heads,
p_dropout=cfg.dropout,
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
config = cfg,
d_model=config.dim,
n_layers=config.layers,
n_heads=config.heads,
p_dropout=config.dropout,
config = config,
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -16,7 +16,7 @@ class AR_NAR(Base):
@property
def causal(self):
if hasattr(self, "config") and self.config:
return "ar" in self.capabilities
return "ar" in self.config.capabilities
return True
@property
@ -31,6 +31,8 @@ class AR_NAR(Base):
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property
@ -41,18 +43,26 @@ class AR_NAR(Base):
@property
def n_max_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.max_levels
return cfg.model.max_levels
@property
def n_tasks(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tasks
return cfg.model.tasks
@property
def n_langs(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.langs
return cfg.model.langs
@property
def n_tones(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tones
return cfg.model.tones
@property

View File

@ -1,5 +1,7 @@
# https://github.com/kyegomez/BitNet
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
# re-enable logging because zetascale fucking sucks

View File

@ -1,4 +1,6 @@
# https://github.com/state-spaces/mamba
from torch.utils.checkpoint import checkpoint
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):

View File

@ -1,12 +1,13 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
import torch
import torch.nn.functional as F
from transformers import MixtralModel, MixtralConfig
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock
# This is required because batch sizes > 1 throws errors
def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim) # was view()
@ -41,5 +42,4 @@ def Fixed_MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> to
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
Original_MixtralSparseMoeBlock_forward = MixtralSparseMoeBlock.forward
MixtralSparseMoeBlock.forward = Fixed_MixtralSparseMoeBlock_forward
MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward

View File

@ -207,9 +207,9 @@ class Base(nn.Module):
return -100
def loss_factor(self, k):
if self.hyper_config is None:
if self.config is None:
return 1.0
return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
def __init__(
self,
@ -231,8 +231,8 @@ class Base(nn.Module):
):
super().__init__()
self.training = training
self.hyper_config = config
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
self.config = config
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.n_text_tokens = n_text_tokens
self.n_audio_tokens = n_audio_tokens
@ -246,7 +246,7 @@ class Base(nn.Module):
# +1 to include the stop token
n_prom_tokens = n_audio_tokens
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
n_resp_tokens = n_audio_tokens + 1 # (1 if self.causal else 0) interoperability
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -263,13 +263,13 @@ class Base(nn.Module):
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
sums=self.config.audio_embedding_sums if self.config is not None else True,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
sums=self.config.audio_embedding_sums if self.config is not None else True
)
# useless since I actually removed using these with the input processing overhaul...
@ -290,20 +290,20 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
hf_attention = self.hyper_config.attention if self.hyper_config is not None else None
hf_attention = self.config.attention if self.config is not None else None
if self.hyper_config.attention == "auto":
if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
self.hyper_config.attention = "flash"
self.config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS:
self.hyper_config.attention = "xformers"
self.config.attention = "xformers"
else:
self.hyper_config.attention = "mem_efficient"
self.config.attention = "mem_efficient"
if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]:
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None
if self.hyper_config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if self.arch_type == "transformer":
@ -326,7 +326,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
@ -342,7 +342,7 @@ class Base(nn.Module):
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads,
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
sliding_window=75 * 12, # 12 second context window
output_router_logits=training,
hidden_act="gelu",
@ -492,8 +492,8 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.hyper_config.attention )
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -691,7 +691,7 @@ class Base(nn.Module):
quant_levels: Tensor | None = None,
):
# old, "naive" way, no loss factoring
if not self.hyper_config.loss_factors:
if not self.config.loss_factors:
target_list = []
for batch_index, batch in enumerate(inputs):
target = []

View File

@ -54,13 +54,12 @@ def init_tts(restart=False):
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())