A few fixes for gpt_asr_hf2

This commit is contained in:
James Betker 2021-11-23 09:29:29 -07:00
parent 3b5c3d85d8
commit 934395d4b8
3 changed files with 5 additions and 3 deletions

View File

@ -231,7 +231,7 @@ class GptAsrHf2(nn.Module):
def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
text_emb = self.gpt.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))

View File

@ -569,8 +569,9 @@ class TorchMelSpectrogramInjector(Injector):
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
norm = opt_get(opt, ['normalize'], False)
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
win_length=self.win_length, power=2, normalized=False,
win_length=self.win_length, power=2, normalized=norm,
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
f_max=self.mel_fmax, n_mels=self.n_mel_channels)

View File

@ -91,7 +91,8 @@ class CombineMelInjector(Injector):
texts = state[self.text_key]
audio_lengths = state[self.audio_lengths]
text_lengths = state[self.text_lengths]
assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches.
if audio.shape[0] == 1:
return {self.output_audio_key: audio, self.output_text_key: texts}
combined_audios = []
combined_texts = []
for b in range(audio.shape[0]//2):