Alterations to diffusion_dvae so it can be used directly on spectrograms

This commit is contained in:
James Betker 2021-09-23 15:56:25 -06:00
parent 97ea329a59
commit 6833048bf7
11 changed files with 191 additions and 763 deletions

View File

@ -73,18 +73,8 @@ def create_dataset(dataset_opt, return_collate=False):
from data.audio.gpt_tts_dataset import GptTtsDataset as D
from data.audio.gpt_tts_dataset import GptTtsCollater as C
collate = C(dataset_opt)
elif mode == 'wavfile_clips':
from data.audio.wavfile_dataset import WavfileDataset as D
elif mode == 'unsupervised_audio':
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D
elif mode == 'stop_prediction':
from models.tacotron2.hparams import create_hparams
default_params = create_hparams()
default_params.update(dataset_opt)
dataset_opt = munchify(default_params)
from data.audio.stop_prediction_dataset import StopPredictionDataset as D
elif mode == 'stop_prediction2':
from data.audio.stop_prediction_dataset_2 import StopPredictionDataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)

View File

@ -1,142 +0,0 @@
import os
import pathlib
import random
import audio2numpy
import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
from tqdm import tqdm
import models.tacotron2.layers as layers
from data.audio.nv_tacotron_dataset import load_mozilla_cv, load_voxpopuli
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
from models.tacotron2.text import text_to_sequence
from utils.util import opt_get
def get_similar_files_libritts(filename):
filedir = os.path.dirname(filename)
return list(pathlib.Path(filedir).glob('*.wav'))
class StopPredictionDataset(torch.utils.data.Dataset):
"""
1) loads audio,text pairs
2) normalizes text and converts them to sequences of one-hot vectors
3) computes mel-spectrograms from audio files.
"""
def __init__(self, hparams):
self.path = hparams['path']
if not isinstance(self.path, list):
self.path = [self.path]
fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj')
if not isinstance(fetcher_mode, list):
fetcher_mode = [fetcher_mode]
assert len(self.path) == len(fetcher_mode)
self.audiopaths_and_text = []
for p, fm in zip(self.path, fetcher_mode):
if fm == 'lj' or fm == 'libritts':
fetcher_fn = load_filepaths_and_text
self.get_similar_files = get_similar_files_libritts
elif fm == 'voxpopuli':
fetcher_fn = load_voxpopuli
self.get_similar_files = None # TODO: Fix.
else:
raise NotImplementedError()
self.audiopaths_and_text.extend(fetcher_fn(p))
self.sampling_rate = hparams.sampling_rate
self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate)
self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
random.seed(hparams.seed)
random.shuffle(self.audiopaths_and_text)
self.max_mel_len = opt_get(hparams, ['max_mel_length'], None)
self.max_text_len = opt_get(hparams, ['max_text_length'], None)
def get_mel(self, filename):
filename = str(filename)
if filename.endswith('.wav'):
audio, sampling_rate = load_wav_to_torch(filename)
else:
audio, sampling_rate = audio2numpy.audio_from_file(filename)
audio = torch.tensor(audio)
if sampling_rate != self.input_sample_rate:
if sampling_rate < self.input_sample_rate:
print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.')
audio_norm = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze()
else:
audio_norm = audio
if audio_norm.std() > 1:
print(f"Something is very wrong with the given audio. std_dev={audio_norm.std()}. file={filename}")
return None
audio_norm.clip_(-1, 1)
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
if self.input_sample_rate != self.sampling_rate:
ratio = self.sampling_rate / self.input_sample_rate
audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
return melspec
def __getitem__(self, index):
path = self.audiopaths_and_text[index][0]
similar_files = self.get_similar_files(path)
mel = self.get_mel(path)
terms = torch.zeros(mel.shape[1])
terms[-1] = 1
while mel.shape[-1] < self.max_mel_len:
another_file = random.choice(similar_files)
another_mel = self.get_mel(another_file)
oterms = torch.zeros(another_mel.shape[1])
oterms[-1] = 1
mel = torch.cat([mel, another_mel], dim=-1)
terms = torch.cat([terms, oterms], dim=-1)
mel = mel[:, :self.max_mel_len]
terms = terms[:self.max_mel_len]
return {
'padded_mel': mel,
'termination_mask': terms,
}
def __len__(self):
return len(self.audiopaths_and_text)
if __name__ == '__main__':
params = {
'mode': 'stop_prediction',
'path': 'E:\\audio\\LibriTTS\\train-clean-360_list.txt',
'phase': 'train',
'n_workers': 0,
'batch_size': 16,
'fetcher_mode': 'libritts',
'max_mel_length': 800,
#'return_wavs': True,
#'input_sample_rate': 22050,
#'sampling_rate': 8000
}
from data import create_dataset, create_dataloader
ds, c = create_dataset(params, return_collate=True)
dl = create_dataloader(ds, params, collate_fn=c, shuffle=True)
i = 0
m = None
for k in range(1000):
for i, b in tqdm(enumerate(dl)):
continue
pm = b['padded_mel']
pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1]))
m = pm if m is None else torch.cat([m, pm], dim=0)
print(m.mean(), m.std())

