more tweaks (vall_e.webui --yaml still breaks things, --model needs to deduce what audio backend now that im supporting other ones again // added easy top-sampler settings back for new implementation)
This commit is contained in:
parent
6ee505cffd
commit
ca8cc15271
|
@ -116,10 +116,28 @@ class BaseConfig:
|
|||
|
||||
# load state dict and copy its stored model config
|
||||
model_kwargs = { "attention": "auto", "training": False, "teacher": False }
|
||||
model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else []
|
||||
lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else []
|
||||
|
||||
state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } }
|
||||
model_state_dict = torch_load( model_path ) if model_path and model_path.exists() else None
|
||||
lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None
|
||||
|
||||
models_config = [ model_state_dict["config"] | { "path": model_path } | model_kwargs ] if model_state_dict is not None else []
|
||||
loras_config = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else []
|
||||
|
||||
state = { "models": models_config, "loras": loras_config, "trainer": { "load_state_dict": True } }
|
||||
|
||||
deduced_backend = None
|
||||
if model_state_dict is not None:
|
||||
# 9 audio levels, will always be DAC
|
||||
if "proms_emb.embs.8.weight" in model_state_dict["module"]:
|
||||
deduced_backend = "dac"
|
||||
# 8 audio levels, may be encodec/vocos (1024 tokens) or nemo (1000 tokens)
|
||||
elif "proms_emb.embs.7.weight" in model_state_dict["module"]:
|
||||
deduced_backend = "nemo" if model_state_dict["module"]["proms_emb.embs.7.weight"].shape[0] == 1000 else "vocos"
|
||||
|
||||
if deduced_backend:
|
||||
_logger.info(f'Deduced audio backend: {deduced_backend}')
|
||||
state["audio_backend"] = deduced_backend
|
||||
|
||||
return cls(**state)
|
||||
|
||||
@classmethod
|
||||
|
@ -867,19 +885,19 @@ class Config(BaseConfig):
|
|||
if audio_backend in ["encodec", "vocos"]:
|
||||
audio_extension = ".enc"
|
||||
cfg.sample_rate = 24_000
|
||||
cfg.model.resp_levels = 8
|
||||
#cfg.model.resp_levels = 8
|
||||
elif audio_backend == "dac":
|
||||
audio_extension = ".dac"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 9
|
||||
#cfg.model.resp_levels = 9
|
||||
elif cfg.audio_backend == "audiodec":
|
||||
audio_extension = ".dec"
|
||||
cfg.sample_rate = 48_000
|
||||
cfg.model.resp_levels = 8 # ?
|
||||
#cfg.model.resp_levels = 8 # ?
|
||||
elif cfg.audio_backend == "nemo":
|
||||
audio_extension = ".nem"
|
||||
cfg.sample_rate = 44_100
|
||||
cfg.model.resp_levels = 8
|
||||
#cfg.model.resp_levels = 8
|
||||
#cfg.model.audio_tokens = 1000
|
||||
else:
|
||||
raise Exception(f"Unknown audio backend: {audio_backend}")
|
||||
|
@ -1144,6 +1162,8 @@ class Config(BaseConfig):
|
|||
|
||||
self.text_tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(text_tokenizer_path))
|
||||
|
||||
self.set_audio_backend(self.audio_backend)
|
||||
|
||||
|
||||
# Preserves the old behavior
|
||||
class NaiveTokenizer:
|
||||
|
|
|
@ -148,36 +148,6 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
# final validations and stuff
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
# this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# apply token dropout error compensation
|
||||
"""
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
"""
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if timesteps[i] is None or (self.predict_causally):
|
||||
# append stop tokens for AR
|
||||
|
|
|
@ -81,75 +81,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
|||
x[..., level] = torch.where( dropout_mask, lhs, rhs )
|
||||
return x
|
||||
|
||||
# aims to properly encode RVQ-encoded token sequence into an embedding
|
||||
# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time)
|
||||
# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness
|
||||
class ResidualAudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
training: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) # i still don't understand why this needs to be explicitly added instead of it being deduced in the embedding itself
|
||||
self.cross_attn = nn.MultiheadAttention( embed_dim=token_dim, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True )
|
||||
self.proj = nn.Linear(token_dim, token_dim) # i don't understand why this is necessary
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
|
||||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
# ( seq_len, dim ) => ( seq_len, levels, dim )
|
||||
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
|
||||
x = x + self.pos_embedding
|
||||
attn, _ = self.cross_attn( x, x, x )
|
||||
x = x + attn
|
||||
x = self.proj( x.mean(dim=1) )
|
||||
|
||||
return x
|
||||
# aims to properly decode the last hidden states from a model into logits for an RVQ-encoded token sequence
|
||||
class ResidualAudioDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
vocab_size,
|
||||
resp_levels,
|
||||
training: bool = True,
|
||||
use_ln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.projs = nn.ModuleList([nn.Sequential(
|
||||
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
|
||||
nn.Linear(d_model, d_model),
|
||||
) for _ in range(resp_levels)]) # per-level projs
|
||||
|
||||
self.cross_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True ) # xattn so each level can attend to others per residual-ness
|
||||
self.head = nn.Linear(d_model, vocab_size) # final output head, i feel it would be better to have it per-level but i assume the proj handles it
|
||||
|
||||
# forward for one sequence
|
||||
def _forward( self, x: Tensor ) -> Tensor:
|
||||
seq_len, resp_levels = x.shape[0], len(self.projs)
|
||||
x = torch.stack([proj(x) for proj in self.projs], dim=1)
|
||||
attn, _ = self.cross_attn( x, x, x )
|
||||
x = x + attn
|
||||
x = self.head( x )
|
||||
x = x.view( resp_levels, seq_len, -1 )
|
||||
return x
|
||||
|
||||
# required to act on per sequence and not a batch due to headed-ness
|
||||
def forward( self, x_i: Tensor ) -> Tensor:
|
||||
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
|
||||
|
||||
# the above, but for FSQ codecs, as each level is independent from one another
|
||||
# this for sure "works" as speech emerges to some extent
|
||||
# aims to properly encode token sequences into an embedding
|
||||
# despite being named for FSQ codecs, this works for RVQ codecs
|
||||
class FiniteAudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1332,6 +1265,18 @@ class Base_V2(nn.Module):
|
|||
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# do top-no logit processing
|
||||
if top_no > 0.0:
|
||||
logits = [ top_no_logits_processing(logit) for logit in logits ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
res = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user