Further wandb logs
This commit is contained in:
parent
19c80bf7a7
commit
3125ca38f5
|
@ -231,9 +231,9 @@ class GptAsrHf2(nn.Module):
|
|||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
text_inputs = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
|
|
@ -4,6 +4,8 @@ import numpy
|
|||
import torch
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
|
||||
from models.waveglow.waveglow import WaveGlow
|
||||
|
||||
|
@ -23,13 +25,22 @@ class Vocoder:
|
|||
return self.model.infer(mel)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = 'data/audio'
|
||||
files = list(pathlib.Path(path).glob('*.npy'))
|
||||
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
|
||||
fig, axs = plt.subplots(1, 1)
|
||||
axs.set_title(title or "Spectrogram (db)")
|
||||
axs.set_ylabel(ylabel)
|
||||
axs.set_xlabel("frame")
|
||||
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
|
||||
if xmax:
|
||||
axs.set_xlim((0, xmax))
|
||||
fig.colorbar(im, ax=axs)
|
||||
plt.show(block=False)
|
||||
|
||||
for inp in tqdm(files):
|
||||
inp = str(inp)
|
||||
mel = torch.tensor(numpy.load(inp)).to('cuda')
|
||||
vocoder = Vocoder()
|
||||
wav = vocoder.transform_mel_to_audio(mel)
|
||||
wavfile.write(f'{inp}.wav', 22050, wav[0].cpu().numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
vocoder = Vocoder()
|
||||
m = torch.load('test_mels.pth')
|
||||
for i, b in enumerate(m):
|
||||
plot_spectrogram(b.cpu())
|
||||
wav = vocoder.transform_mel_to_audio(b)
|
||||
wavfile.write(f'{i}.wav', 22050, wav[0].cpu().numpy())
|
|
@ -577,6 +577,7 @@ class TorchMelSpectrogramInjector(Injector):
|
|||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
self.mel_stft = self.mel_stft.to(inp.device)
|
||||
mel = self.mel_stft(inp)
|
||||
return {self.output: mel}
|
||||
|
||||
|
|
|
@ -81,6 +81,8 @@ class CombineMelInjector(Injector):
|
|||
self.text_key = opt['text_key']
|
||||
self.audio_lengths = opt['audio_lengths_key']
|
||||
self.text_lengths = opt['text_lengths_key']
|
||||
self.output_audio_key = opt['output_audio_key']
|
||||
self.output_text_key = opt['output_text_key']
|
||||
from models.tacotron2.text import symbols
|
||||
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
|
||||
|
||||
|
@ -93,9 +95,9 @@ class CombineMelInjector(Injector):
|
|||
combined_audios = []
|
||||
combined_texts = []
|
||||
for b in range(audio.shape[0]//2):
|
||||
a1 = audio[b*2, :audio_lengths[b*2]]
|
||||
a2 = audio[b*2+1, :audio_lengths[b*2+1]]
|
||||
a = torch.cat([a1, a2], dim=0)
|
||||
a1 = audio[b*2, :, :audio_lengths[b*2]]
|
||||
a2 = audio[b*2+1, :, :audio_lengths[b*2+1]]
|
||||
a = torch.cat([a1, a2], dim=1)
|
||||
a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1]))
|
||||
combined_audios.append(a)
|
||||
|
||||
|
@ -105,8 +107,8 @@ class CombineMelInjector(Injector):
|
|||
t = torch.cat([t1, t2], dim=0)
|
||||
t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1]))
|
||||
combined_texts.append(t)
|
||||
return {self.audio_key: torch.stack(combined_audios, dim=0),
|
||||
self.text_key: torch.stack(combined_texts, dim=0)}
|
||||
return {self.output_audio_key: torch.stack(combined_audios, dim=0),
|
||||
self.output_text_key: torch.stack(combined_texts, dim=0)}
|
||||
|
||||
|
||||
def test_mel_injector():
|
||||
|
|
Loading…
Reference in New Issue
Block a user