Mods to unet_diffusion_tts6 to support super resolution mode

This commit is contained in:
James Betker 2022-02-03 19:59:39 -07:00
parent 4249681c4b
commit bc506d4bcd

View File

@ -1,4 +1,5 @@
import functools
import random
from collections import OrderedDict
import torch
@ -176,16 +177,21 @@ class DiffusionTts(nn.Module):
num_heads_upsample=-1,
kernel_size=3,
scale_factor=2,
conditioning_inputs_provided=True,
time_embed_dim_multiplier=4,
transformer_depths=8,
cond_transformer_depth=8,
mid_transformer_depth=8,
nil_guidance_fwd_proportion=.3,
super_sampling=False,
max_positions=-1,
fully_disable_tokens_percent=0, # When specified, this percent of the time tokens are entirely ignored.
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -200,7 +206,9 @@ class DiffusionTts(nn.Module):
self.dims = dims
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
self.mask_token_id = num_tokens
self.super_sampling_enabled = super_sampling
self.max_positions = max_positions
self.fully_disable_tokens_percent = fully_disable_tokens_percent
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier
@ -212,17 +220,15 @@ class DiffusionTts(nn.Module):
embedding_dim = model_channels * 8
self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim)
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(in_channels, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*2, embedding_dim, 1)
self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(embedding_dim*2, embedding_dim, 1)
self.conditioning_encoder = CheckpointedXTransformerEncoder(
max_seq_len=-1, # Should be unused
use_pos_emb=False,
attn_layers=Encoder(
dim=embedding_dim,
depth=transformer_depths,
depth=cond_transformer_depth,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
@ -294,7 +300,7 @@ class DiffusionTts(nn.Module):
use_pos_emb=False,
attn_layers=Encoder(
dim=ch,
depth=transformer_depths,
depth=mid_transformer_depth,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
@ -378,36 +384,53 @@ class DiffusionTts(nn.Module):
def forward(self, x, timesteps, tokens, conditioning_input=None):
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param tokens: an aligned text input.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
:return: an [N x C x ...] Tensor of outputs.
"""
assert conditioning_input is not None
if self.super_sampling_enabled:
assert lr_input is not None
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
x = torch.cat([x, lr_input], dim=1)
if tokens is not None and self.fully_disable_tokens_percent > random.random():
tokens = None
if tokens is not None and self.max_positions > 0 and x.shape[-1] > self.max_positions:
proportion_x_removed = self.max_positions/x.shape[-1]
x = x[:,:,:self.max_positions] # TODO: extract random subsets of x (favored towards the front). This should help diversity in training.
tokens = tokens[:,:int(proportion_x_removed*tokens.shape[-1])]
with autocast(x.device.type):
orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 2048)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
if self.conditioning_enabled:
assert conditioning_input is not None
if tokens is not None:
tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1])))
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# Mask out guidance tokens for un-guided diffusion.
if self.training and self.nil_guidance_fwd_proportion > 0:
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
tokens = torch.where(token_mask, self.mask_token_id, tokens)
code_emb = self.code_embedding(tokens).permute(0,2,1)
if self.conditioning_enabled:
cond_emb = self.contextual_embedder(conditioning_input)
cond_emb = self.contextual_embedder(conditioning_input)
if tokens is not None:
# Mask out guidance tokens for un-guided diffusion.
if self.training and self.nil_guidance_fwd_proportion > 0:
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
tokens = torch.where(token_mask, self.mask_token_id, tokens)
code_emb = self.code_embedding(tokens).permute(0,2,1)
code_emb = self.conditioning_conv(torch.cat([cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]), code_emb], dim=1))
else:
code_emb = cond_emb.unsqueeze(-1)
code_emb = self.conditioning_encoder(code_emb)
first = True
@ -445,6 +468,7 @@ if __name__ == '__main__':
tok = torch.randint(0,30, (2,388))
cond = torch.randn(2, 1, 44000)
ts = torch.LongTensor([600, 600])
lr = torch.randn(2,1,10000)
model = DiffusionTts(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
@ -453,8 +477,8 @@ if __name__ == '__main__':
num_heads=8,
kernel_size=3,
scale_factor=2,
conditioning_inputs_provided=True,
time_embed_dim_multiplier=4)
model(clip, ts, tok, cond)
time_embed_dim_multiplier=4, super_sampling=True)
model(clip, ts, tok, cond, lr)
model(clip, ts, None, cond, lr)
torch.save(model.state_dict(), 'test_out.pth')