Further wandb logs

This commit is contained in:
James Betker 2021-11-22 16:40:19 -07:00
parent 19c80bf7a7
commit 3125ca38f5
4 changed files with 31 additions and 17 deletions

View File

@ -231,9 +231,9 @@ class GptAsrHf2(nn.Module):
def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
text_emb = self.gpt.get_input_embeddings()(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))

View File

@ -4,6 +4,8 @@ import numpy
import torch
from scipy.io import wavfile
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
from models.waveglow.waveglow import WaveGlow
@ -23,13 +25,22 @@ class Vocoder:
return self.model.infer(mel)
if __name__ == '__main__':
path = 'data/audio'
files = list(pathlib.Path(path).glob('*.npy'))
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
for inp in tqdm(files):
inp = str(inp)
mel = torch.tensor(numpy.load(inp)).to('cuda')
vocoder = Vocoder()
wav = vocoder.transform_mel_to_audio(mel)
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())
if __name__ == '__main__':
vocoder = Vocoder()
m = torch.load('test_mels.pth')
for i, b in enumerate(m):
plot_spectrogram(b.cpu())
wav = vocoder.transform_mel_to_audio(b)
wavfile.write(f'{i}.wav', 22050, wav[0].cpu().numpy())

View File

@ -577,6 +577,7 @@ class TorchMelSpectrogramInjector(Injector):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
return {self.output: mel}

View File

@ -81,6 +81,8 @@ class CombineMelInjector(Injector):
self.text_key = opt['text_key']
self.audio_lengths = opt['audio_lengths_key']
self.text_lengths = opt['text_lengths_key']
self.output_audio_key = opt['output_audio_key']
self.output_text_key = opt['output_text_key']
from models.tacotron2.text import symbols
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
@ -93,9 +95,9 @@ class CombineMelInjector(Injector):
combined_audios = []
combined_texts = []
for b in range(audio.shape[0]//2):
a1 = audio[b*2, :audio_lengths[b*2]]
a2 = audio[b*2+1, :audio_lengths[b*2+1]]
a = torch.cat([a1, a2], dim=0)
a1 = audio[b*2, :, :audio_lengths[b*2]]
a2 = audio[b*2+1, :, :audio_lengths[b*2+1]]
a = torch.cat([a1, a2], dim=1)
a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1]))
combined_audios.append(a)
@ -105,8 +107,8 @@ class CombineMelInjector(Injector):
t = torch.cat([t1, t2], dim=0)
t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1]))
combined_texts.append(t)
return {self.audio_key: torch.stack(combined_audios, dim=0),
self.text_key: torch.stack(combined_texts, dim=0)}
return {self.output_audio_key: torch.stack(combined_audios, dim=0),
self.output_text_key: torch.stack(combined_texts, dim=0)}
def test_mel_injector():