diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 6decc92e..993f9a82 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -231,7 +231,7 @@ class GptAsrHf2(nn.Module): def get_logits(self, mel_inputs, text_targets, get_attns=False): - # Pad front remove last element to set up next token prediction. Pad at front is the "START" token. + # Pad front and remove last element to set up next token prediction. Pad at front is the "START" token. text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1] text_emb = self.gpt.get_input_embeddings()(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index e21c02bd..f13e7608 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -569,8 +569,9 @@ class TorchMelSpectrogramInjector(Injector): self.mel_fmin = opt_get(opt, ['mel_fmin'], 0) self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) + norm = opt_get(opt, ['normalize'], False) self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, - win_length=self.win_length, power=2, normalized=False, + win_length=self.win_length, power=2, normalized=norm, sample_rate=self.sampling_rate, f_min=self.mel_fmin, f_max=self.mel_fmax, n_mels=self.n_mel_channels) diff --git a/codes/trainer/injectors/spec_augment.py b/codes/trainer/injectors/spec_augment.py index cf425592..63fa7c6a 100644 --- a/codes/trainer/injectors/spec_augment.py +++ b/codes/trainer/injectors/spec_augment.py @@ -91,7 +91,8 @@ class CombineMelInjector(Injector): texts = state[self.text_key] audio_lengths = state[self.audio_lengths] text_lengths = state[self.text_lengths] - assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches. + if audio.shape[0] == 1: + return {self.output_audio_key: audio, self.output_text_key: texts} combined_audios = [] combined_texts = [] for b in range(audio.shape[0]//2):