forked from mrq/DL-Art-School
A few fixes for gpt_asr_hf2
This commit is contained in:
parent
3b5c3d85d8
commit
934395d4b8
|
@ -231,7 +231,7 @@ class GptAsrHf2(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||||
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
|
# Pad front and remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||||
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||||
|
|
|
@ -569,8 +569,9 @@ class TorchMelSpectrogramInjector(Injector):
|
||||||
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
|
self.mel_fmin = opt_get(opt, ['mel_fmin'], 0)
|
||||||
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
|
||||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
|
||||||
|
norm = opt_get(opt, ['normalize'], False)
|
||||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
||||||
win_length=self.win_length, power=2, normalized=False,
|
win_length=self.win_length, power=2, normalized=norm,
|
||||||
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
||||||
f_max=self.mel_fmax, n_mels=self.n_mel_channels)
|
f_max=self.mel_fmax, n_mels=self.n_mel_channels)
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,8 @@ class CombineMelInjector(Injector):
|
||||||
texts = state[self.text_key]
|
texts = state[self.text_key]
|
||||||
audio_lengths = state[self.audio_lengths]
|
audio_lengths = state[self.audio_lengths]
|
||||||
text_lengths = state[self.text_lengths]
|
text_lengths = state[self.text_lengths]
|
||||||
assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches.
|
if audio.shape[0] == 1:
|
||||||
|
return {self.output_audio_key: audio, self.output_text_key: texts}
|
||||||
combined_audios = []
|
combined_audios = []
|
||||||
combined_texts = []
|
combined_texts = []
|
||||||
for b in range(audio.shape[0]//2):
|
for b in range(audio.shape[0]//2):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user