View File

@ -1,92 +0,0 @@
import os
import pathlib
import random
from munch import munchify
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
from models.tacotron2 import hparams
from models.tacotron2.layers import TacotronSTFT
from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined
# set of clips from the librivox corpus of equal length with the sentence alignment labeled.
class StopPredictionDataset(Dataset):
def __init__(self, opt):
path = opt['path']
label_compaction = opt_get(opt, ['label_compaction'], 1)
hp = munchify(hparams.create_hparams())
cache_path = os.path.join(path, 'cache.pth')
if os.path.exists(cache_path):
self.files = torch.load(cache_path)
else:
print("Building cache..")
self.files = list(pathlib.Path(path).glob('*.wav'))
torch.save(self.files, cache_path)
self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR.
self.mel_length = 2000
self.stft = TacotronSTFT(
hp.filter_length, hp.hop_length, hp.win_length,
hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
hp.mel_fmax)
self.label_compaction = label_compaction
def __getitem__(self, index):
audio, _ = load_wav_to_torch(self.files[index])
starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth'))
if audio.std() > 1:
print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}")
return None
audio.clip_(-1, 1)
mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0)
# Form labels.
labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for s in starts:
# Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction.
s = s // (256 * self.label_compaction)
if s >= 2000//self.label_compaction:
continue
labels_start[s] = 1
labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long)
for e in ends:
e = e // (256 * self.label_compaction)
if e >= 2000//self.label_compaction:
continue
labels_end[e] = 1
return {
'mels': mels,
'labels_start': labels_start,
'labels_end': labels_end,
}
def __len__(self):
return len(self.files)
if __name__ == '__main__':
opt = {
'path': 'D:\\data\\audio\\libritts\\stop_dataset',
'label_compaction': 4,
}
ds = StopPredictionDataset(opt)
j = 0
for i in tqdm(range(100)):
b = ds[random.randint(0, len(ds))]
start_indices = torch.nonzero(b['labels_start']).squeeze(1)
end_indices = torch.nonzero(b['labels_end']).squeeze(1)
assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens.
for i in range(len(end_indices)):
s = start_indices[i].item()*4
e = end_indices[i].item()*4
m = b['mels'][:, s:e]
save_mel_buffer_to_file(m, f'{j}.npy')
j += 1

View File

@ -113,7 +113,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
audio_norm, filename = self.get_audio_for_index(index)
alt_files, actual_samples = self.get_related_audio_for_index(index)
except:
print(f"Error loading audio for file {filename} or {alt_files}")
print(f"Error loading audio for file {self.audiopaths[index]}")
return self[index+1]
# This is required when training to make sure all clips align.

View File

