Further wandb logs

This commit is contained in:
James Betker 2021-11-22 16:40:19 -07:00
parent 19c80bf7a7
commit 3125ca38f5
4 changed files with 31 additions and 17 deletions

View File

@ -231,9 +231,9 @@ class GptAsrHf2(nn.Module):
def get_logits(self, mel_inputs, text_targets, get_attns=False): def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token. # Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1] text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
text_emb = self.gpt.get_input_embeddings()(text_targets) text_emb = self.gpt.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
mel_emb = self.mel_encoder(mel_inputs) mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))

View File

@ -4,6 +4,8 @@ import numpy
import torch import torch
from scipy.io import wavfile from scipy.io import wavfile
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
from models.waveglow.waveglow import WaveGlow from models.waveglow.waveglow import WaveGlow
@ -23,13 +25,22 @@ class Vocoder:
return self.model.infer(mel) return self.model.infer(mel)
if __name__ == '__main__': def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
path = 'data/audio' fig, axs = plt.subplots(1, 1)
files = list(pathlib.Path(path).glob('*.npy')) axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
for inp in tqdm(files):
inp = str(inp) if __name__ == '__main__':
mel = torch.tensor(numpy.load(inp)).to('cuda') vocoder = Vocoder()
vocoder = Vocoder() m = torch.load('test_mels.pth')
wav = vocoder.transform_mel_to_audio(mel) for i, b in enumerate(m):
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy()) plot_spectrogram(b.cpu())
wav = vocoder.transform_mel_to_audio(b)
wavfile.write(f'{i}.wav', 22050, wav[0].cpu().numpy())

View File

@ -577,6 +577,7 @@ class TorchMelSpectrogramInjector(Injector):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1) inp = inp.squeeze(1)
assert len(inp.shape) == 2 assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp) mel = self.mel_stft(inp)
return {self.output: mel} return {self.output: mel}

View File

@ -81,6 +81,8 @@ class CombineMelInjector(Injector):
self.text_key = opt['text_key'] self.text_key = opt['text_key']
self.audio_lengths = opt['audio_lengths_key'] self.audio_lengths = opt['audio_lengths_key']
self.text_lengths = opt['text_lengths_key'] self.text_lengths = opt['text_lengths_key']
self.output_audio_key = opt['output_audio_key']
self.output_text_key = opt['output_text_key']
from models.tacotron2.text import symbols from models.tacotron2.text import symbols
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user. self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
@ -93,9 +95,9 @@ class CombineMelInjector(Injector):
combined_audios = [] combined_audios = []
combined_texts = [] combined_texts = []
for b in range(audio.shape[0]//2): for b in range(audio.shape[0]//2):
a1 = audio[b*2, :audio_lengths[b*2]] a1 = audio[b*2, :, :audio_lengths[b*2]]
a2 = audio[b*2+1, :audio_lengths[b*2+1]] a2 = audio[b*2+1, :, :audio_lengths[b*2+1]]
a = torch.cat([a1, a2], dim=0) a = torch.cat([a1, a2], dim=1)
a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1])) a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1]))
combined_audios.append(a) combined_audios.append(a)
@ -105,8 +107,8 @@ class CombineMelInjector(Injector):
t = torch.cat([t1, t2], dim=0) t = torch.cat([t1, t2], dim=0)
t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1])) t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1]))
combined_texts.append(t) combined_texts.append(t)
return {self.audio_key: torch.stack(combined_audios, dim=0), return {self.output_audio_key: torch.stack(combined_audios, dim=0),
self.text_key: torch.stack(combined_texts, dim=0)} self.output_text_key: torch.stack(combined_texts, dim=0)}
def test_mel_injector(): def test_mel_injector():