minicoder with classifier head: spread out probability mass for 0 predictions

This commit is contained in:
James Betker 2022-03-08 15:51:31 -07:00
parent 29b2921222
commit f56edb2122

View File

@ -139,10 +139,23 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
self.num_classes = classes
def forward(self, x):
def forward(self, x, labels=None):
h = self.enc(x)
return self.head(h)
logits = self.head(h)
if labels is None:
return logits
else:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
zero_extra_mass[:, 0] = -.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
loss = nn.functional.cross_entropy(logits, oh_labels)
return loss
class QueryProvidedAttentionBlock(nn.Module):