From 436fe24822ab92d56cf1a00a56ed036995bfa357 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 27 Feb 2022 15:00:06 -0700 Subject: [PATCH] Add conditioning-free guidance --- codes/models/gpt_voice/unet_diffusion_tts7.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index 90bc12af..b54fbe8c 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -194,7 +194,9 @@ class DiffusionTts(nn.Module): time_embed_dim_multiplier=4, cond_transformer_depth=8, mid_transformer_depth=8, + # Parameters for regularization. nil_guidance_fwd_proportion=.3, + unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for super-sampling. super_sampling=False, super_sampling_max_noising_factor=.1, @@ -226,6 +228,7 @@ class DiffusionTts(nn.Module): self.mask_token_id = num_tokens self.super_sampling_enabled = super_sampling self.super_sampling_max_noising_factor = super_sampling_max_noising_factor + self.unconditioned_percentage = unconditioned_percentage padding = 1 if kernel_size == 3 else 2 time_embed_dim = model_channels * time_embed_dim_multiplier @@ -274,6 +277,7 @@ class DiffusionTts(nn.Module): cross_attend=self.enable_unaligned_inputs, ) ) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,embedding_dim,1)) self.input_blocks = nn.ModuleList( [ @@ -461,6 +465,11 @@ class DiffusionTts(nn.Module): else: code_emb = self.conditioning_encoder(code_emb) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb) + first = True time_emb = time_emb.float() h = x