From bcd8cc51e1b7bdb4d9e35690a6254e0e118700f3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 19 Jan 2022 00:35:08 -0700 Subject: [PATCH] Enable collated data for diffusion purposes --- codes/data/audio/fast_paired_dataset.py | 2 ++ .../unet_diffusion_tts_experimental.py | 11 +++++++-- codes/trainer/ExtensibleTrainer.py | 24 +++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py index 20811f8c..074cd647 100644 --- a/codes/data/audio/fast_paired_dataset.py +++ b/codes/data/audio/fast_paired_dataset.py @@ -144,6 +144,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): return self[rv] orig_output = wav.shape[-1] orig_text_len = tseq.shape[0] + orig_aligned_code_length = aligned_codes.shape[0] if wav.shape[-1] != self.max_wav_len: wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) # These codes are aligned to audio inputs, so make sure to pad them as well. @@ -154,6 +155,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset): 'real_text': text, 'padded_text': tseq, 'aligned_codes': aligned_codes, + 'aligned_code_lengths': orig_aligned_code_length, 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), 'wav': wav, 'wav_lengths': torch.tensor(orig_output, dtype=torch.long), diff --git a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py index 91d88986..b092180d 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py +++ b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner +from scripts.audio.gen.use_diffuse_tts import ceil_multiple from trainer.networks import register_model from utils.util import get_mask_from_lengths from utils.util import checkpoint @@ -295,7 +296,12 @@ class DiffusionTts(nn.Module): :param tokens: an aligned text input. :return: an [N x C x ...] Tensor of outputs. """ - assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. + orig_x_shape = x.shape[-1] + cm = ceil_multiple(x.shape[-1], 4096) + if cm != 0: + pc = (cm-x.shape[-1])/x.shape[-1] + x = F.pad(x, (0,cm-x.shape[-1])) + tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) if self.conditioning_enabled: assert conditioning_input is not None @@ -320,7 +326,8 @@ class DiffusionTts(nn.Module): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(x.dtype) - return self.out(h) + out = self.out(h) + return out[:, :, :orig_x_shape] def benchmark(self, x, timesteps, tokens, conditioning_input): profile = OrderedDict() diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 1df4b507..b7079941 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -183,12 +183,36 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() + sort_key = opt_get(self.opt, ['train', 'sort_key'], None) + if sort_key is not None: + sort_indices = torch.sort(data[sort_key]).indices + else: + sort_indices = None + batch_factor = self.batch_factor if perform_micro_batching else 1 self.dstate = {} for k, v in data.items(): + if sort_indices is not None: + if isinstance(v, list): + v = [v[i] for i in sort_indices] + else: + v = v[sort_indices] if isinstance(v, torch.Tensor): self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] + if opt_get(self.opt, ['train', 'auto_collate'], False): + for k, v in self.dstate.items(): + if f'{k}_lengths' in self.dstate.keys(): + for c in range(len(v)): + maxlen = self.dstate[f'{k}_lengths'][c].max() + if len(v[c].shape) == 2: + self.dstate[k][c] = self.dstate[k][c][:, :maxlen] + elif len(v[c].shape) == 3: + self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen] + elif len(v[c].shape) == 4: + self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen] + + def optimize_parameters(self, step, optimize=True): # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values():