@ -1,135 +0,0 @@
import os
import random
import torch
import torch.utils.data
import torchaudio
from tqdm import tqdm
from data.audio.wav_aug import WavAugmentor
from data.util import find_files_of_type, is_wav_file
from models.tacotron2.taco_utils import load_wav_to_torch
from utils.util import opt_get
def load_audio_from_wav(audiopath, sampling_rate):
audio, lsr = load_wav_to_torch(audiopath)
if lsr != sampling_rate:
if lsr < sampling_rate:
print(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {sampling_rate}. This is not a good idea.')
audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=sampling_rate/lsr, mode='nearest', recompute_scale_factor=False).squeeze()
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
audio.clip_(-1, 1)
return audio.unsqueeze(0)
class WavfileDataset(torch.utils.data.Dataset):
def __init__(self, opt):
path = opt['path']
cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case.
if not isinstance(path, list):
path = [path]
if os.path.exists(cache_path):
self.audiopaths = torch.load(cache_path)
else:
print("Building cache..")
self.audiopaths = []
for p in path:
self.audiopaths.extend(find_files_of_type('img', p, qualifier=is_wav_file)[0])
torch.save(self.audiopaths, cache_path)
# Parse options
self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000)
self.pad_to = opt_get(opt, ['pad_to_seconds'], None)
if self.pad_to is not None:
self.pad_to *= self.sampling_rate
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
self.augment = opt_get(opt, ['do_augmentation'], False)
if self.augment:
# The "window size" for the clips produced in seconds.
self.window = 2 * self.sampling_rate
self.augmentor = WavAugmentor()
def get_audio_for_index(self, index):
audiopath = self.audiopaths[index]
audio = load_audio_from_wav(audiopath, self.sampling_rate)
return audio, audiopath
def __getitem__(self, index):
success = False
# This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why.
# Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know..
while not success:
try:
# Split audio_norm into two tensors of equal size.
audio_norm, filename = self.get_audio_for_index(index)
success = True
except:
print(f"Failed to load {index} {self.audiopaths[index]}")
if self.augment:
if audio_norm.shape[1] < self.window * 2:
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
return self[(index + 1) % len(self)]
j = random.randint(0, audio_norm.shape[1] - self.window)
clip1 = audio_norm[:, j:j+self.window]
if self.augment:
clip1 = self.augmentor.augment(clip1, self.sampling_rate)
j = random.randint(0, audio_norm.shape[1]-self.window)
clip2 = audio_norm[:, j:j+self.window]
if self.augment:
clip2 = self.augmentor.augment(clip2, self.sampling_rate)
# This is required when training to make sure all clips align.
if self.pad_to is not None:
if audio_norm.shape[-1] <= self.pad_to:
audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1]))
else:
gap = audio_norm.shape[-1] - self.pad_to
start = random.randint(0, gap-1)
audio_norm = audio_norm[:, start:start+self.pad_to]
output = {
'clip': audio_norm,
'path': filename,
}
if self.augment:
output.update({
'clip1': clip1[0, :].unsqueeze(0),
'clip2': clip2[0, :].unsqueeze(0),
})
return output
def __len__(self):
return len(self.audiopaths)
if __name__ == '__main__':
params = {
'mode': 'wavfile_clips',
'path': ['E:\\audio\\books-split', 'E:\\audio\\LibriTTS\\train-clean-360', 'D:\\data\\audio\\podcasts-split'],
'cache_path': 'E:\\audio\\clips-cache.pth',
'sampling_rate': 22050,
'pad_to_seconds': 5,
'phase': 'train',
'n_workers': 0,
'batch_size': 16,
'do_augmentation': False
}
from data import create_dataset, create_dataloader, util
ds = create_dataset(params)
dl = create_dataloader(ds, params)
i = 0
for b in tqdm(dl):
for b_ in range(16):
pass
#torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate)
#torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate)
#i += 1

View File

