diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index b6866166..cc08e832 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -231,9 +231,9 @@ class GptAsrHf2(nn.Module): def get_logits(self, mel_inputs, text_targets, get_attns=False): # Pad front remove last element to set up next token prediction. Pad at front is the "START" token. - text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1] - text_emb = self.gpt.get_input_embeddings()(text_targets) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) + text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1] + text_emb = self.gpt.get_input_embeddings()(text_inputs) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) mel_emb = self.mel_encoder(mel_inputs) mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) diff --git a/codes/scripts/audio/use_vocoder.py b/codes/scripts/audio/use_vocoder.py index f2270c2c..ad433d42 100644 --- a/codes/scripts/audio/use_vocoder.py +++ b/codes/scripts/audio/use_vocoder.py @@ -4,6 +4,8 @@ import numpy import torch from scipy.io import wavfile from tqdm import tqdm +import matplotlib.pyplot as plt +import librosa from models.waveglow.waveglow import WaveGlow @@ -23,13 +25,22 @@ class Vocoder: return self.model.infer(mel) -if __name__ == '__main__': - path = 'data/audio' - files = list(pathlib.Path(path).glob('*.npy')) +def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): + fig, axs = plt.subplots(1, 1) + axs.set_title(title or "Spectrogram (db)") + axs.set_ylabel(ylabel) + axs.set_xlabel("frame") + im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect) + if xmax: + axs.set_xlim((0, xmax)) + fig.colorbar(im, ax=axs) + plt.show(block=False) - for inp in tqdm(files): - inp = str(inp) - mel = torch.tensor(numpy.load(inp)).to('cuda') - vocoder = Vocoder() - wav = vocoder.transform_mel_to_audio(mel) - wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy()) \ No newline at end of file + +if __name__ == '__main__': + vocoder = Vocoder() + m = torch.load('test_mels.pth') + for i, b in enumerate(m): + plot_spectrogram(b.cpu()) + wav = vocoder.transform_mel_to_audio(b) + wavfile.write(f'{i}.wav', 22050, wav[0].cpu().numpy()) \ No newline at end of file diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 88c72fdc..a6dd21de 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -577,6 +577,7 @@ class TorchMelSpectrogramInjector(Injector): if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = inp.squeeze(1) assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) mel = self.mel_stft(inp) return {self.output: mel} diff --git a/codes/trainer/injectors/spec_augment.py b/codes/trainer/injectors/spec_augment.py index 620f10af..cf425592 100644 --- a/codes/trainer/injectors/spec_augment.py +++ b/codes/trainer/injectors/spec_augment.py @@ -81,6 +81,8 @@ class CombineMelInjector(Injector): self.text_key = opt['text_key'] self.audio_lengths = opt['audio_lengths_key'] self.text_lengths = opt['text_lengths_key'] + self.output_audio_key = opt['output_audio_key'] + self.output_text_key = opt['output_text_key'] from models.tacotron2.text import symbols self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user. @@ -93,9 +95,9 @@ class CombineMelInjector(Injector): combined_audios = [] combined_texts = [] for b in range(audio.shape[0]//2): - a1 = audio[b*2, :audio_lengths[b*2]] - a2 = audio[b*2+1, :audio_lengths[b*2+1]] - a = torch.cat([a1, a2], dim=0) + a1 = audio[b*2, :, :audio_lengths[b*2]] + a2 = audio[b*2+1, :, :audio_lengths[b*2+1]] + a = torch.cat([a1, a2], dim=1) a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1])) combined_audios.append(a) @@ -105,8 +107,8 @@ class CombineMelInjector(Injector): t = torch.cat([t1, t2], dim=0) t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1])) combined_texts.append(t) - return {self.audio_key: torch.stack(combined_audios, dim=0), - self.text_key: torch.stack(combined_texts, dim=0)} + return {self.output_audio_key: torch.stack(combined_audios, dim=0), + self.output_text_key: torch.stack(combined_texts, dim=0)} def test_mel_injector():