This commit is contained in:
James Betker 2021-11-22 17:24:13 -07:00
parent 973f47c525
commit 01e635168b

View File

@ -249,7 +249,6 @@ class GptAsrHf2(nn.Module):
return text_logits return text_logits
def forward(self, mel_inputs, text_targets, return_attentions=False): def forward(self, mel_inputs, text_targets, return_attentions=False):
plot_spectrogram(mel_inputs[0].cpu())
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions) text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
if return_attentions: if return_attentions: