diff --git a/vall_e/config.py b/vall_e/config.py index 4e3f4b1..363e003 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -280,6 +280,7 @@ class ModelExperimentalSettings: len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) # + logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training # "auto" will pick best for codec diff --git a/vall_e/inference.py b/vall_e/inference.py index 082c042..7f7edb5 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -523,8 +523,6 @@ class TTS(): # add an additional X seconds len_list = [ int(l * duration_padding) for l in len_list ] - print( len_list ) - kwargs = {} if prefix_context is not None: kwargs["prefix_context"] = prefix_context diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index bd1662d..c2b070f 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1250,11 +1250,11 @@ class Base_V2(nn.Module): tasks = self.get_input( inputs, name="task" ) # grab duration if no resp is provided or len task is requested - if tasks[0] == "len" or aux_lens[0][2] == 0: + if tasks[0] == "len": # do duration prediction logits_aux = self.len_decoder( output.logits ) - # only keep the designated token (although this should technically be logit[-1, :1]) - logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + # it's more accurate this way + logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ] logits = logits_aux