Add conditioning-free guidance
This commit is contained in:
parent
ac920798bb
commit
436fe24822
|
@ -194,7 +194,9 @@ class DiffusionTts(nn.Module):
|
|||
time_embed_dim_multiplier=4,
|
||||
cond_transformer_depth=8,
|
||||
mid_transformer_depth=8,
|
||||
# Parameters for regularization.
|
||||
nil_guidance_fwd_proportion=.3,
|
||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||
# Parameters for super-sampling.
|
||||
super_sampling=False,
|
||||
super_sampling_max_noising_factor=.1,
|
||||
|
@ -226,6 +228,7 @@ class DiffusionTts(nn.Module):
|
|||
self.mask_token_id = num_tokens
|
||||
self.super_sampling_enabled = super_sampling
|
||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||
self.unconditioned_percentage = unconditioned_percentage
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||
|
@ -274,6 +277,7 @@ class DiffusionTts(nn.Module):
|
|||
cross_attend=self.enable_unaligned_inputs,
|
||||
)
|
||||
)
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,embedding_dim,1))
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -461,6 +465,11 @@ class DiffusionTts(nn.Module):
|
|||
else:
|
||||
code_emb = self.conditioning_encoder(code_emb)
|
||||
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb)
|
||||
|
||||
first = True
|
||||
time_emb = time_emb.float()
|
||||
h = x
|
||||
|
|
Loading…
Reference in New Issue
Block a user