oops
This commit is contained in:
parent
a1184586ef
commit
7a0956863d
|
@ -1077,7 +1077,7 @@ class Base_V2(nn.Module):
|
|||
aux_loss_logit = torch.cat( logits_aux )
|
||||
|
||||
if self.len_use_logits:
|
||||
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
|
||||
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten()
|
||||
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
else:
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||
|
|
Loading…
Reference in New Issue
Block a user