forked from mrq/DL-Art-School
minicoder with classifier head: spread out probability mass for 0 predictions
This commit is contained in:
parent
29b2921222
commit
f56edb2122
|
@ -139,10 +139,23 @@ class AudioMiniEncoderWithClassifierHead(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.enc = AudioMiniEncoder(**kwargs)
|
self.enc = AudioMiniEncoder(**kwargs)
|
||||||
self.head = nn.Linear(self.enc.dim, classes)
|
self.head = nn.Linear(self.enc.dim, classes)
|
||||||
|
self.num_classes = classes
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, labels=None):
|
||||||
h = self.enc(x)
|
h = self.enc(x)
|
||||||
return self.head(h)
|
logits = self.head(h)
|
||||||
|
if labels is None:
|
||||||
|
return logits
|
||||||
|
else:
|
||||||
|
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
||||||
|
zeros_indices = (labels == 0).unsqueeze(-1)
|
||||||
|
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
||||||
|
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
||||||
|
zero_extra_mass[:, 0] = -.2
|
||||||
|
zero_extra_mass = zero_extra_mass * zeros_indices
|
||||||
|
oh_labels = oh_labels + zero_extra_mass
|
||||||
|
loss = nn.functional.cross_entropy(logits, oh_labels)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class QueryProvidedAttentionBlock(nn.Module):
|
class QueryProvidedAttentionBlock(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user