This commit is contained in:
mrq 2025-03-31 21:11:43 -05:00
parent a1184586ef
commit 7a0956863d

View File

@ -1077,7 +1077,7 @@ class Base_V2(nn.Module):
aux_loss_logit = torch.cat( logits_aux )
if self.len_use_logits:
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten()
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
else:
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second