actually do duration prediction
This commit is contained in:
parent
5c512717a6
commit
2ccf1b5740
|
@ -280,6 +280,7 @@ class ModelExperimentalSettings:
|
|||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
|
||||
#
|
||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
||||
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
||||
# "auto" will pick best for codec
|
||||
|
|
|
@ -523,8 +523,6 @@ class TTS():
|
|||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
||||
print( len_list )
|
||||
|
||||
kwargs = {}
|
||||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
|
|
|
@ -1250,11 +1250,11 @@ class Base_V2(nn.Module):
|
|||
tasks = self.get_input( inputs, name="task" )
|
||||
|
||||
# grab duration if no resp is provided or len task is requested
|
||||
if tasks[0] == "len" or aux_lens[0][2] == 0:
|
||||
if tasks[0] == "len":
|
||||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the designated token (although this should technically be logit[-1, :1])
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
# it's more accurate this way
|
||||
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
logits = logits_aux
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user