From 6833048bf7609ae175d9f41baf4049365443261b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Sep 2021 15:56:25 -0600 Subject: [PATCH] Alterations to diffusion_dvae so it can be used directly on spectrograms --- codes/data/__init__.py | 10 - codes/data/audio/stop_prediction_dataset.py | 142 ------- codes/data/audio/stop_prediction_dataset_2.py | 92 ----- .../data/audio/unsupervised_audio_dataset.py | 2 +- codes/data/audio/wavfile_dataset.py | 135 ------- codes/models/diffusion/diffusion_dvae.py | 79 +++- codes/models/gpt_voice/mini_encoder.py | 80 +++- codes/models/gpt_voice/my_dvae.py | 370 ------------------ .../audio/preparation/save_mels_to_disk.py | 40 ++ .../audio/preparation/split_on_silence.py | 2 +- codes/train.py | 2 +- 11 files changed, 191 insertions(+), 763 deletions(-) delete mode 100644 codes/data/audio/stop_prediction_dataset.py delete mode 100644 codes/data/audio/stop_prediction_dataset_2.py delete mode 100644 codes/data/audio/wavfile_dataset.py delete mode 100644 codes/models/gpt_voice/my_dvae.py create mode 100644 codes/scripts/audio/preparation/save_mels_to_disk.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 728b2d66..9d4ba52b 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -73,18 +73,8 @@ def create_dataset(dataset_opt, return_collate=False): from data.audio.gpt_tts_dataset import GptTtsDataset as D from data.audio.gpt_tts_dataset import GptTtsCollater as C collate = C(dataset_opt) - elif mode == 'wavfile_clips': - from data.audio.wavfile_dataset import WavfileDataset as D elif mode == 'unsupervised_audio': from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset as D - elif mode == 'stop_prediction': - from models.tacotron2.hparams import create_hparams - default_params = create_hparams() - default_params.update(dataset_opt) - dataset_opt = munchify(default_params) - from data.audio.stop_prediction_dataset import StopPredictionDataset as D - elif mode == 'stop_prediction2': - from data.audio.stop_prediction_dataset_2 import StopPredictionDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/audio/stop_prediction_dataset.py b/codes/data/audio/stop_prediction_dataset.py deleted file mode 100644 index acdc98dd..00000000 --- a/codes/data/audio/stop_prediction_dataset.py +++ /dev/null @@ -1,142 +0,0 @@ -import os -import pathlib -import random - -import audio2numpy -import numpy as np -import torch -import torch.utils.data -import torch.nn.functional as F -from tqdm import tqdm - -import models.tacotron2.layers as layers -from data.audio.nv_tacotron_dataset import load_mozilla_cv, load_voxpopuli -from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text - -from models.tacotron2.text import text_to_sequence -from utils.util import opt_get - - -def get_similar_files_libritts(filename): - filedir = os.path.dirname(filename) - return list(pathlib.Path(filedir).glob('*.wav')) - - -class StopPredictionDataset(torch.utils.data.Dataset): - """ - 1) loads audio,text pairs - 2) normalizes text and converts them to sequences of one-hot vectors - 3) computes mel-spectrograms from audio files. - """ - def __init__(self, hparams): - self.path = hparams['path'] - if not isinstance(self.path, list): - self.path = [self.path] - - fetcher_mode = opt_get(hparams, ['fetcher_mode'], 'lj') - if not isinstance(fetcher_mode, list): - fetcher_mode = [fetcher_mode] - assert len(self.path) == len(fetcher_mode) - - self.audiopaths_and_text = [] - for p, fm in zip(self.path, fetcher_mode): - if fm == 'lj' or fm == 'libritts': - fetcher_fn = load_filepaths_and_text - self.get_similar_files = get_similar_files_libritts - elif fm == 'voxpopuli': - fetcher_fn = load_voxpopuli - self.get_similar_files = None # TODO: Fix. - else: - raise NotImplementedError() - self.audiopaths_and_text.extend(fetcher_fn(p)) - self.sampling_rate = hparams.sampling_rate - self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate) - self.stft = layers.TacotronSTFT( - hparams.filter_length, hparams.hop_length, hparams.win_length, - hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, - hparams.mel_fmax) - random.seed(hparams.seed) - random.shuffle(self.audiopaths_and_text) - self.max_mel_len = opt_get(hparams, ['max_mel_length'], None) - self.max_text_len = opt_get(hparams, ['max_text_length'], None) - - def get_mel(self, filename): - filename = str(filename) - if filename.endswith('.wav'): - audio, sampling_rate = load_wav_to_torch(filename) - else: - audio, sampling_rate = audio2numpy.audio_from_file(filename) - audio = torch.tensor(audio) - - if sampling_rate != self.input_sample_rate: - if sampling_rate < self.input_sample_rate: - print(f'{filename} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {self.input_sample_rate}. This is not a good idea.') - audio_norm = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=self.input_sample_rate/sampling_rate, mode='nearest', recompute_scale_factor=False).squeeze() - else: - audio_norm = audio - if audio_norm.std() > 1: - print(f"Something is very wrong with the given audio. std_dev={audio_norm.std()}. file={filename}") - return None - audio_norm.clip_(-1, 1) - audio_norm = audio_norm.unsqueeze(0) - audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) - if self.input_sample_rate != self.sampling_rate: - ratio = self.sampling_rate / self.input_sample_rate - audio_norm = torch.nn.functional.interpolate(audio_norm.unsqueeze(0), scale_factor=ratio, mode='area').squeeze(0) - melspec = self.stft.mel_spectrogram(audio_norm) - melspec = torch.squeeze(melspec, 0) - - return melspec - - def __getitem__(self, index): - path = self.audiopaths_and_text[index][0] - similar_files = self.get_similar_files(path) - mel = self.get_mel(path) - terms = torch.zeros(mel.shape[1]) - terms[-1] = 1 - while mel.shape[-1] < self.max_mel_len: - another_file = random.choice(similar_files) - another_mel = self.get_mel(another_file) - oterms = torch.zeros(another_mel.shape[1]) - oterms[-1] = 1 - mel = torch.cat([mel, another_mel], dim=-1) - terms = torch.cat([terms, oterms], dim=-1) - mel = mel[:, :self.max_mel_len] - terms = terms[:self.max_mel_len] - - - return { - 'padded_mel': mel, - 'termination_mask': terms, - } - - def __len__(self): - return len(self.audiopaths_and_text) - - -if __name__ == '__main__': - params = { - 'mode': 'stop_prediction', - 'path': 'E:\\audio\\LibriTTS\\train-clean-360_list.txt', - 'phase': 'train', - 'n_workers': 0, - 'batch_size': 16, - 'fetcher_mode': 'libritts', - 'max_mel_length': 800, - #'return_wavs': True, - #'input_sample_rate': 22050, - #'sampling_rate': 8000 - } - from data import create_dataset, create_dataloader - - ds, c = create_dataset(params, return_collate=True) - dl = create_dataloader(ds, params, collate_fn=c, shuffle=True) - i = 0 - m = None - for k in range(1000): - for i, b in tqdm(enumerate(dl)): - continue - pm = b['padded_mel'] - pm = torch.nn.functional.pad(pm, (0, 800-pm.shape[-1])) - m = pm if m is None else torch.cat([m, pm], dim=0) - print(m.mean(), m.std()) \ No newline at end of file diff --git a/codes/data/audio/stop_prediction_dataset_2.py b/codes/data/audio/stop_prediction_dataset_2.py deleted file mode 100644 index 61c93a45..00000000 --- a/codes/data/audio/stop_prediction_dataset_2.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import pathlib -import random - -from munch import munchify -from torch.utils.data import Dataset -import torch -from tqdm import tqdm - -from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file -from models.tacotron2 import hparams -from models.tacotron2.layers import TacotronSTFT -from models.tacotron2.taco_utils import load_wav_to_torch -from utils.util import opt_get - - -# A dataset that consumes the result from the script `produce_libri_stretched_dataset`, which itself is a combined -# set of clips from the librivox corpus of equal length with the sentence alignment labeled. -class StopPredictionDataset(Dataset): - def __init__(self, opt): - path = opt['path'] - label_compaction = opt_get(opt, ['label_compaction'], 1) - hp = munchify(hparams.create_hparams()) - cache_path = os.path.join(path, 'cache.pth') - if os.path.exists(cache_path): - self.files = torch.load(cache_path) - else: - print("Building cache..") - self.files = list(pathlib.Path(path).glob('*.wav')) - torch.save(self.files, cache_path) - self.sampling_rate = 22050 # Fixed since the underlying data is also fixed at this SR. - self.mel_length = 2000 - self.stft = TacotronSTFT( - hp.filter_length, hp.hop_length, hp.win_length, - hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin, - hp.mel_fmax) - self.label_compaction = label_compaction - - def __getitem__(self, index): - audio, _ = load_wav_to_torch(self.files[index]) - starts, ends = torch.load(str(self.files[index]).replace('.wav', '_se.pth')) - - if audio.std() > 1: - print(f"Something is very wrong with the given audio. std_dev={audio.std()}. file={self.files[index]}") - return None - audio.clip_(-1, 1) - mels = self.stft.mel_spectrogram(audio.unsqueeze(0))[:, :, :self.mel_length].squeeze(0) - - # Form labels. - labels_start = torch.zeros((2000 // self.label_compaction,), dtype=torch.long) - for s in starts: - # Mel compaction operates at a ratio of 1/256, the dataset also allows further compaction. - s = s // (256 * self.label_compaction) - if s >= 2000//self.label_compaction: - continue - labels_start[s] = 1 - labels_end = torch.zeros((2000 // self.label_compaction,), dtype=torch.long) - for e in ends: - e = e // (256 * self.label_compaction) - if e >= 2000//self.label_compaction: - continue - labels_end[e] = 1 - - return { - 'mels': mels, - 'labels_start': labels_start, - 'labels_end': labels_end, - } - - - def __len__(self): - return len(self.files) - - -if __name__ == '__main__': - opt = { - 'path': 'D:\\data\\audio\\libritts\\stop_dataset', - 'label_compaction': 4, - } - ds = StopPredictionDataset(opt) - j = 0 - for i in tqdm(range(100)): - b = ds[random.randint(0, len(ds))] - start_indices = torch.nonzero(b['labels_start']).squeeze(1) - end_indices = torch.nonzero(b['labels_end']).squeeze(1) - assert len(end_indices) <= len(start_indices) # There should always be more START tokens then END tokens. - for i in range(len(end_indices)): - s = start_indices[i].item()*4 - e = end_indices[i].item()*4 - m = b['mels'][:, s:e] - save_mel_buffer_to_file(m, f'{j}.npy') - j += 1 \ No newline at end of file diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 5e4fd300..0eaea644 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -113,7 +113,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): audio_norm, filename = self.get_audio_for_index(index) alt_files, actual_samples = self.get_related_audio_for_index(index) except: - print(f"Error loading audio for file {filename} or {alt_files}") + print(f"Error loading audio for file {self.audiopaths[index]}") return self[index+1] # This is required when training to make sure all clips align. diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py deleted file mode 100644 index a8b3a2b2..00000000 --- a/codes/data/audio/wavfile_dataset.py +++ /dev/null @@ -1,135 +0,0 @@ -import os -import random - -import torch -import torch.utils.data -import torchaudio -from tqdm import tqdm - -from data.audio.wav_aug import WavAugmentor -from data.util import find_files_of_type, is_wav_file -from models.tacotron2.taco_utils import load_wav_to_torch -from utils.util import opt_get - - -def load_audio_from_wav(audiopath, sampling_rate): - audio, lsr = load_wav_to_torch(audiopath) - if lsr != sampling_rate: - if lsr < sampling_rate: - print(f'{audiopath} has a sample rate of {sampling_rate} which is lower than the requested sample rate of {sampling_rate}. This is not a good idea.') - audio = torch.nn.functional.interpolate(audio.unsqueeze(0).unsqueeze(1), scale_factor=sampling_rate/lsr, mode='nearest', recompute_scale_factor=False).squeeze() - - # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. - # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. - if torch.any(audio > 2) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") - audio.clip_(-1, 1) - return audio.unsqueeze(0) - - -class WavfileDataset(torch.utils.data.Dataset): - - def __init__(self, opt): - path = opt['path'] - cache_path = opt['cache_path'] # Will fail when multiple paths specified, must be specified in this case. - if not isinstance(path, list): - path = [path] - if os.path.exists(cache_path): - self.audiopaths = torch.load(cache_path) - else: - print("Building cache..") - self.audiopaths = [] - for p in path: - self.audiopaths.extend(find_files_of_type('img', p, qualifier=is_wav_file)[0]) - torch.save(self.audiopaths, cache_path) - - # Parse options - self.sampling_rate = opt_get(opt, ['sampling_rate'], 24000) - self.pad_to = opt_get(opt, ['pad_to_seconds'], None) - if self.pad_to is not None: - self.pad_to *= self.sampling_rate - self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to) - - self.augment = opt_get(opt, ['do_augmentation'], False) - if self.augment: - # The "window size" for the clips produced in seconds. - self.window = 2 * self.sampling_rate - self.augmentor = WavAugmentor() - - def get_audio_for_index(self, index): - audiopath = self.audiopaths[index] - audio = load_audio_from_wav(audiopath, self.sampling_rate) - return audio, audiopath - - def __getitem__(self, index): - success = False - # This "success" thing is a hack: This dataset is randomly failing for no apparent good reason and I don't know why. - # Symptoms are it complaining about being unable to read a nonsensical filename that is clearly corrupted. Memory corruption? I don't know.. - while not success: - try: - # Split audio_norm into two tensors of equal size. - audio_norm, filename = self.get_audio_for_index(index) - success = True - except: - print(f"Failed to load {index} {self.audiopaths[index]}") - - if self.augment: - if audio_norm.shape[1] < self.window * 2: - # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this. - return self[(index + 1) % len(self)] - j = random.randint(0, audio_norm.shape[1] - self.window) - clip1 = audio_norm[:, j:j+self.window] - if self.augment: - clip1 = self.augmentor.augment(clip1, self.sampling_rate) - j = random.randint(0, audio_norm.shape[1]-self.window) - clip2 = audio_norm[:, j:j+self.window] - if self.augment: - clip2 = self.augmentor.augment(clip2, self.sampling_rate) - - # This is required when training to make sure all clips align. - if self.pad_to is not None: - if audio_norm.shape[-1] <= self.pad_to: - audio_norm = torch.nn.functional.pad(audio_norm, (0, self.pad_to - audio_norm.shape[-1])) - else: - gap = audio_norm.shape[-1] - self.pad_to - start = random.randint(0, gap-1) - audio_norm = audio_norm[:, start:start+self.pad_to] - - output = { - 'clip': audio_norm, - 'path': filename, - } - if self.augment: - output.update({ - 'clip1': clip1[0, :].unsqueeze(0), - 'clip2': clip2[0, :].unsqueeze(0), - }) - return output - - def __len__(self): - return len(self.audiopaths) - - -if __name__ == '__main__': - params = { - 'mode': 'wavfile_clips', - 'path': ['E:\\audio\\books-split', 'E:\\audio\\LibriTTS\\train-clean-360', 'D:\\data\\audio\\podcasts-split'], - 'cache_path': 'E:\\audio\\clips-cache.pth', - 'sampling_rate': 22050, - 'pad_to_seconds': 5, - 'phase': 'train', - 'n_workers': 0, - 'batch_size': 16, - 'do_augmentation': False - } - from data import create_dataset, create_dataloader, util - - ds = create_dataset(params) - dl = create_dataloader(ds, params) - i = 0 - for b in tqdm(dl): - for b_ in range(16): - pass - #torchaudio.save(f'{i}_clip1_{b_}.wav', b['clip1'][b_], ds.sampling_rate) - #torchaudio.save(f'{i}_clip2_{b_}.wav', b['clip2'][b_], ds.sampling_rate) - #i += 1 diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 09977a06..fa1f9b50 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -5,11 +5,12 @@ from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, Res import torch import torch.nn as nn +from models.gpt_voice.lucidrains_dvae import eval_decorator from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.vqvae.vqvae import Quantize from trainer.networks import register_model -import models.gpt_voice.my_dvae as mdvae from utils.util import get_mask_from_lengths +import models.gpt_voice.mini_encoder as menc class DiscreteEncoder(nn.Module): @@ -22,13 +23,13 @@ class DiscreteEncoder(nn.Module): super().__init__() self.blocks = nn.Sequential( conv_nd(1, in_channels, model_channels, 3, padding=1), - mdvae.ResBlock(model_channels, dropout, dims=1), + menc.ResBlock(model_channels, dropout, dims=1), Downsample(model_channels, use_conv=True, dims=1, out_channels=model_channels*2, factor=scale), - mdvae.ResBlock(model_channels*2, dropout, dims=1), + menc.ResBlock(model_channels*2, dropout, dims=1), Downsample(model_channels*2, use_conv=True, dims=1, out_channels=model_channels*4, factor=scale), - mdvae.ResBlock(model_channels*4, dropout, dims=1), + menc.ResBlock(model_channels*4, dropout, dims=1), AttentionBlock(model_channels*4, num_heads=4), - mdvae.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1), + menc.ResBlock(model_channels*4, dropout, out_channels=out_channels, dims=1), ) def forward(self, spectrogram): @@ -249,6 +250,21 @@ class DiffusionDVAE(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) + def get_debug_values(self, step, __): + if self.record_codes: + # Report annealing schedule + return {'histogram_codes': self.codes} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + img = self.norm(images) + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, commitment_loss, codes = self.codebook(logits) + return codes + def _decode_continouous(self, x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals): if self.conditioning_enabled: assert conditioning_inputs is not None @@ -299,17 +315,28 @@ class DiffusionDVAE(nn.Module): return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals) def forward(self, x, timesteps, spectrogram, conditioning_inputs=None, num_conditioning_signals=None): - assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. - # Compute DVAE portion first. spec_logits = self.encoder(spectrogram).permute((0,2,1)) sampled, commitment_loss, codes = self.quantizer(spec_logits) + if self.training: # Compute from softmax outputs to preserve gradients. embeddings = sampled.permute((0,2,1)) else: # Compute from codes only. embeddings = self.quantizer.embed_code(codes).permute((0,2,1)) + + # This is so we can debug the distribution of codes being learned. + if self.internal_step % 50 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.internal_step += 1 + return self._decode_continouous(x, timesteps, embeddings, conditioning_inputs, num_conditioning_signals), commitment_loss @@ -318,12 +345,44 @@ def register_unet_diffusion_dvae(opt_net, opt): return DiffusionDVAE(**opt_net['kwargs']) + +''' + + +class DiffusionDVAE(nn.Module): + def __init__( + self, + model_channels, + num_res_blocks, + in_channels=1, + out_channels=2, # mean and variance + spectrogram_channels=80, + spectrogram_conditioning_levels=[3,4,5], # Levels at which spectrogram conditioning is applied to the waveform. + dropout=0, + channel_mult=(1, 2, 4, 8, 16, 32, 64), + attention_resolutions=(16,32,64), + conv_resample=True, + dims=1, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + use_new_attention_order=False, + kernel_size=5, + quantize_dim=1024, + num_discrete_codes=8192, + scale_steps=4, + conditioning_inputs_provided=True, + ): + ''' + # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': - clip = torch.randn(4, 1, 81920) spec = torch.randn(4, 80, 416) cond = torch.randn(4, 5, 80, 200) num_cond = torch.tensor([2,4,5,3], dtype=torch.long) ts = torch.LongTensor([432, 234, 100, 555]) - model = DiffusionDVAE(32, 2) - print(model(clip, ts, spec, cond, num_cond)[0].shape) + model = DiffusionDVAE(model_channels=128, num_res_blocks=1, in_channels=80, out_channels=160, spectrogram_conditioning_levels=[1,2], + channel_mult=(1,2,4), attention_resolutions=[4], num_heads=4, kernel_size=3, scale_steps=2, conditioning_inputs_provided=False) + print(model(torch.randn_like(spec), ts, spec, cond, num_cond)[0].shape) diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index 22d6c3f4..8b2d54b5 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -4,13 +4,91 @@ import torch.nn as nn from models.diffusion.nn import normalization, conv_nd, zero_module from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy -from models.gpt_voice.my_dvae import ResBlock # Combined resnet & full-attention encoder for converting an audio clip into an embedding. from utils.util import checkpoint +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x): + if self.do_checkpoint: + return checkpoint( + self._forward, x + ) + else: + return self._forward(x) + + def _forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + class AudioMiniEncoder(nn.Module): def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0): super().__init__() diff --git a/codes/models/gpt_voice/my_dvae.py b/codes/models/gpt_voice/my_dvae.py deleted file mode 100644 index 5443c609..00000000 --- a/codes/models/gpt_voice/my_dvae.py +++ /dev/null @@ -1,370 +0,0 @@ -import functools -import math -from math import sqrt - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import einsum - -from models.diffusion.nn import conv_nd, normalization, zero_module -from models.diffusion.unet_diffusion import Upsample, Downsample, AttentionBlock -from models.vqvae.vqvae import Quantize -from trainer.networks import register_model -from utils.util import opt_get, checkpoint - - -def default(val, d): - return val if val is not None else d - - -def eval_decorator(fn): - def inner(model, *args, **kwargs): - was_training = model.training - model.eval() - out = fn(model, *args, **kwargs) - model.train(was_training) - return out - return inner - - -class ResBlock(nn.Module): - def __init__( - self, - channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - up=False, - down=False, - kernel_size=3, - do_checkpoint=True, - ): - super().__init__() - self.channels = channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_scale_shift_norm = use_scale_shift_norm - self.do_checkpoint = do_checkpoint - padding = 1 if kernel_size == 3 else 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, kernel_size, padding=padding - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x): - if self.do_checkpoint: - return checkpoint( - self._forward, x - ) - else: - return self._forward(x) - - def _forward(self, x): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DisjointUnet(nn.Module): - def __init__( - self, - attention_resolutions, - channel_mult_down, - channel_mult_up, - in_channels = 3, - model_channels = 64, - out_channels = 3, - dims=2, - num_res_blocks = 2, - stride = 2, - dropout=0, - num_heads=4, - ): - super().__init__() - - self.enc_input_blocks = nn.ModuleList( - [ - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ] - ) - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult_down): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - dropout, - out_channels=mult * model_channels, - dims=dims, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=-1, - ) - ) - self.enc_input_blocks.append(nn.Sequential(*layers)) - input_block_chans.append(ch) - if level != len(channel_mult_down) - 1: - out_ch = ch - self.enc_input_blocks.append( - Downsample( - ch, True, dims=dims, out_channels=out_ch, factor=stride - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - - self.enc_middle_block = nn.Sequential( - ResBlock( - ch, - dropout, - dims=dims, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=-1, - ), - ResBlock( - ch, - dropout, - dims=dims, - ), - ) - - self.enc_output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult_up)): - for i in range(num_res_blocks + 1): - if len(input_block_chans) > 0: - ich = input_block_chans.pop() - else: - ich = 0 - layers = [ - ResBlock( - ch + ich, - dropout, - out_channels=model_channels * mult, - dims=dims, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=-1, - ) - ) - if level != len(channel_mult_up)-1 and i == num_res_blocks: - out_ch = ch - layers.append( - Upsample(ch, True, dims=dims, out_channels=out_ch, factor=stride) - ) - ds //= 2 - self.enc_output_blocks.append(nn.Sequential(*layers)) - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - conv_nd(dims, ch, out_channels, 3, padding=1), - ) - - def forward(self, x): - hs = [] - h = x - for module in self.enc_input_blocks: - h = module(h) - hs.append(h) - h = self.enc_middle_block(h) - for module in self.enc_output_blocks: - if len(hs) > 0: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h) - h = h.type(x.dtype) - return self.out(h) - - -class DiscreteVAE(nn.Module): - def __init__( - self, - attention_resolutions, - in_channels = 3, - model_channels = 64, - out_channels = 3, - channel_mult=(1, 2, 4, 8), - dims=2, - num_tokens = 512, - codebook_dim = 512, - convergence_layer=2, - num_res_blocks = 0, - stride = 2, - straight_through = False, - dropout=0, - num_heads=4, - record_codes=True, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.num_tokens = num_tokens - self.num_layers = len(channel_mult) - self.straight_through = straight_through - self.codebook = Quantize(codebook_dim, num_tokens) - self.positional_dims = dims - self.dropout = dropout - self.num_heads = num_heads - self.record_codes = record_codes - if record_codes: - self.codes = torch.zeros((32768,), dtype=torch.long) - self.code_ind = 0 - self.internal_step = 0 - - enc_down = channel_mult - enc_up = list(reversed(channel_mult[convergence_layer:])) - self.encoder = DisjointUnet(attention_resolutions, enc_down, enc_up, in_channels=in_channels, model_channels=model_channels, - out_channels=codebook_dim, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout, - stride=stride) - dec_down = list(reversed(enc_up)) - dec_up = list(reversed(enc_down)) - self.decoder = DisjointUnet(attention_resolutions, dec_down, dec_up, in_channels=codebook_dim, model_channels=model_channels, - out_channels=out_channels, dims=dims, num_res_blocks=num_res_blocks, num_heads=num_heads, dropout=dropout, - stride=stride) - - def get_debug_values(self, step, __): - if self.record_codes: - # Report annealing schedule - return {'histogram_codes': self.codes} - else: - return {} - - @torch.no_grad() - @eval_decorator - def get_codebook_indices(self, images): - img = images - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) - return codes - - def decode( - self, - img_seq - ): - image_embeds = self.codebook.embed_code(img_seq) - b, n, d = image_embeds.shape - - kwargs = {} - if self.positional_dims == 1: - arrange = 'b n d -> b d n' - else: - h = w = int(sqrt(n)) - arrange = 'b (h w) d -> b d h w' - kwargs = {'h': h, 'w': w} - image_embeds = rearrange(image_embeds, arrange, **kwargs) - images = self.decoder(image_embeds) - return images - - def infer(self, img): - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) - return self.decode(codes) - - # Note: This module is not meant to be run in forward() except while training. It has special logic which performs - # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially - # more lossy (but useful for determining network performance). - def forward( - self, - img - ): - logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) - sampled, commitment_loss, codes = self.codebook(logits) - sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) - - if self.training: - out = sampled - out = self.decoder(out) - else: - # This is non-differentiable, but gives a better idea of how the network is actually performing. - out = self.decode(codes) - - # reconstruction loss - recon_loss = F.mse_loss(img, out, reduction='none') - - # This is so we can debug the distribution of codes being learned. - if self.record_codes and self.internal_step % 50 == 0: - codes = codes.flatten() - l = codes.shape[0] - i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l - self.codes[i:i+l] = codes.cpu() - self.code_ind = self.code_ind + l - if self.code_ind >= self.codes.shape[0]: - self.code_ind = 0 - self.internal_step += 1 - - return recon_loss, commitment_loss, out - - -@register_model -def register_my_dvae(opt_net, opt): - return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {})) - - -if __name__ == '__main__': - net = DiscreteVAE((8, 16), channel_mult=(1,2,4,8,8), in_channels=80, model_channels=128, out_channels=80, dims=1, num_res_blocks=2) - inp = torch.randn((2,80,512)) - print([j.shape for j in net(inp)]) diff --git a/codes/scripts/audio/preparation/save_mels_to_disk.py b/codes/scripts/audio/preparation/save_mels_to_disk.py new file mode 100644 index 00000000..37c534b0 --- /dev/null +++ b/codes/scripts/audio/preparation/save_mels_to_disk.py @@ -0,0 +1,40 @@ +import argparse + +import numpy +import torch +from spleeter.audio.adapter import AudioAdapter +from tqdm import tqdm + +from data.util import find_audio_files +# Uses pydub to process a directory of audio files, splitting them into clips at points where it detects a small amount +# of silence. +from trainer.injectors.base_injectors import MelSpectrogramInjector + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--path') + args = parser.parse_args() + files = find_audio_files(args.path, include_nonwav=True) + mel_inj = MelSpectrogramInjector({'in':'in', 'out':'out'}, {}) + audio_loader = AudioAdapter.default() + for e, wav_file in enumerate(tqdm(files)): + if e < 272583: + continue + print(f"Processing {wav_file}..") + outfile = f'{wav_file}.npz' + + try: + wave, sample_rate = audio_loader.load(wav_file, sample_rate=22050) + wave = torch.tensor(wave)[:,0].unsqueeze(0) + wave = wave / wave.abs().max() + except: + print(f"Error with {wav_file}") + continue + + inj = mel_inj({'in': wave}) + numpy.savez_compressed(outfile, inj['out'].numpy()) + + +if __name__ == '__main__': + main() diff --git a/codes/scripts/audio/preparation/split_on_silence.py b/codes/scripts/audio/preparation/split_on_silence.py index a53f60d3..b05798a7 100644 --- a/codes/scripts/audio/preparation/split_on_silence.py +++ b/codes/scripts/audio/preparation/split_on_silence.py @@ -19,7 +19,7 @@ def main(): maximum_duration = 20 files = find_audio_files(args.path, include_nonwav=True) for e, wav_file in enumerate(tqdm(files)): - if e < 2759: + if e < 12593: continue print(f"Processing {wav_file}..") outdir = os.path.join(args.out, f'{e}_{os.path.basename(wav_file[:-4])}').replace('.', '').strip() diff --git a/codes/train.py b/codes/train.py index f4de0a21..e618b709 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_dvae_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()