From 2ccf1b57408c44650702548d04746b1ff33d6d8d Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Tue, 11 Mar 2025 22:14:54 -0500
Subject: [PATCH] actually do duration prediction

---
 vall_e/config.py         | 1 +
 vall_e/inference.py      | 2 --
 vall_e/models/base_v2.py | 6 +++---
 3 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/vall_e/config.py b/vall_e/config.py
index 4e3f4b1..363e003 100755
--- a/vall_e/config.py
+++ b/vall_e/config.py
@@ -280,6 +280,7 @@ class ModelExperimentalSettings:
 	len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
 
 	# 
+	logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
 	per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
 	audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
 	# "auto" will pick best for codec
diff --git a/vall_e/inference.py b/vall_e/inference.py
index 082c042..7f7edb5 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -523,8 +523,6 @@ class TTS():
 						# add an additional X seconds
 						len_list = [ int(l * duration_padding) for l in len_list ]
 
-					print( len_list )
-
 					kwargs = {}
 					if prefix_context is not None:
 						kwargs["prefix_context"] = prefix_context
diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py
index bd1662d..c2b070f 100644
--- a/vall_e/models/base_v2.py
+++ b/vall_e/models/base_v2.py
@@ -1250,11 +1250,11 @@ class Base_V2(nn.Module):
 			tasks = self.get_input( inputs, name="task" )
 
 			# grab duration if no resp is provided or len task is requested
-			if tasks[0] == "len" or aux_lens[0][2] == 0:
+			if tasks[0] == "len":
 				# do duration prediction
 				logits_aux = self.len_decoder( output.logits )
-				# only keep the designated token (although this should technically be logit[-1, :1])
-				logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
+				# it's more accurate this way
+				logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
 
 				logits = logits_aux