fix for bsz>1 because I forgot the old implementation implicitly handles this

This commit is contained in:
mrq 2025-04-02 17:17:37 -05:00
parent 068dbdb785
commit caad99ab78

View File

@ -1011,6 +1011,20 @@ class Base_V2(nn.Module):
break
# fill in gap to make dim=-1 equal by filling in with -infs
# the old implementation inherently does this through the master Classifiers class
# but it needs to explicitly be done here
dim_neg_1 = 0
for batch_index, logit in enumerate( loss_logits ):
dim_neg_1 = max( dim_neg_1, logit.shape[-1] )
for batch_index, logit in enumerate( loss_logits ):
if dim_neg_1 == logit.shape[-1]:
continue
loss_logits[batch_index] = torch.cat([logit, torch.full( (logit.shape[0], dim_neg_1 - logit.shape[-1]), -float("inf"), device=logit.device, dtype=logit.dtype) ], dim=-1 )
loss_target = torch.cat( loss_targets )
loss_logit = torch.cat( loss_logits )