diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index fd289252..d8e3f90a 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -139,10 +139,23 @@ class AudioMiniEncoderWithClassifierHead(nn.Module): super().__init__() self.enc = AudioMiniEncoder(**kwargs) self.head = nn.Linear(self.enc.dim, classes) + self.num_classes = classes - def forward(self, x): + def forward(self, x, labels=None): h = self.enc(x) - return self.head(h) + logits = self.head(h) + if labels is None: + return logits + else: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1)) + zero_extra_mass[:, 0] = -.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + loss = nn.functional.cross_entropy(logits, oh_labels) + return loss class QueryProvidedAttentionBlock(nn.Module):