diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 9e4a06d..910e8fd 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1077,7 +1077,7 @@ class Base_V2(nn.Module): aux_loss_logit = torch.cat( logits_aux ) if self.len_use_logits: - aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0) + aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten() loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor else: aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second