Add conditioning-free guidance

This commit is contained in:
James Betker 2022-02-27 15:00:06 -07:00
parent ac920798bb
commit 436fe24822

View File

@ -194,7 +194,9 @@ class DiffusionTts(nn.Module):
time_embed_dim_multiplier=4, time_embed_dim_multiplier=4,
cond_transformer_depth=8, cond_transformer_depth=8,
mid_transformer_depth=8, mid_transformer_depth=8,
# Parameters for regularization.
nil_guidance_fwd_proportion=.3, nil_guidance_fwd_proportion=.3,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for super-sampling. # Parameters for super-sampling.
super_sampling=False, super_sampling=False,
super_sampling_max_noising_factor=.1, super_sampling_max_noising_factor=.1,
@ -226,6 +228,7 @@ class DiffusionTts(nn.Module):
self.mask_token_id = num_tokens self.mask_token_id = num_tokens
self.super_sampling_enabled = super_sampling self.super_sampling_enabled = super_sampling
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
self.unconditioned_percentage = unconditioned_percentage
padding = 1 if kernel_size == 3 else 2 padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier time_embed_dim = model_channels * time_embed_dim_multiplier
@ -274,6 +277,7 @@ class DiffusionTts(nn.Module):
cross_attend=self.enable_unaligned_inputs, cross_attend=self.enable_unaligned_inputs,
) )
) )
self.unconditioned_embedding = nn.Parameter(torch.randn(1,embedding_dim,1))
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
@ -461,6 +465,11 @@ class DiffusionTts(nn.Module):
else: else:
code_emb = self.conditioning_encoder(code_emb) code_emb = self.conditioning_encoder(code_emb)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb)
first = True first = True
time_emb = time_emb.float() time_emb = time_emb.float()
h = x h = x