actually do duration prediction
This commit is contained in:
parent
5c512717a6
commit
2ccf1b5740
|
@ -280,6 +280,7 @@ class ModelExperimentalSettings:
|
||||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||||
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
||||||
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
||||||
# "auto" will pick best for codec
|
# "auto" will pick best for codec
|
||||||
|
|
|
@ -523,8 +523,6 @@ class TTS():
|
||||||
# add an additional X seconds
|
# add an additional X seconds
|
||||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||||
|
|
||||||
print( len_list )
|
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if prefix_context is not None:
|
if prefix_context is not None:
|
||||||
kwargs["prefix_context"] = prefix_context
|
kwargs["prefix_context"] = prefix_context
|
||||||
|
|
|
@ -1250,11 +1250,11 @@ class Base_V2(nn.Module):
|
||||||
tasks = self.get_input( inputs, name="task" )
|
tasks = self.get_input( inputs, name="task" )
|
||||||
|
|
||||||
# grab duration if no resp is provided or len task is requested
|
# grab duration if no resp is provided or len task is requested
|
||||||
if tasks[0] == "len" or aux_lens[0][2] == 0:
|
if tasks[0] == "len":
|
||||||
# do duration prediction
|
# do duration prediction
|
||||||
logits_aux = self.len_decoder( output.logits )
|
logits_aux = self.len_decoder( output.logits )
|
||||||
# only keep the designated token (although this should technically be logit[-1, :1])
|
# it's more accurate this way
|
||||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
|
||||||
logits = logits_aux
|
logits = logits_aux
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user