fix for bsz>1 because I forgot the old implementation implicitly handles this
This commit is contained in:
parent
068dbdb785
commit
caad99ab78
|
@ -1011,6 +1011,20 @@ class Base_V2(nn.Module):
|
|||
|
||||
break
|
||||
|
||||
# fill in gap to make dim=-1 equal by filling in with -infs
|
||||
# the old implementation inherently does this through the master Classifiers class
|
||||
# but it needs to explicitly be done here
|
||||
dim_neg_1 = 0
|
||||
for batch_index, logit in enumerate( loss_logits ):
|
||||
dim_neg_1 = max( dim_neg_1, logit.shape[-1] )
|
||||
|
||||
for batch_index, logit in enumerate( loss_logits ):
|
||||
if dim_neg_1 == logit.shape[-1]:
|
||||
continue
|
||||
|
||||
loss_logits[batch_index] = torch.cat([logit, torch.full( (logit.shape[0], dim_neg_1 - logit.shape[-1]), -float("inf"), device=logit.device, dtype=logit.dtype) ], dim=-1 )
|
||||
|
||||
|
||||
loss_target = torch.cat( loss_targets )
|
||||
loss_logit = torch.cat( loss_logits )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user