actually do duration prediction

This commit is contained in:
mrq 2025-03-11 22:14:54 -05:00
parent 5c512717a6
commit 2ccf1b5740
3 changed files with 4 additions and 5 deletions

View File

@ -280,6 +280,7 @@ class ModelExperimentalSettings:
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
# "auto" will pick best for codec

View File

@ -523,8 +523,6 @@ class TTS():
# add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ]
print( len_list )
kwargs = {}
if prefix_context is not None:
kwargs["prefix_context"] = prefix_context

View File

@ -1250,11 +1250,11 @@ class Base_V2(nn.Module):
tasks = self.get_input( inputs, name="task" )
# grab duration if no resp is provided or len task is requested
if tasks[0] == "len" or aux_lens[0][2] == 0:
if tasks[0] == "len":
# do duration prediction
logits_aux = self.len_decoder( output.logits )
# only keep the designated token (although this should technically be logit[-1, :1])
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
# it's more accurate this way
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits = logits_aux