diff --git a/vall_e/config.py b/vall_e/config.py index e9a1dfa..a622aa0 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -116,10 +116,28 @@ class BaseConfig: # load state dict and copy its stored model config model_kwargs = { "attention": "auto", "training": False, "teacher": False } - model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else [] - lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else [] - state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } } + model_state_dict = torch_load( model_path ) if model_path and model_path.exists() else None + lora_state_dict = torch_load( lora_path ) if lora_path and lora_path.exists() else None + + models_config = [ model_state_dict["config"] | { "path": model_path } | model_kwargs ] if model_state_dict is not None else [] + loras_config = [ lora_state_dict["config"] | { "path": lora_path } ] if lora_state_dict is not None else [] + + state = { "models": models_config, "loras": loras_config, "trainer": { "load_state_dict": True } } + + deduced_backend = None + if model_state_dict is not None: + # 9 audio levels, will always be DAC + if "proms_emb.embs.8.weight" in model_state_dict["module"]: + deduced_backend = "dac" + # 8 audio levels, may be encodec/vocos (1024 tokens) or nemo (1000 tokens) + elif "proms_emb.embs.7.weight" in model_state_dict["module"]: + deduced_backend = "nemo" if model_state_dict["module"]["proms_emb.embs.7.weight"].shape[0] == 1000 else "vocos" + + if deduced_backend: + _logger.info(f'Deduced audio backend: {deduced_backend}') + state["audio_backend"] = deduced_backend + return cls(**state) @classmethod @@ -867,19 +885,19 @@ class Config(BaseConfig): if audio_backend in ["encodec", "vocos"]: audio_extension = ".enc" cfg.sample_rate = 24_000 - cfg.model.resp_levels = 8 + #cfg.model.resp_levels = 8 elif audio_backend == "dac": audio_extension = ".dac" cfg.sample_rate = 44_100 - cfg.model.resp_levels = 9 + #cfg.model.resp_levels = 9 elif cfg.audio_backend == "audiodec": audio_extension = ".dec" cfg.sample_rate = 48_000 - cfg.model.resp_levels = 8 # ? + #cfg.model.resp_levels = 8 # ? elif cfg.audio_backend == "nemo": audio_extension = ".nem" cfg.sample_rate = 44_100 - cfg.model.resp_levels = 8 + #cfg.model.resp_levels = 8 #cfg.model.audio_tokens = 1000 else: raise Exception(f"Unknown audio backend: {audio_backend}") @@ -1144,6 +1162,8 @@ class Config(BaseConfig): self.text_tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(text_tokenizer_path)) + self.set_audio_backend(self.audio_backend) + # Preserves the old behavior class NaiveTokenizer: diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 369ee1f..df7ddbc 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -148,36 +148,6 @@ class AR_NAR_V2(Base_V2): # final validations and stuff for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): - # cap quant_level if it exceeds its corresponding resp/prom - # this was needed for when my DAC-encoded audio was erroneously trimmed to 8 RVQ levels instead of 9 - if quant_level >= resps.shape[-1]: - quant_levels[i] = resps.shape[-1] - 1 - - # proms could be a Tensor, list[Tensor], or None - if isinstance( proms, torch.Tensor ): - if quant_level >= proms.shape[-1]: - quant_levels[i] = proms.shape[-1] - 1 - - elif isinstance( proms, list ): - for j, prom in enumerate( proms ): - if not isinstance( prom, torch.Tensor ): - continue - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 - - # apply token dropout error compensation - """ - if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): - steps = resps.shape[0] - for l in range( quant_level ): - for t in range( steps ): - token = resps[t, l].item() - - if random.random() < token_dropout_error: - offset = 1 * ( 1 if random.random() < 0.5 else -1 ) - resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 - """ - # only apply stop token for RVQ level 0 if timesteps[i] is None or (self.predict_causally): # append stop tokens for AR diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 8fb1aef..770907c 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -81,75 +81,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ): x[..., level] = torch.where( dropout_mask, lhs, rhs ) return x -# aims to properly encode RVQ-encoded token sequence into an embedding -# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time) -# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness -class ResidualAudioEncoder(nn.Module): - def __init__( - self, - n_tokens: int, - n_levels: int, - token_dim: int, - training: bool = True, - ): - super().__init__() - self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)]) - self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) # i still don't understand why this needs to be explicitly added instead of it being deduced in the embedding itself - self.cross_attn = nn.MultiheadAttention( embed_dim=token_dim, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True ) - self.proj = nn.Linear(token_dim, token_dim) # i don't understand why this is necessary - - def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: - # empty - if xi.shape[0] == 0: - dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0] - return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype) - if dropout_mask is not None: - xi = _dropout_codes( xi, dropout_mask, dropout_token ) - - # ( seq_len, dim ) => ( seq_len, levels, dim ) - x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1) - x = x + self.pos_embedding - attn, _ = self.cross_attn( x, x, x ) - x = x + attn - x = self.proj( x.mean(dim=1) ) - - return x -# aims to properly decode the last hidden states from a model into logits for an RVQ-encoded token sequence -class ResidualAudioDecoder(nn.Module): - def __init__( - self, - d_model, - vocab_size, - resp_levels, - training: bool = True, - use_ln: bool = False, - ): - super().__init__() - - self.projs = nn.ModuleList([nn.Sequential( - (nn.LayerNorm(d_model) if use_ln else nn.Identity()), - nn.Linear(d_model, d_model), - ) for _ in range(resp_levels)]) # per-level projs - - self.cross_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True ) # xattn so each level can attend to others per residual-ness - self.head = nn.Linear(d_model, vocab_size) # final output head, i feel it would be better to have it per-level but i assume the proj handles it - - # forward for one sequence - def _forward( self, x: Tensor ) -> Tensor: - seq_len, resp_levels = x.shape[0], len(self.projs) - x = torch.stack([proj(x) for proj in self.projs], dim=1) - attn, _ = self.cross_attn( x, x, x ) - x = x + attn - x = self.head( x ) - x = x.view( resp_levels, seq_len, -1 ) - return x - - # required to act on per sequence and not a batch due to headed-ness - def forward( self, x_i: Tensor ) -> Tensor: - return torch.stack([ self._forward(x) for x in x_i ], dim=0) - -# the above, but for FSQ codecs, as each level is independent from one another -# this for sure "works" as speech emerges to some extent +# aims to properly encode token sequences into an embedding +# despite being named for FSQ codecs, this works for RVQ codecs class FiniteAudioEncoder(nn.Module): def __init__( self, @@ -1332,6 +1265,18 @@ class Base_V2(nn.Module): seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ] logits = [ logit[-self.causal_size:] for logit in logits ] + # perform min_p filtering of our logits + if min_p > 0.0: + logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] + + # perform top_k/top_p filtering of our logits + if top_k > 0 or top_p < 1.0: + logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] + + # do top-no logit processing + if top_no > 0.0: + logits = [ top_no_logits_processing(logit) for logit in logits ] + # argmax instead if temperature <= 0.0: res = [ logit.argmax(dim=-1) for logit in logits ]