Add conditioning-free guidance
This commit is contained in:
parent
ac920798bb
commit
436fe24822
|
@ -194,7 +194,9 @@ class DiffusionTts(nn.Module):
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
cond_transformer_depth=8,
|
cond_transformer_depth=8,
|
||||||
mid_transformer_depth=8,
|
mid_transformer_depth=8,
|
||||||
|
# Parameters for regularization.
|
||||||
nil_guidance_fwd_proportion=.3,
|
nil_guidance_fwd_proportion=.3,
|
||||||
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for super-sampling.
|
# Parameters for super-sampling.
|
||||||
super_sampling=False,
|
super_sampling=False,
|
||||||
super_sampling_max_noising_factor=.1,
|
super_sampling_max_noising_factor=.1,
|
||||||
|
@ -226,6 +228,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.mask_token_id = num_tokens
|
self.mask_token_id = num_tokens
|
||||||
self.super_sampling_enabled = super_sampling
|
self.super_sampling_enabled = super_sampling
|
||||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
|
@ -274,6 +277,7 @@ class DiffusionTts(nn.Module):
|
||||||
cross_attend=self.enable_unaligned_inputs,
|
cross_attend=self.enable_unaligned_inputs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,embedding_dim,1))
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -461,6 +465,11 @@ class DiffusionTts(nn.Module):
|
||||||
else:
|
else:
|
||||||
code_emb = self.conditioning_encoder(code_emb)
|
code_emb = self.conditioning_encoder(code_emb)
|
||||||
|
|
||||||
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
if self.unconditioned_percentage > 0:
|
||||||
|
unconditioned_batches = torch.rand((code_emb.shape[0],1,1), device=code_emb.device) < self.unconditioned_percentage
|
||||||
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, code_emb.shape[2]), code_emb)
|
||||||
|
|
||||||
first = True
|
first = True
|
||||||
time_emb = time_emb.float()
|
time_emb = time_emb.float()
|
||||||
h = x
|
h = x
|
||||||
|
|
Loading…
Reference in New Issue
Block a user