@ -5,11 +5,12 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res
import torch
import torch.nn as nn
from models.gpt_voice.lucidrains_dvae import eval_decorator
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
import models.gpt_voice.my_dvae as mdvae
from utils.util import get_mask_from_lengths
import models.gpt_voice.mini_encoder as menc
class DiscreteEncoder(nn.Module):
@ -22,13 +23,13 @@ class DiscreteEncoder(nn.Module):
super().__init__()
self.blocks = nn.Sequential(
conv_nd(1, in_channels, model_channels, 3, padding=1),
mdvae.ResBlock(model_channels, dropout, dims=1),
menc.ResBlock(model_channels, dropout, dims=1),
Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale),
mdvae.ResBlock(model_channels*2, dropout, dims=1),
menc.ResBlock(model_channels*2, dropout, dims=1),
Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale),
mdvae.ResBlock(model_channels*4, dropout, dims=1),
menc.ResBlock(model_channels*4, dropout, dims=1),
AttentionBlock(model_channels*4, num_heads=4),
mdvae.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1),
)
def forward(self, spectrogram):
@ -249,6 +250,21 @@ class DiffusionDVAE(nn.Module):
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
)
def get_debug_values(self, step, __):
if self.record_codes:
# Report annealing schedule
return {'histogram_codes': self.codes}
else:
return {}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
return codes
def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals):
if self.conditioning_enabled:
assert conditioning_inputs is not None
@ -299,17 +315,28 @@ class DiffusionDVAE(nn.Module):
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals)
def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None):
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
# Compute DVAE portion first.
spec_logits = self.encoder(spectrogram).permute((0,2,1))
sampled, commitment_loss, codes = self.quantizer(spec_logits)
if self.training:
# Compute from softmax outputs to preserve gradients.
embeddings = sampled.permute((0,2,1))
else:
# Compute from codes only.
embeddings = self.quantizer.embed_code(codes).permute((0,2,1))
# This is so we can debug the distribution of codes being learned.
if self.internal_step % 50 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.internal_step += 1
return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss
@ -318,12 +345,44 @@ def register_unet_diffusion_dvae(opt_net, opt):
return DiffusionDVAE(**opt_net['kwargs'])
'''
class DiffusionDVAE(nn.Module):
def __init__(
self,
model_channels,
num_res_blocks,
in_channels=1,
out_channels=2, # mean and variance
spectrogram_channels=80,
spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform.
dropout=0,
channel_mult=(1, 2, 4, 8, 16, 32, 64),
attention_resolutions=(16,32,64),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
use_new_attention_order=False,
kernel_size=5,
quantize_dim=1024,
num_discrete_codes=8192,
scale_steps=4,
conditioning_inputs_provided=True,
):
'''
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = torch.randn(4, 1, 81920)
spec = torch.randn(4, 80, 416)
cond = torch.randn(4, 5, 80, 200)
num_cond = torch.tensor([2,4,5,3], dtype=torch.long)
ts = torch.LongTensor([432, 234, 100, 555])
model = DiffusionDVAE(32, 2)
print(model(clip, ts, spec, cond, num_cond)[0].shape)
model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2],
channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False)
print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape)

View File

@ -4,13 +4,91 @@ import torch.nn as nn
from models.diffusion.nn import normalization, conv_nd, zero_module
from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy
from models.gpt_voice.my_dvae import ResBlock
# Combined resnet & full-attention encoder for converting an audio clip into an embedding.
from utils.util import checkpoint
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
kernel_size=3,
do_checkpoint=True,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.do_checkpoint = do_checkpoint
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AudioMiniEncoder(nn.Module):
def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0):
super().__init__()

View File

