From 7a0956863d5b5a599e88234b04758ee15c91849f Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 31 Mar 2025 21:11:43 -0500 Subject: [PATCH] oops --- vall_e/models/base_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 9e4a06d..910e8fd 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1077,7 +1077,7 @@ class Base_V2(nn.Module): aux_loss_logit = torch.cat( logits_aux ) if self.len_use_logits: - aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0) + aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten() loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor else: aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second