diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index fc6a509..1610b7e 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1011,6 +1011,20 @@ class Base_V2(nn.Module): break + # fill in gap to make dim=-1 equal by filling in with -infs + # the old implementation inherently does this through the master Classifiers class + # but it needs to explicitly be done here + dim_neg_1 = 0 + for batch_index, logit in enumerate( loss_logits ): + dim_neg_1 = max( dim_neg_1, logit.shape[-1] ) + + for batch_index, logit in enumerate( loss_logits ): + if dim_neg_1 == logit.shape[-1]: + continue + + loss_logits[batch_index] = torch.cat([logit, torch.full( (logit.shape[0], dim_neg_1 - logit.shape[-1]), -float("inf"), device=logit.device, dtype=logit.dtype) ], dim=-1 ) + + loss_target = torch.cat( loss_targets ) loss_logit = torch.cat( loss_logits )