@ -1,370 +0,0 @@
import functools
import math
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from models.diffusion.nn import conv_nd, normalization, zero_module
from models.diffusion.unet_diffusion import Upsample, Downsample, AttentionBlock
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
from utils.util import opt_get, checkpoint
def default(val, d):
return val if val is not None else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
kernel_size=3,
do_checkpoint=True,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.do_checkpoint = do_checkpoint
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.do_checkpoint:
return checkpoint(
self._forward, x
)
else:
return self._forward(x)
def _forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class DisjointUnet(nn.Module):
def __init__(
self,
attention_resolutions,
channel_mult_down,
channel_mult_up,
in_channels = 3,
model_channels = 64,
out_channels = 3,
dims=2,
num_res_blocks = 2,
stride = 2,
dropout=0,
num_heads=4,
):
super().__init__()
self.enc_input_blocks = nn.ModuleList(
[
conv_nd(dims, in_channels, model_channels, 3, padding=1)
]
)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult_down):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
dropout,
out_channels=mult * model_channels,
dims=dims,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=-1,
)
)
self.enc_input_blocks.append(nn.Sequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult_down) - 1:
out_ch = ch
self.enc_input_blocks.append(
Downsample(
ch, True, dims=dims, out_channels=out_ch, factor=stride
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self.enc_middle_block = nn.Sequential(
ResBlock(
ch,
dropout,
dims=dims,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=-1,
),
ResBlock(
ch,
dropout,
dims=dims,
),
)
self.enc_output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult_up)):
for i in range(num_res_blocks + 1):
if len(input_block_chans) > 0:
ich = input_block_chans.pop()
else:
ich = 0
layers = [
ResBlock(
ch + ich,
dropout,
out_channels=model_channels * mult,
dims=dims,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=-1,
)
)
if level != len(channel_mult_up)-1 and i == num_res_blocks:
out_ch = ch
layers.append(
Upsample(ch, True, dims=dims, out_channels=out_ch, factor=stride)
)
ds //= 2
self.enc_output_blocks.append(nn.Sequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(dims, ch, out_channels, 3, padding=1),
)
def forward(self, x):
hs = []
h = x
for module in self.enc_input_blocks:
h = module(h)
hs.append(h)
h = self.enc_middle_block(h)
for module in self.enc_output_blocks:
if len(hs) > 0:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h)
h = h.type(x.dtype)
return self.out(h)
class DiscreteVAE(nn.Module):
def __init__(
self,
attention_resolutions,
in_channels = 3,
model_channels = 64,
out_channels = 3,
channel_mult=(1, 2, 4, 8),
dims=2,
num_tokens = 512,
codebook_dim = 512,
convergence_layer=2,
num_res_blocks = 0,
stride = 2,
straight_through = False,
dropout=0,
num_heads=4,
record_codes=True,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.num_tokens = num_tokens
self.num_layers = len(channel_mult)
self.straight_through = straight_through
self.codebook = Quantize(codebook_dim, num_tokens)
self.positional_dims = dims
self.dropout = dropout
self.num_heads = num_heads
self.record_codes = record_codes
if record_codes:
self.codes = torch.zeros((32768,), dtype=torch.long)
self.code_ind = 0
self.internal_step = 0
enc_down = channel_mult
enc_up = list(reversed(channel_mult[convergence_layer:]))
self.encoder = DisjointUnet(attention_resolutions, enc_down, enc_up, in_channels=in_channels, model_channels=model_channels,
out_channels=codebook_dim, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout,
stride=stride)
dec_down = list(reversed(enc_up))
dec_up = list(reversed(enc_down))
self.decoder = DisjointUnet(attention_resolutions, dec_down, dec_up, in_channels=codebook_dim, model_channels=model_channels,
out_channels=out_channels, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout,
stride=stride)
def get_debug_values(self, step, __):
if self.record_codes:
# Report annealing schedule
return {'histogram_codes': self.codes}
else:
return {}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = images
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
return codes
def decode(
self,
img_seq
):
image_embeds = self.codebook.embed_code(img_seq)
b, n, d = image_embeds.shape
kwargs = {}
if self.positional_dims == 1:
arrange = 'b n d -> b d n'
else:
h = w = int(sqrt(n))
arrange = 'b (h w) d -> b d h w'
kwargs = {'h': h, 'w': w}
image_embeds = rearrange(image_embeds, arrange, **kwargs)
images = self.decoder(image_embeds)
return images
def infer(self, img):
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
return self.decode(codes)
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
# more lossy (but useful for determining network performance).
def forward(
self,
img
):
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, commitment_loss, codes = self.codebook(logits)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
if self.training:
out = sampled
out = self.decoder(out)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out = self.decode(codes)
# reconstruction loss
recon_loss = F.mse_loss(img, out, reduction='none')
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 50 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.internal_step += 1
return recon_loss, commitment_loss, out
@register_model
def register_my_dvae(opt_net, opt):
return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
net = DiscreteVAE((8, 16), channel_mult=(1,2,4,8,8), in_channels=80, model_channels=128, out_channels=80, dims=1, num_res_blocks=2)
inp = torch.randn((2,80,512))
print([j.shape for j in net(inp)])

View File

@ -0,0 +1,40 @@
import argparse
import numpy
import torch
from spleeter.audio.adapter import AudioAdapter
from tqdm import tqdm
from data.util import find_audio_files
# Uses pydub to process a directory of audio files, splitting them into clips at points where it detects a small amount
# of silence.
from trainer.injectors.base_injectors import MelSpectrogramInjector
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path')
args = parser.parse_args()
files = find_audio_files(args.path, include_nonwav=True)
mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {})
audio_loader = AudioAdapter.default()
for e, wav_file in enumerate(tqdm(files)):
if e < 272583:
continue
print(f"Processing {wav_file}..")
outfile = f'{wav_file}.npz'
try:
wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050)
wave = torch.tensor(wave)[:,0].unsqueeze(0)
wave = wave / wave.abs().max()
except:
print(f"Error with {wav_file}")
continue
inj = mel_inj({'in': wave})
numpy.savez_compressed(outfile, inj['out'].numpy())
if __name__ == '__main__':
main()

View File

@ -19,7 +19,7 @@ def main():
maximum_duration = 20
files = find_audio_files(args.path, include_nonwav=True)
for e, wav_file in enumerate(tqdm(files)):
if e < 2759:
if e < 12593:
continue
print(f"Processing {wav_file}..")
outdir = os.path.join(args.out, f'{e}_{os.path.basename(wav_file[:-4])}').replace('.', '').strip()

View File

@ -284,7 +284,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()