A few fixes for gpt_asr_hf2
This commit is contained in:
parent
3b5c3d85d8
commit
934395d4b8
|
@ -231,7 +231,7 @@ class GptAsrHf2(nn.Module):
|
|||
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
|
|
|
@ -569,8 +569,9 @@ class TorchMelSpectrogramInjector(Injector):
|
|||
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
|
||||
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||
norm = opt_get(opt, ['normalize'], False)
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
||||
win_length=self.win_length, power=2, normalized=False,
|
||||
win_length=self.win_length, power=2, normalized=norm,
|
||||
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
||||
f_max=self.mel_fmax, n_mels=self.n_mel_channels)
|
||||
|
||||
|
|
|
@ -91,7 +91,8 @@ class CombineMelInjector(Injector):
|
|||
texts = state[self.text_key]
|
||||
audio_lengths = state[self.audio_lengths]
|
||||
text_lengths = state[self.text_lengths]
|
||||
assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches.
|
||||
if audio.shape[0] == 1:
|
||||
return {self.output_audio_key: audio, self.output_text_key: texts}
|
||||
combined_audios = []
|
||||
combined_texts = []
|
||||
for b in range(audio.shape[0]//2):
|
||||
|
|
Loading…
Reference in New Issue
Block a user