Alterations to diffusion_dvae so it can be used directly on spectrograms
This commit is contained in:
parent
97ea329a59
commit
6833048bf7
|
@ -73,18 +73,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||
collate = C(dataset_opt)
|
||||
elif mode == 'wavfile_clips':
|
||||
from data.audio.wavfile_dataset import WavfileDataset as D
|
||||
elif mode == 'unsupervised_audio':
|
||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
|
||||
elif mode == 'stop_prediction':
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
from data.audio.stop_prediction_dataset import StopPredictionDataset as D
|
||||
elif mode == 'stop_prediction2':
|
||||
from data.audio.stop_prediction_dataset_2 import StopPredictionDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
import audio2numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.tacotron2.layers as layers
|
||||
from data.audio.nv_tacotron_dataset import load_mozilla_cv, load_voxpopuli
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
||||
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def get_similar_files_libritts(filename):
|
||||
filedir = os.path.dirname(filename)
|
||||
return list(pathlib.Path(filedir).glob('*.wav'))
|
||||
|
||||
|
||||
class StopPredictionDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio,text pairs
|
||||
2) normalizes text and converts them to sequences of one-hot vectors
|
||||
3) computes mel-spectrograms from audio files.
|
||||
"""
|
||||
def __init__(self, hparams):
|
||||
self.path = hparams['path']
|
||||
if not isinstance(self.path, list):
|
||||
self.path = [self.path]
|
||||
|
||||
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
|
||||
if not isinstance(fetcher_mode, list):
|
||||
fetcher_mode = [fetcher_mode]
|
||||
assert len(self.path) == len(fetcher_mode)
|
||||
|
||||
self.audiopaths_and_text = []
|
||||
for p, fm in zip(self.path, fetcher_mode):
|
||||
if fm == 'lj' or fm == 'libritts':
|
||||
fetcher_fn = load_filepaths_and_text
|
||||
self.get_similar_files = get_similar_files_libritts
|
||||
elif fm == 'voxpopuli':
|
||||
fetcher_fn = load_voxpopuli
|
||||
self.get_similar_files = None # TODO: Fix.
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
self.audiopaths_and_text.extend(fetcher_fn(p))
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate)
|
||||
self.stft = layers.TacotronSTFT(
|
||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
|
||||
hparams.mel_fmax)
|
||||
random.seed(hparams.seed)
|
||||
random.shuffle(self.audiopaths_and_text)
|
||||
self.max_mel_len = opt_get(hparams, ['max_mel_length'], None)
|
||||
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
|
||||
|
||||
def get_mel(self, filename):
|
||||
filename = str(filename)
|
||||
if filename.endswith('.wav'):
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
else:
|
||||
audio, sampling_rate = audio2numpy.audio_from_file(filename)
|
||||
audio = torch.tensor(audio)
|
||||
|
||||
if sampling_rate != self.input_sample_rate:
|
||||
if sampling_rate < self.input_sample_rate:
|
||||
print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.')
|
||||
audio_norm = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze()
|
||||
else:
|
||||
audio_norm = audio
|
||||
if audio_norm.std() > 1:
|
||||
print(f"Something is very wrong with the given audio. std_dev={audio_norm.std()}. file={filename}")
|
||||
return None
|
||||
audio_norm.clip_(-1, 1)
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
if self.input_sample_rate != self.sampling_rate:
|
||||
ratio = self.sampling_rate / self.input_sample_rate
|
||||
audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0)
|
||||
melspec = self.stft.mel_spectrogram(audio_norm)
|
||||
melspec = torch.squeeze(melspec, 0)
|
||||
|
||||
return melspec
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.audiopaths_and_text[index][0]
|
||||
similar_files = self.get_similar_files(path)
|
||||
mel = self.get_mel(path)
|
||||
terms = torch.zeros(mel.shape[1])
|
||||
terms[-1] = 1
|
||||
while mel.shape[-1] < self.max_mel_len:
|
||||
another_file = random.choice(similar_files)
|
||||
another_mel = self.get_mel(another_file)
|
||||
oterms = torch.zeros(another_mel.shape[1])
|
||||
oterms[-1] = 1
|
||||
mel = torch.cat([mel, another_mel], dim=-1)
|
||||
terms = torch.cat([terms, oterms], dim=-1)
|
||||
mel = mel[:, :self.max_mel_len]
|
||||
terms = terms[:self.max_mel_len]
|
||||
|
||||
|
||||
return {
|
||||
'padded_mel': mel,
|
||||
'termination_mask': terms,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_and_text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'stop_prediction',
|
||||
'path': 'E:\\audio\\LibriTTS\\train-clean-360_list.txt',
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
'fetcher_mode': 'libritts',
|
||||
'max_mel_length': 800,
|
||||
#'return_wavs': True,
|
||||
#'input_sample_rate': 22050,
|
||||
#'sampling_rate': 8000
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c, shuffle=True)
|
||||
i = 0
|
||||
m = None
|
||||
for k in range(1000):
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
continue
|
||||
pm = b['padded_mel']
|
||||
pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1]))
|
||||
m = pm if m is None else torch.cat([m, pm], dim=0)
|
||||
print(m.mean(), m.std())
|
|
@ -1,92 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
|
||||
from munch import munchify
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
|
||||
from models.tacotron2 import hparams
|
||||
from models.tacotron2.layers import TacotronSTFT
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
|
||||
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
|
||||
class StopPredictionDataset(Dataset):
|
||||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
label_compaction = opt_get(opt, ['label_compaction'], 1)
|
||||
hp = munchify(hparams.create_hparams())
|
||||
cache_path = os.path.join(path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
self.files = torch.load(cache_path)
|
||||
else:
|
||||
print("Building cache..")
|
||||
self.files = list(pathlib.Path(path).glob('*.wav'))
|
||||
torch.save(self.files, cache_path)
|
||||
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
|
||||
self.mel_length = 2000
|
||||
self.stft = TacotronSTFT(
|
||||
hp.filter_length, hp.hop_length, hp.win_length,
|
||||
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
|
||||
hp.mel_fmax)
|
||||
self.label_compaction = label_compaction
|
||||
|
||||
def __getitem__(self, index):
|
||||
audio, _ = load_wav_to_torch(self.files[index])
|
||||
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
|
||||
|
||||
if audio.std() > 1:
|
||||
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
|
||||
return None
|
||||
audio.clip_(-1, 1)
|
||||
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
|
||||
|
||||
# Form labels.
|
||||
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||||
for s in starts:
|
||||
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
|
||||
s = s // (256 * self.label_compaction)
|
||||
if s >= 2000//self.label_compaction:
|
||||
continue
|
||||
labels_start[s] = 1
|
||||
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
|
||||
for e in ends:
|
||||
e = e // (256 * self.label_compaction)
|
||||
if e >= 2000//self.label_compaction:
|
||||
continue
|
||||
labels_end[e] = 1
|
||||
|
||||
return {
|
||||
'mels': mels,
|
||||
'labels_start': labels_start,
|
||||
'labels_end': labels_end,
|
||||
}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
|
||||
'label_compaction': 4,
|
||||
}
|
||||
ds = StopPredictionDataset(opt)
|
||||
j = 0
|
||||
for i in tqdm(range(100)):
|
||||
b = ds[random.randint(0, len(ds))]
|
||||
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
|
||||
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
|
||||
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
|
||||
for i in range(len(end_indices)):
|
||||
s = start_indices[i].item()*4
|
||||
e = end_indices[i].item()*4
|
||||
m = b['mels'][:, s:e]
|
||||
save_mel_buffer_to_file(m, f'{j}.npy')
|
||||
j += 1
|
|
@ -113,7 +113,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
alt_files, actual_samples = self.get_related_audio_for_index(index)
|
||||
except:
|
||||
print(f"Error loading audio for file {filename} or {alt_files}")
|
||||
print(f"Error loading audio for file {self.audiopaths[index]}")
|
||||
return self[index+1]
|
||||
|
||||
# This is required when training to make sure all clips align.
|
||||
|
|
|
@ -1,135 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.wav_aug import WavAugmentor
|
||||
from data.util import find_files_of_type, is_wav_file
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def load_audio_from_wav(audiopath, sampling_rate):
|
||||
audio, lsr = load_wav_to_torch(audiopath)
|
||||
if lsr != sampling_rate:
|
||||
if lsr < sampling_rate:
|
||||
print(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {sampling_rate}. This is not a good idea.')
|
||||
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=sampling_rate/lsr, mode='nearest', recompute_scale_factor=False).squeeze()
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 2) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
audio.clip_(-1, 1)
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
class WavfileDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
path = opt['path']
|
||||
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
|
||||
if not isinstance(path, list):
|
||||
path = [path]
|
||||
if os.path.exists(cache_path):
|
||||
self.audiopaths = torch.load(cache_path)
|
||||
else:
|
||||
print("Building cache..")
|
||||
self.audiopaths = []
|
||||
for p in path:
|
||||
self.audiopaths.extend(find_files_of_type('img', p, qualifier=is_wav_file)[0])
|
||||
torch.save(self.audiopaths, cache_path)
|
||||
|
||||
# Parse options
|
||||
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
|
||||
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
|
||||
if self.pad_to is not None:
|
||||
self.pad_to *= self.sampling_rate
|
||||
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
||||
|
||||
self.augment = opt_get(opt, ['do_augmentation'], False)
|
||||
if self.augment:
|
||||
# The "window size" for the clips produced in seconds.
|
||||
self.window = 2 * self.sampling_rate
|
||||
self.augmentor = WavAugmentor()
|
||||
|
||||
def get_audio_for_index(self, index):
|
||||
audiopath = self.audiopaths[index]
|
||||
audio = load_audio_from_wav(audiopath, self.sampling_rate)
|
||||
return audio, audiopath
|
||||
|
||||
def __getitem__(self, index):
|
||||
success = False
|
||||
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
|
||||
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
|
||||
while not success:
|
||||
try:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
success = True
|
||||
except:
|
||||
print(f"Failed to load {index} {self.audiopaths[index]}")
|
||||
|
||||
if self.augment:
|
||||
if audio_norm.shape[1] < self.window * 2:
|
||||
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
||||
return self[(index + 1) % len(self)]
|
||||
j = random.randint(0, audio_norm.shape[1] - self.window)
|
||||
clip1 = audio_norm[:, j:j+self.window]
|
||||
if self.augment:
|
||||
clip1 = self.augmentor.augment(clip1, self.sampling_rate)
|
||||
j = random.randint(0, audio_norm.shape[1]-self.window)
|
||||
clip2 = audio_norm[:, j:j+self.window]
|
||||
if self.augment:
|
||||
clip2 = self.augmentor.augment(clip2, self.sampling_rate)
|
||||
|
||||
# This is required when training to make sure all clips align.
|
||||
if self.pad_to is not None:
|
||||
if audio_norm.shape[-1] <= self.pad_to:
|
||||
audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1]))
|
||||
else:
|
||||
gap = audio_norm.shape[-1] - self.pad_to
|
||||
start = random.randint(0, gap-1)
|
||||
audio_norm = audio_norm[:, start:start+self.pad_to]
|
||||
|
||||
output = {
|
||||
'clip': audio_norm,
|
||||
'path': filename,
|
||||
}
|
||||
if self.augment:
|
||||
output.update({
|
||||
'clip1': clip1[0, :].unsqueeze(0),
|
||||
'clip2': clip2[0, :].unsqueeze(0),
|
||||
})
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'wavfile_clips',
|
||||
'path': ['E:\\audio\\books-split', 'E:\\audio\\LibriTTS\\train-clean-360', 'D:\\data\\audio\\podcasts-split'],
|
||||
'cache_path': 'E:\\audio\\clips-cache.pth',
|
||||
'sampling_rate': 22050,
|
||||
'pad_to_seconds': 5,
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
'do_augmentation': False
|
||||
}
|
||||
from data import create_dataset, create_dataloader, util
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params)
|
||||
i = 0
|
||||
for b in tqdm(dl):
|
||||
for b_ in range(16):
|
||||
pass
|
||||
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
|
||||
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
|
||||
#i += 1
|
|
@ -5,11 +5,12 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.gpt_voice.lucidrains_dvae import eval_decorator
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
import models.gpt_voice.my_dvae as mdvae
|
||||
from utils.util import get_mask_from_lengths
|
||||
import models.gpt_voice.mini_encoder as menc
|
||||
|
||||
|
||||
class DiscreteEncoder(nn.Module):
|
||||
|
@ -22,13 +23,13 @@ class DiscreteEncoder(nn.Module):
|
|||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
conv_nd(1, in_channels, model_channels, 3, padding=1),
|
||||
mdvae.ResBlock(model_channels, dropout, dims=1),
|
||||
menc.ResBlock(model_channels, dropout, dims=1),
|
||||
Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale),
|
||||
mdvae.ResBlock(model_channels*2, dropout, dims=1),
|
||||
menc.ResBlock(model_channels*2, dropout, dims=1),
|
||||
Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale),
|
||||
mdvae.ResBlock(model_channels*4, dropout, dims=1),
|
||||
menc.ResBlock(model_channels*4, dropout, dims=1),
|
||||
AttentionBlock(model_channels*4, num_heads=4),
|
||||
mdvae.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
|
||||
menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
|
||||
)
|
||||
|
||||
def forward(self, spectrogram):
|
||||
|
@ -249,6 +250,21 @@ class DiffusionDVAE(nn.Module):
|
|||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.record_codes:
|
||||
# Report annealing schedule
|
||||
return {'histogram_codes': self.codes}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
return codes
|
||||
|
||||
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
|
||||
if self.conditioning_enabled:
|
||||
assert conditioning_inputs is not None
|
||||
|
@ -299,17 +315,28 @@ class DiffusionDVAE(nn.Module):
|
|||
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals)
|
||||
|
||||
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
|
||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||
|
||||
# Compute DVAE portion first.
|
||||
spec_logits = self.encoder(spectrogram).permute((0,2,1))
|
||||
sampled, commitment_loss, codes = self.quantizer(spec_logits)
|
||||
|
||||
if self.training:
|
||||
# Compute from softmax outputs to preserve gradients.
|
||||
embeddings = sampled.permute((0,2,1))
|
||||
else:
|
||||
# Compute from codes only.
|
||||
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.internal_step % 50 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
self.code_ind = 0
|
||||
self.internal_step += 1
|
||||
|
||||
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss
|
||||
|
||||
|
||||
|
@ -318,12 +345,44 @@ def register_unet_diffusion_dvae(opt_net, opt):
|
|||
return DiffusionDVAE(**opt_net['kwargs'])
|
||||
|
||||
|
||||
|
||||
'''
|
||||
|
||||
|
||||
class DiffusionDVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_channels,
|
||||
num_res_blocks,
|
||||
in_channels=1,
|
||||
out_channels=2, # mean and variance
|
||||
spectrogram_channels=80,
|
||||
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8, 16, 32, 64),
|
||||
attention_resolutions=(16,32,64),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
use_new_attention_order=False,
|
||||
kernel_size=5,
|
||||
quantize_dim=1024,
|
||||
num_discrete_codes=8192,
|
||||
scale_steps=4,
|
||||
conditioning_inputs_provided=True,
|
||||
):
|
||||
'''
|
||||
|
||||
# Test for ~4 second audio clip at 22050Hz
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(4, 1, 81920)
|
||||
spec = torch.randn(4, 80, 416)
|
||||
cond = torch.randn(4, 5, 80, 200)
|
||||
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
|
||||
ts = torch.LongTensor([432, 234, 100, 555])
|
||||
model = DiffusionDVAE(32, 2)
|
||||
print(model(clip, ts, spec, cond, num_cond)[0].shape)
|
||||
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
|
||||
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
|
||||
print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape)
|
||||
|
|
|
@ -4,13 +4,91 @@ import torch.nn as nn
|
|||
|
||||
from models.diffusion.nn import normalization, conv_nd, zero_module
|
||||
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy
|
||||
from models.gpt_voice.my_dvae import ResBlock
|
||||
|
||||
|
||||
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
do_checkpoint=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.do_checkpoint = do_checkpoint
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.do_checkpoint:
|
||||
return checkpoint(
|
||||
self._forward, x
|
||||
)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AudioMiniEncoder(nn.Module):
|
||||
def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,370 +0,0 @@
|
|||
import functools
|
||||
import math
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.diffusion.nn import conv_nd, normalization, zero_module
|
||||
from models.diffusion.unet_diffusion import Upsample, Downsample, AttentionBlock
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get, checkpoint
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
return inner
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
do_checkpoint=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.do_checkpoint = do_checkpoint
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
if self.do_checkpoint:
|
||||
return checkpoint(
|
||||
self._forward, x
|
||||
)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class DisjointUnet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_resolutions,
|
||||
channel_mult_down,
|
||||
channel_mult_up,
|
||||
in_channels = 3,
|
||||
model_channels = 64,
|
||||
out_channels = 3,
|
||||
dims=2,
|
||||
num_res_blocks = 2,
|
||||
stride = 2,
|
||||
dropout=0,
|
||||
num_heads=4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.enc_input_blocks = nn.ModuleList(
|
||||
[
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
]
|
||||
)
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult_down):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=-1,
|
||||
)
|
||||
)
|
||||
self.enc_input_blocks.append(nn.Sequential(*layers))
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult_down) - 1:
|
||||
out_ch = ch
|
||||
self.enc_input_blocks.append(
|
||||
Downsample(
|
||||
ch, True, dims=dims, out_channels=out_ch, factor=stride
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
self.enc_middle_block = nn.Sequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
dropout,
|
||||
dims=dims,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=-1,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
dropout,
|
||||
dims=dims,
|
||||
),
|
||||
)
|
||||
|
||||
self.enc_output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult_up)):
|
||||
for i in range(num_res_blocks + 1):
|
||||
if len(input_block_chans) > 0:
|
||||
ich = input_block_chans.pop()
|
||||
else:
|
||||
ich = 0
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=-1,
|
||||
)
|
||||
)
|
||||
if level != len(channel_mult_up)-1 and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
Upsample(ch, True, dims=dims, out_channels=out_ch, factor=stride)
|
||||
)
|
||||
ds //= 2
|
||||
self.enc_output_blocks.append(nn.Sequential(*layers))
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, ch, out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
hs = []
|
||||
h = x
|
||||
for module in self.enc_input_blocks:
|
||||
h = module(h)
|
||||
hs.append(h)
|
||||
h = self.enc_middle_block(h)
|
||||
for module in self.enc_output_blocks:
|
||||
if len(hs) > 0:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class DiscreteVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_resolutions,
|
||||
in_channels = 3,
|
||||
model_channels = 64,
|
||||
out_channels = 3,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
dims=2,
|
||||
num_tokens = 512,
|
||||
codebook_dim = 512,
|
||||
convergence_layer=2,
|
||||
num_res_blocks = 0,
|
||||
stride = 2,
|
||||
straight_through = False,
|
||||
dropout=0,
|
||||
num_heads=4,
|
||||
record_codes=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.num_tokens = num_tokens
|
||||
self.num_layers = len(channel_mult)
|
||||
self.straight_through = straight_through
|
||||
self.codebook = Quantize(codebook_dim, num_tokens)
|
||||
self.positional_dims = dims
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.record_codes = record_codes
|
||||
if record_codes:
|
||||
self.codes = torch.zeros((32768,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
|
||||
enc_down = channel_mult
|
||||
enc_up = list(reversed(channel_mult[convergence_layer:]))
|
||||
self.encoder = DisjointUnet(attention_resolutions, enc_down, enc_up, in_channels=in_channels, model_channels=model_channels,
|
||||
out_channels=codebook_dim, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout,
|
||||
stride=stride)
|
||||
dec_down = list(reversed(enc_up))
|
||||
dec_up = list(reversed(enc_down))
|
||||
self.decoder = DisjointUnet(attention_resolutions, dec_down, dec_up, in_channels=codebook_dim, model_channels=model_channels,
|
||||
out_channels=out_channels, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout,
|
||||
stride=stride)
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.record_codes:
|
||||
# Report annealing schedule
|
||||
return {'histogram_codes': self.codes}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
img = images
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
return codes
|
||||
|
||||
def decode(
|
||||
self,
|
||||
img_seq
|
||||
):
|
||||
image_embeds = self.codebook.embed_code(img_seq)
|
||||
b, n, d = image_embeds.shape
|
||||
|
||||
kwargs = {}
|
||||
if self.positional_dims == 1:
|
||||
arrange = 'b n d -> b d n'
|
||||
else:
|
||||
h = w = int(sqrt(n))
|
||||
arrange = 'b (h w) d -> b d h w'
|
||||
kwargs = {'h': h, 'w': w}
|
||||
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
||||
images = self.decoder(image_embeds)
|
||||
return images
|
||||
|
||||
def infer(self, img):
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
return self.decode(codes)
|
||||
|
||||
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
||||
# more lossy (but useful for determining network performance).
|
||||
def forward(
|
||||
self,
|
||||
img
|
||||
):
|
||||
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||
sampled, commitment_loss, codes = self.codebook(logits)
|
||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
|
||||
if self.training:
|
||||
out = sampled
|
||||
out = self.decoder(out)
|
||||
else:
|
||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||
out = self.decode(codes)
|
||||
|
||||
# reconstruction loss
|
||||
recon_loss = F.mse_loss(img, out, reduction='none')
|
||||
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.record_codes and self.internal_step % 50 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i:i+l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
self.code_ind = 0
|
||||
self.internal_step += 1
|
||||
|
||||
return recon_loss, commitment_loss, out
|
||||
|
||||
|
||||
@register_model
|
||||
def register_my_dvae(opt_net, opt):
|
||||
return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = DiscreteVAE((8, 16), channel_mult=(1,2,4,8,8), in_channels=80, model_channels=128, out_channels=80, dims=1, num_res_blocks=2)
|
||||
inp = torch.randn((2,80,512))
|
||||
print([j.shape for j in net(inp)])
|
40
codes/scripts/audio/preparation/save_mels_to_disk.py
Normal file
40
codes/scripts/audio/preparation/save_mels_to_disk.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import argparse
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from spleeter.audio.adapter import AudioAdapter
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.util import find_audio_files
|
||||
# Uses pydub to process a directory of audio files, splitting them into clips at points where it detects a small amount
|
||||
# of silence.
|
||||
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path')
|
||||
args = parser.parse_args()
|
||||
files = find_audio_files(args.path, include_nonwav=True)
|
||||
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
|
||||
audio_loader = AudioAdapter.default()
|
||||
for e, wav_file in enumerate(tqdm(files)):
|
||||
if e < 272583:
|
||||
continue
|
||||
print(f"Processing {wav_file}..")
|
||||
outfile = f'{wav_file}.npz'
|
||||
|
||||
try:
|
||||
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)
|
||||
wave = torch.tensor(wave)[:,0].unsqueeze(0)
|
||||
wave = wave / wave.abs().max()
|
||||
except:
|
||||
print(f"Error with {wav_file}")
|
||||
continue
|
||||
|
||||
inj = mel_inj({'in': wave})
|
||||
numpy.savez_compressed(outfile, inj['out'].numpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -19,7 +19,7 @@ def main():
|
|||
maximum_duration = 20
|
||||
files = find_audio_files(args.path, include_nonwav=True)
|
||||
for e, wav_file in enumerate(tqdm(files)):
|
||||
if e < 2759:
|
||||
if e < 12593:
|
||||
continue
|
||||
print(f"Processing {wav_file}..")
|
||||
outdir = os.path.join(args.out, f'{e}_{os.path.basename(wav_file[:-4])}').replace('.', '').strip()
|
||||
|
|
|
@ -284,7 +284,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user