more tweaks (vall_e.webui --yaml still breaks things, --model needs to deduce what audio backend now that im supporting other ones again // added easy top-sampler settings back for new implementation)

This commit is contained in:
mrq 2025-03-14 20:18:25 -05:00
parent 6ee505cffd
commit ca8cc15271
3 changed files with 41 additions and 106 deletions

View File

@ -116,10 +116,28 @@ class BaseConfig:
# load state dict and copy its stored model config
model_kwargs = { "attention": "auto", "training": False, "teacher": False }
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else []
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
model_state_dict = torch_load( model_path ) if model_path and model_path.exists() else None
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
models_config = [ model_state_dict["config"] | { "path": model_path } | model_kwargs ] if model_state_dict is not None else []
loras_config = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
state = { "models": models_config, "loras": loras_config, "trainer": { "load_state_dict": True } }
deduced_backend = None
if model_state_dict is not None:
# 9 audio levels, will always be DAC
if "proms_emb.embs.8.weight" in model_state_dict["module"]:
deduced_backend = "dac"
# 8 audio levels, may be encodec/vocos (1024 tokens) or nemo (1000 tokens)
elif "proms_emb.embs.7.weight" in model_state_dict["module"]:
deduced_backend = "nemo" if model_state_dict["module"]["proms_emb.embs.7.weight"].shape[0] == 1000 else "vocos"
if deduced_backend:
_logger.info(f'Deduced audio backend: {deduced_backend}')
state["audio_backend"] = deduced_backend
return cls(**state)
@classmethod
@ -867,19 +885,19 @@ class Config(BaseConfig):
if audio_backend in ["encodec", "vocos"]:
audio_extension = ".enc"
cfg.sample_rate = 24_000
cfg.model.resp_levels = 8
#cfg.model.resp_levels = 8
elif audio_backend == "dac":
audio_extension = ".dac"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 9
#cfg.model.resp_levels = 9
elif cfg.audio_backend == "audiodec":
audio_extension = ".dec"
cfg.sample_rate = 48_000
cfg.model.resp_levels = 8 # ?
#cfg.model.resp_levels = 8 # ?
elif cfg.audio_backend == "nemo":
audio_extension = ".nem"
cfg.sample_rate = 44_100
cfg.model.resp_levels = 8
#cfg.model.resp_levels = 8
#cfg.model.audio_tokens = 1000
else:
raise Exception(f"Unknown audio backend: {audio_backend}")
@ -1144,6 +1162,8 @@ class Config(BaseConfig):
self.text_tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(text_tokenizer_path))
self.set_audio_backend(self.audio_backend)
# Preserves the old behavior
class NaiveTokenizer:

View File

@ -148,36 +148,6 @@ class AR_NAR_V2(Base_V2):
# final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance( proms, torch.Tensor ):
if quant_level >= proms.shape[-1]:
quant_levels[i] = proms.shape[-1] - 1
elif isinstance( proms, list ):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation
"""
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = resps.shape[0]
for l in range( quant_level ):
for t in range( steps ):
token = resps[t, l].item()
if random.random() < token_dropout_error:
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
"""
# only apply stop token for RVQ level 0
if timesteps[i] is None or (self.predict_causally):
# append stop tokens for AR

View File

@ -81,75 +81,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
# aims to properly encode RVQ-encoded token sequence into an embedding
# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time)
# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness
class ResidualAudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
training: bool = True,
):
super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) # i still don't understand why this needs to be explicitly added instead of it being deduced in the embedding itself
self.cross_attn = nn.MultiheadAttention( embed_dim=token_dim, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True )
self.proj = nn.Linear(token_dim, token_dim) # i don't understand why this is necessary
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# ( seq_len, dim ) => ( seq_len, levels, dim )
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
attn, _ = self.cross_attn( x, x, x )
x = x + attn
x = self.proj( x.mean(dim=1) )
return x
# aims to properly decode the last hidden states from a model into logits for an RVQ-encoded token sequence
class ResidualAudioDecoder(nn.Module):
def __init__(
self,
d_model,
vocab_size,
resp_levels,
training: bool = True,
use_ln: bool = False,
):
super().__init__()
self.projs = nn.ModuleList([nn.Sequential(
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
nn.Linear(d_model, d_model),
) for _ in range(resp_levels)]) # per-level projs
self.cross_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True ) # xattn so each level can attend to others per residual-ness
self.head = nn.Linear(d_model, vocab_size) # final output head, i feel it would be better to have it per-level but i assume the proj handles it
# forward for one sequence
def _forward( self, x: Tensor ) -> Tensor:
seq_len, resp_levels = x.shape[0], len(self.projs)
x = torch.stack([proj(x) for proj in self.projs], dim=1)
attn, _ = self.cross_attn( x, x, x )
x = x + attn
x = self.head( x )
x = x.view( resp_levels, seq_len, -1 )
return x
# required to act on per sequence and not a batch due to headed-ness
def forward( self, x_i: Tensor ) -> Tensor:
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
# the above, but for FSQ codecs, as each level is independent from one another
# this for sure "works" as speech emerges to some extent
# aims to properly encode token sequences into an embedding
# despite being named for FSQ codecs, this works for RVQ codecs
class FiniteAudioEncoder(nn.Module):
def __init__(
self,
@ -1332,6 +1265,18 @@ class Base_V2(nn.Module):
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
logits = [ logit[-self.causal_size:] for logit in logits ]
# perform min_p filtering of our logits
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=-1) for logit in logits ]