forked from mrq/DL-Art-School
minicoder with classifier head: spread out probability mass for 0 predictions
This commit is contained in:
parent
29b2921222
commit
f56edb2122
|
@ -139,10 +139,23 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
|||
super().__init__()
|
||||
self.enc = AudioMiniEncoder(**kwargs)
|
||||
self.head = nn.Linear(self.enc.dim, classes)
|
||||
self.num_classes = classes
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, labels=None):
|
||||
h = self.enc(x)
|
||||
return self.head(h)
|
||||
logits = self.head(h)
|
||||
if labels is None:
|
||||
return logits
|
||||
else:
|
||||
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
||||
zeros_indices = (labels == 0).unsqueeze(-1)
|
||||
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
||||
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||
zero_extra_mass[:, 0] = -.2
|
||||
zero_extra_mass = zero_extra_mass * zeros_indices
|
||||
oh_labels = oh_labels + zero_extra_mass
|
||||
loss = nn.functional.cross_entropy(logits, oh_labels)
|
||||
return loss
|
||||
|
||||
|
||||
class QueryProvidedAttentionBlock(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user