This commit is contained in:
James Betker 2022-05-25 12:31:56 -06:00
parent 52a20f3aa3
commit 56f19a23cd

View File

@ -96,6 +96,7 @@ class DiffusionTts(nn.Module):
use_fp16=False, use_fp16=False,
kernel_size=3, kernel_size=3,
scale_factor=2, scale_factor=2,
num_heads=None,
time_embed_dim_multiplier=4, time_embed_dim_multiplier=4,
nil_guidance_fwd_proportion=.15, nil_guidance_fwd_proportion=.15,
): ):
@ -111,7 +112,7 @@ class DiffusionTts(nn.Module):
self.dims = dims self.dims = dims
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
self.mask_token_id = num_tokens self.mask_token_id = num_tokens
num_heads = model_channels // 64 num_heads = model_channels // 64 if num_heads is None else num_heads
padding = 1 if kernel_size == 3 else 2 padding = 1 if kernel_size == 3 else 2