Enable collated data for diffusion purposes

This commit is contained in:
James Betker 2022-01-19 00:35:08 -07:00
parent dc9cd8c206
commit bcd8cc51e1
3 changed files with 35 additions and 2 deletions

View File

@ -144,6 +144,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
return self[rv] return self[rv]
orig_output = wav.shape[-1] orig_output = wav.shape[-1]
orig_text_len = tseq.shape[0] orig_text_len = tseq.shape[0]
orig_aligned_code_length = aligned_codes.shape[0]
if wav.shape[-1] != self.max_wav_len: if wav.shape[-1] != self.max_wav_len:
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1])) wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
# These codes are aligned to audio inputs, so make sure to pad them as well. # These codes are aligned to audio inputs, so make sure to pad them as well.
@ -154,6 +155,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
'real_text': text, 'real_text': text,
'padded_text': tseq, 'padded_text': tseq,
'aligned_codes': aligned_codes, 'aligned_codes': aligned_codes,
'aligned_code_lengths': orig_aligned_code_length,
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long), 'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'wav': wav, 'wav': wav,
'wav_lengths': torch.tensor(orig_output, dtype=torch.long), 'wav_lengths': torch.tensor(orig_output, dtype=torch.long),

View File

@ -9,6 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import get_mask_from_lengths from utils.util import get_mask_from_lengths
from utils.util import checkpoint from utils.util import checkpoint
@ -295,7 +296,12 @@ class DiffusionTts(nn.Module):
:param tokens: an aligned text input. :param tokens: an aligned text input.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 4096)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
if self.conditioning_enabled: if self.conditioning_enabled:
assert conditioning_input is not None assert conditioning_input is not None
@ -320,7 +326,8 @@ class DiffusionTts(nn.Module):
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) out = self.out(h)
return out[:, :, :orig_x_shape]
def benchmark(self, x, timesteps, tokens, conditioning_input): def benchmark(self, x, timesteps, tokens, conditioning_input):
profile = OrderedDict() profile = OrderedDict()

View File

@ -183,12 +183,36 @@ class ExtensibleTrainer(BaseModel):
o.zero_grad() o.zero_grad()
torch.cuda.empty_cache() torch.cuda.empty_cache()
sort_key = opt_get(self.opt, ['train', 'sort_key'], None)
if sort_key is not None:
sort_indices = torch.sort(data[sort_key]).indices
else:
sort_indices = None
batch_factor = self.batch_factor if perform_micro_batching else 1 batch_factor = self.batch_factor if perform_micro_batching else 1
self.dstate = {} self.dstate = {}
for k, v in data.items(): for k, v in data.items():
if sort_indices is not None:
if isinstance(v, list):
v = [v[i] for i in sort_indices]
else:
v = v[sort_indices]
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)] self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
if opt_get(self.opt, ['train', 'auto_collate'], False):
for k, v in self.dstate.items():
if f'{k}_lengths' in self.dstate.keys():
for c in range(len(v)):
maxlen = self.dstate[f'{k}_lengths'][c].max()
if len(v[c].shape) == 2:
self.dstate[k][c] = self.dstate[k][c][:, :maxlen]
elif len(v[c].shape) == 3:
self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen]
elif len(v[c].shape) == 4:
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
def optimize_parameters(self, step, optimize=True): def optimize_parameters(self, step, optimize=True):
# Some models need to make parametric adjustments per-step. Do that here. # Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values(): for net in self.networks.values():