Enable collated data for diffusion purposes
This commit is contained in:
parent
dc9cd8c206
commit
bcd8cc51e1
|
@ -144,6 +144,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
return self[rv]
|
||||
orig_output = wav.shape[-1]
|
||||
orig_text_len = tseq.shape[0]
|
||||
orig_aligned_code_length = aligned_codes.shape[0]
|
||||
if wav.shape[-1] != self.max_wav_len:
|
||||
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
||||
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
||||
|
@ -154,6 +155,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
'real_text': text,
|
||||
'padded_text': tseq,
|
||||
'aligned_codes': aligned_codes,
|
||||
'aligned_code_lengths': orig_aligned_code_length,
|
||||
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||
'wav': wav,
|
||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import get_mask_from_lengths
|
||||
from utils.util import checkpoint
|
||||
|
@ -295,7 +296,12 @@ class DiffusionTts(nn.Module):
|
|||
:param tokens: an aligned text input.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 4096)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
|
||||
if self.conditioning_enabled:
|
||||
assert conditioning_input is not None
|
||||
|
||||
|
@ -320,7 +326,8 @@ class DiffusionTts(nn.Module):
|
|||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
out = self.out(h)
|
||||
return out[:, :, :orig_x_shape]
|
||||
|
||||
def benchmark(self, x, timesteps, tokens, conditioning_input):
|
||||
profile = OrderedDict()
|
||||
|
|
|
@ -183,12 +183,36 @@ class ExtensibleTrainer(BaseModel):
|
|||
o.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
sort_key = opt_get(self.opt, ['train', 'sort_key'], None)
|
||||
if sort_key is not None:
|
||||
sort_indices = torch.sort(data[sort_key]).indices
|
||||
else:
|
||||
sort_indices = None
|
||||
|
||||
batch_factor = self.batch_factor if perform_micro_batching else 1
|
||||
self.dstate = {}
|
||||
for k, v in data.items():
|
||||
if sort_indices is not None:
|
||||
if isinstance(v, list):
|
||||
v = [v[i] for i in sort_indices]
|
||||
else:
|
||||
v = v[sort_indices]
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
||||
|
||||
if opt_get(self.opt, ['train', 'auto_collate'], False):
|
||||
for k, v in self.dstate.items():
|
||||
if f'{k}_lengths' in self.dstate.keys():
|
||||
for c in range(len(v)):
|
||||
maxlen = self.dstate[f'{k}_lengths'][c].max()
|
||||
if len(v[c].shape) == 2:
|
||||
self.dstate[k][c] = self.dstate[k][c][:, :maxlen]
|
||||
elif len(v[c].shape) == 3:
|
||||
self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen]
|
||||
elif len(v[c].shape) == 4:
|
||||
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
||||
|
||||
|
||||
def optimize_parameters(self, step, optimize=True):
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
|
|
Loading…
Reference in New Issue
Block a user