From 0f3ca28e39c33d96baa78223513a24bac0cda4a3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 25 Jan 2022 14:26:21 -0700 Subject: [PATCH] Allow diffusion model to be trained with masking tokens --- .../unet_diffusion_tts_experimental.py | 44 +++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py index aba7d926..6e2cd9d8 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py +++ b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py @@ -115,7 +115,7 @@ class DiffusionTts(nn.Module): self, model_channels, in_channels=1, - num_tokens=30, + num_tokens=32, out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K @@ -135,6 +135,7 @@ class DiffusionTts(nn.Module): scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4, + nil_guidance_fwd_proportion=.3, ): super().__init__() @@ -153,6 +154,8 @@ class DiffusionTts(nn.Module): self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.dims = dims + self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion + self.mask_token_id = num_tokens padding = 1 if kernel_size == 3 else 2 @@ -183,7 +186,7 @@ class DiffusionTts(nn.Module): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Embedding(num_tokens, ch) + token_conditioning_block = nn.Embedding(num_tokens+1, ch) token_conditioning_block.weight.data.normal_(mean=0.0, std=.02) self.input_blocks.append(token_conditioning_block) token_conditioning_blocks.append(token_conditioning_block) @@ -286,6 +289,24 @@ class DiffusionTts(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) + def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', + strict: bool = True): + # Temporary hack to allow the addition of nil-guidance token embeddings to the existing guidance embeddings. + lsd = self.state_dict() + revised = 0 + for i, blk in enumerate(self.input_blocks): + if isinstance(blk, nn.Embedding): + key = f'input_blocks.{i}.weight' + if state_dict[key].shape[0] != lsd[key].shape[0]: + t = torch.randn_like(lsd[key]) * .02 + t[:state_dict[key].shape[0]] = state_dict[key] + state_dict[key] = t + revised += 1 + print(f"Loaded experimental unet_diffusion_net with {revised} modifications.") + return super().load_state_dict(state_dict, strict) + + + def forward(self, x, timesteps, tokens, conditioning_input=None): """ Apply the model to an input batch. @@ -307,11 +328,16 @@ class DiffusionTts(nn.Module): hs = [] emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.conditioning_enabled: - emb2 = self.contextual_embedder(conditioning_input) - emb = emb1 + emb2 + actual_cond = self.contextual_embedder(conditioning_input) + emb = emb1 + actual_cond else: emb = emb1 + # Mask out guidance tokens for un-guided diffusion. + if self.nil_guidance_fwd_proportion > 0: + token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion + tokens = torch.where(token_mask, self.mask_token_id, tokens) + h = x.type(self.dtype) for k, module in enumerate(self.input_blocks): if isinstance(module, nn.Embedding): @@ -370,13 +396,15 @@ def register_diffusion_tts_experimental(opt_net, opt): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': - clip = torch.randn(2, 1, 86016) - tok = torch.randint(0,30, (2,388)) - cond = torch.randn(2, 1, 44000) - ts = torch.LongTensor([555, 556]) + clip = torch.randn(4, 1, 86016) + tok = torch.randint(0,30, (4,388)) + cond = torch.randn(4, 1, 44000) + ts = torch.LongTensor([555, 556, 600, 600]) model = DiffusionTts(64, channel_mult=[1,1.5,2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 2, 4, 4, 4], token_conditioning_resolutions=[1,4,16,64], attention_resolutions=[256,512], num_heads=4, kernel_size=3, scale_factor=2, conditioning_inputs_provided=True, time_embed_dim_multiplier=4) + model(clip, ts, tok, cond) + p, r = model.benchmark(clip, ts, tok, cond) p = {k: v / 1000000000 for k, v in p.items()} p = sorted(p.items(), key=operator.itemgetter(1))