forked from mrq/DL-Art-School
whoops
This commit is contained in:
parent
973f47c525
commit
01e635168b
|
@ -249,7 +249,6 @@ class GptAsrHf2(nn.Module):
|
||||||
return text_logits
|
return text_logits
|
||||||
|
|
||||||
def forward(self, mel_inputs, text_targets, return_attentions=False):
|
def forward(self, mel_inputs, text_targets, return_attentions=False):
|
||||||
plot_spectrogram(mel_inputs[0].cpu())
|
|
||||||
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
|
||||||
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
|
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user