From bc506d4bcd7277a90409a8c5795e94fc8fbbc5a4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Feb 2022 19:59:39 -0700 Subject: [PATCH] Mods to unet_diffusion_tts6 to support super resolution mode --- codes/models/gpt_voice/unet_diffusion_tts6.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts6.py b/codes/models/gpt_voice/unet_diffusion_tts6.py index bf6e6215..a965b7fe 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts6.py +++ b/codes/models/gpt_voice/unet_diffusion_tts6.py @@ -1,4 +1,5 @@ import functools +import random from collections import OrderedDict import torch @@ -176,16 +177,21 @@ class DiffusionTts(nn.Module): num_heads_upsample=-1, kernel_size=3, scale_factor=2, - conditioning_inputs_provided=True, time_embed_dim_multiplier=4, - transformer_depths=8, + cond_transformer_depth=8, + mid_transformer_depth=8, nil_guidance_fwd_proportion=.3, + super_sampling=False, + max_positions=-1, + fully_disable_tokens_percent=0, # When specified, this percent of the time tokens are entirely ignored. ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads + if super_sampling: + in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -200,7 +206,9 @@ class DiffusionTts(nn.Module): self.dims = dims self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion self.mask_token_id = num_tokens - + self.super_sampling_enabled = super_sampling + self.max_positions = max_positions + self.fully_disable_tokens_percent = fully_disable_tokens_percent padding = 1 if kernel_size == 3 else 2 time_embed_dim = model_channels * time_embed_dim_multiplier @@ -212,17 +220,15 @@ class DiffusionTts(nn.Module): embedding_dim = model_channels * 8 self.code_embedding = nn.Embedding(num_tokens+1, embedding_dim) - self.conditioning_enabled = conditioning_inputs_provided - if conditioning_inputs_provided: - self.contextual_embedder = AudioMiniEncoder(in_channels, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_conv = nn.Conv1d(embedding_dim*2, embedding_dim, 1) + self.contextual_embedder = AudioMiniEncoder(1, embedding_dim, base_channels=32, depth=6, resnet_blocks=1, + attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) + self.conditioning_conv = nn.Conv1d(embedding_dim*2, embedding_dim, 1) self.conditioning_encoder = CheckpointedXTransformerEncoder( max_seq_len=-1, # Should be unused use_pos_emb=False, attn_layers=Encoder( dim=embedding_dim, - depth=transformer_depths, + depth=cond_transformer_depth, heads=num_heads, ff_dropout=dropout, attn_dropout=dropout, @@ -294,7 +300,7 @@ class DiffusionTts(nn.Module): use_pos_emb=False, attn_layers=Encoder( dim=ch, - depth=transformer_depths, + depth=mid_transformer_depth, heads=num_heads, ff_dropout=dropout, attn_dropout=dropout, @@ -378,36 +384,53 @@ class DiffusionTts(nn.Module): - def forward(self, x, timesteps, tokens, conditioning_input=None): + def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param tokens: an aligned text input. + :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. + :param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate. :return: an [N x C x ...] Tensor of outputs. """ + assert conditioning_input is not None + if self.super_sampling_enabled: + assert lr_input is not None + lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest') + x = torch.cat([x, lr_input], dim=1) + + if tokens is not None and self.fully_disable_tokens_percent > random.random(): + tokens = None + + if tokens is not None and self.max_positions > 0 and x.shape[-1] > self.max_positions: + proportion_x_removed = self.max_positions/x.shape[-1] + x = x[:,:,:self.max_positions] # TODO: extract random subsets of x (favored towards the front). This should help diversity in training. + tokens = tokens[:,:int(proportion_x_removed*tokens.shape[-1])] + with autocast(x.device.type): orig_x_shape = x.shape[-1] cm = ceil_multiple(x.shape[-1], 2048) if cm != 0: pc = (cm-x.shape[-1])/x.shape[-1] x = F.pad(x, (0,cm-x.shape[-1])) - tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) - if self.conditioning_enabled: - assert conditioning_input is not None + if tokens is not None: + tokens = F.pad(tokens, (0,int(pc*tokens.shape[-1]))) hs = [] time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - # Mask out guidance tokens for un-guided diffusion. - if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion - tokens = torch.where(token_mask, self.mask_token_id, tokens) - code_emb = self.code_embedding(tokens).permute(0,2,1) - if self.conditioning_enabled: - cond_emb = self.contextual_embedder(conditioning_input) + cond_emb = self.contextual_embedder(conditioning_input) + if tokens is not None: + # Mask out guidance tokens for un-guided diffusion. + if self.training and self.nil_guidance_fwd_proportion > 0: + token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion + tokens = torch.where(token_mask, self.mask_token_id, tokens) + code_emb = self.code_embedding(tokens).permute(0,2,1) code_emb = self.conditioning_conv(torch.cat([cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1]), code_emb], dim=1)) + else: + code_emb = cond_emb.unsqueeze(-1) code_emb = self.conditioning_encoder(code_emb) first = True @@ -445,6 +468,7 @@ if __name__ == '__main__': tok = torch.randint(0,30, (2,388)) cond = torch.randn(2, 1, 44000) ts = torch.LongTensor([600, 600]) + lr = torch.randn(2,1,10000) model = DiffusionTts(128, channel_mult=[1,1.5,2, 3, 4, 6, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 1], @@ -453,8 +477,8 @@ if __name__ == '__main__': num_heads=8, kernel_size=3, scale_factor=2, - conditioning_inputs_provided=True, - time_embed_dim_multiplier=4) - model(clip, ts, tok, cond) + time_embed_dim_multiplier=4, super_sampling=True) + model(clip, ts, tok, cond, lr) + model(clip, ts, None, cond, lr) torch.save(model.state_dict(), 'test_out.pth')