misc nonfunctional

This commit is contained in:
James Betker 2021-11-22 17:16:39 -07:00
parent 3125ca38f5
commit 973f47c525
4 changed files with 34 additions and 21 deletions

View File

@ -9,6 +9,7 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic
from models.tacotron2.text import symbols from models.tacotron2.text import symbols
from trainer.networks import register_model from trainer.networks import register_model
from utils.audio import plot_spectrogram
from utils.util import opt_get from utils.util import opt_get
@ -248,6 +249,7 @@ class GptAsrHf2(nn.Module):
return text_logits return text_logits
def forward(self, mel_inputs, text_targets, return_attentions=False): def forward(self, mel_inputs, text_targets, return_attentions=False):
plot_spectrogram(mel_inputs[0].cpu())
text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token. text_targets = F.pad(text_targets, (0,1)) # Pad the targets with a <0> so that all have a "stop" token.
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions) text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
if return_attentions: if return_attentions:

View File

@ -1,13 +1,8 @@
import pathlib
import numpy
import torch import torch
from scipy.io import wavfile from scipy.io import wavfile
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
from models.waveglow.waveglow import WaveGlow from models.waveglow.waveglow import WaveGlow
from utils.audio import plot_spectrogram
class Vocoder: class Vocoder:
@ -25,18 +20,6 @@ class Vocoder:
return self.model.infer(mel) return self.model.infer(mel)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
if __name__ == '__main__': if __name__ == '__main__':
vocoder = Vocoder() vocoder = Vocoder()
m = torch.load('test_mels.pth') m = torch.load('test_mels.pth')

View File

@ -5,8 +5,10 @@ import torchaudio.functional
from kornia.augmentation import RandomResizedCrop from kornia.augmentation import RandomResizedCrop
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from data.audio.unsupervised_audio_dataset import load_audio
from trainer.inject import Injector, create_injector from trainer.inject import Injector, create_injector
from trainer.losses import extract_params_from_state from trainer.losses import extract_params_from_state
from utils.audio import plot_spectrogram
from utils.util import opt_get from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt from utils.weight_scheduler import get_scheduler_for_opt
@ -568,7 +570,7 @@ class TorchMelSpectrogramInjector(Injector):
self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000) self.mel_fmax = opt_get(opt, ['mel_fmax'], 8000)
self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050)
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length, self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
win_length=self.win_length, power=2, normalized=True, win_length=self.win_length, power=2, normalized=False,
sample_rate=self.sampling_rate, f_min=self.mel_fmin, sample_rate=self.sampling_rate, f_min=self.mel_fmin,
f_max=self.mel_fmax, n_mels=self.n_mel_channels) f_max=self.mel_fmax, n_mels=self.n_mel_channels)
@ -582,6 +584,14 @@ class TorchMelSpectrogramInjector(Injector):
return {self.output: mel} return {self.output: mel}
def test_torch_mel_injector():
a = load_audio('D:\\data\\audio\\libritts\\train-clean-100\\19\\198\\19_198_000000_000000.wav', 22050)
inj = TorchMelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
f = inj({'in': a.unsqueeze(0)})['out']
plot_spectrogram(f[0])
print('Pause')
class RandomAudioCropInjector(Injector): class RandomAudioCropInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super().__init__(opt, env) super().__init__(opt, env)
@ -606,6 +616,10 @@ class AudioResampleInjector(Injector):
return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)} return {self.output: torchaudio.functional.resample(inp, self.input_sr, self.output_sr)}
if __name__ == '__main__': def test_audio_resample_injector():
inj = AudioResampleInjector({'in': 'x', 'out': 'y', 'input_sample_rate': 22050, 'output_sample_rate': '1'}, None) inj = AudioResampleInjector({'in': 'x', 'out': 'y', 'input_sample_rate': 22050, 'output_sample_rate': '1'}, None)
print(inj({'x':torch.rand(10,1,40800)})['y'].shape) print(inj({'x':torch.rand(10,1,40800)})['y'].shape)
if __name__ == '__main__':
test_torch_mel_injector()

14
codes/utils/audio.py Normal file
View File

@ -0,0 +1,14 @@
import librosa
import matplotlib.pyplot as plt
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)