fix for bsz>1 because I forgot the old implementation implicitly handles this
This commit is contained in:
parent
068dbdb785
commit
caad99ab78
@ -1011,6 +1011,20 @@ class Base_V2(nn.Module):
|
|||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# fill in gap to make dim=-1 equal by filling in with -infs
|
||||||
|
# the old implementation inherently does this through the master Classifiers class
|
||||||
|
# but it needs to explicitly be done here
|
||||||
|
dim_neg_1 = 0
|
||||||
|
for batch_index, logit in enumerate( loss_logits ):
|
||||||
|
dim_neg_1 = max( dim_neg_1, logit.shape[-1] )
|
||||||
|
|
||||||
|
for batch_index, logit in enumerate( loss_logits ):
|
||||||
|
if dim_neg_1 == logit.shape[-1]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loss_logits[batch_index] = torch.cat([logit, torch.full( (logit.shape[0], dim_neg_1 - logit.shape[-1]), -float("inf"), device=logit.device, dtype=logit.dtype) ], dim=-1 )
|
||||||
|
|
||||||
|
|
||||||
loss_target = torch.cat( loss_targets )
|
loss_target = torch.cat( loss_targets )
|
||||||
loss_logit = torch.cat( loss_logits )
|
loss_logit = torch.cat( loss_logits )
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user