Enable collated data for diffusion purposes
This commit is contained in:
parent
dc9cd8c206
commit
bcd8cc51e1
|
@ -144,6 +144,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
return self[rv]
|
return self[rv]
|
||||||
orig_output = wav.shape[-1]
|
orig_output = wav.shape[-1]
|
||||||
orig_text_len = tseq.shape[0]
|
orig_text_len = tseq.shape[0]
|
||||||
|
orig_aligned_code_length = aligned_codes.shape[0]
|
||||||
if wav.shape[-1] != self.max_wav_len:
|
if wav.shape[-1] != self.max_wav_len:
|
||||||
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
wav = F.pad(wav, (0, self.max_wav_len - wav.shape[-1]))
|
||||||
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
# These codes are aligned to audio inputs, so make sure to pad them as well.
|
||||||
|
@ -154,6 +155,7 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
||||||
'real_text': text,
|
'real_text': text,
|
||||||
'padded_text': tseq,
|
'padded_text': tseq,
|
||||||
'aligned_codes': aligned_codes,
|
'aligned_codes': aligned_codes,
|
||||||
|
'aligned_code_lengths': orig_aligned_code_length,
|
||||||
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||||
'wav': wav,
|
'wav': wav,
|
||||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||||
|
|
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||||
|
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import get_mask_from_lengths
|
from utils.util import get_mask_from_lengths
|
||||||
from utils.util import checkpoint
|
from utils.util import checkpoint
|
||||||
|
@ -295,7 +296,12 @@ class DiffusionTts(nn.Module):
|
||||||
:param tokens: an aligned text input.
|
:param tokens: an aligned text input.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement.
|
orig_x_shape = x.shape[-1]
|
||||||
|
cm = ceil_multiple(x.shape[-1], 4096)
|
||||||
|
if cm != 0:
|
||||||
|
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||||
|
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||||
|
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
|
||||||
if self.conditioning_enabled:
|
if self.conditioning_enabled:
|
||||||
assert conditioning_input is not None
|
assert conditioning_input is not None
|
||||||
|
|
||||||
|
@ -320,7 +326,8 @@ class DiffusionTts(nn.Module):
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
h = module(h, emb)
|
h = module(h, emb)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h)
|
out = self.out(h)
|
||||||
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
def benchmark(self, x, timesteps, tokens, conditioning_input):
|
def benchmark(self, x, timesteps, tokens, conditioning_input):
|
||||||
profile = OrderedDict()
|
profile = OrderedDict()
|
||||||
|
|
|
@ -183,12 +183,36 @@ class ExtensibleTrainer(BaseModel):
|
||||||
o.zero_grad()
|
o.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
sort_key = opt_get(self.opt, ['train', 'sort_key'], None)
|
||||||
|
if sort_key is not None:
|
||||||
|
sort_indices = torch.sort(data[sort_key]).indices
|
||||||
|
else:
|
||||||
|
sort_indices = None
|
||||||
|
|
||||||
batch_factor = self.batch_factor if perform_micro_batching else 1
|
batch_factor = self.batch_factor if perform_micro_batching else 1
|
||||||
self.dstate = {}
|
self.dstate = {}
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
if sort_indices is not None:
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = [v[i] for i in sort_indices]
|
||||||
|
else:
|
||||||
|
v = v[sort_indices]
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
|
||||||
|
|
||||||
|
if opt_get(self.opt, ['train', 'auto_collate'], False):
|
||||||
|
for k, v in self.dstate.items():
|
||||||
|
if f'{k}_lengths' in self.dstate.keys():
|
||||||
|
for c in range(len(v)):
|
||||||
|
maxlen = self.dstate[f'{k}_lengths'][c].max()
|
||||||
|
if len(v[c].shape) == 2:
|
||||||
|
self.dstate[k][c] = self.dstate[k][c][:, :maxlen]
|
||||||
|
elif len(v[c].shape) == 3:
|
||||||
|
self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen]
|
||||||
|
elif len(v[c].shape) == 4:
|
||||||
|
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
||||||
|
|
||||||
|
|
||||||
def optimize_parameters(self, step, optimize=True):
|
def optimize_parameters(self, step, optimize=True):
|
||||||
# Some models need to make parametric adjustments per-step. Do that here.
|
# Some models need to make parametric adjustments per-step. Do that here.
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user