oops
This commit is contained in:
parent
a1184586ef
commit
7a0956863d
|
@ -1077,7 +1077,7 @@ class Base_V2(nn.Module):
|
||||||
aux_loss_logit = torch.cat( logits_aux )
|
aux_loss_logit = torch.cat( logits_aux )
|
||||||
|
|
||||||
if self.len_use_logits:
|
if self.len_use_logits:
|
||||||
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0)
|
aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).flatten()
|
||||||
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
else:
|
else:
|
||||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||||
|
|
Loading…
Reference in New Issue
Block a user