fix nh
This commit is contained in:
parent
52a20f3aa3
commit
56f19a23cd
|
@ -96,6 +96,7 @@ class DiffusionTts(nn.Module):
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
scale_factor=2,
|
scale_factor=2,
|
||||||
|
num_heads=None,
|
||||||
time_embed_dim_multiplier=4,
|
time_embed_dim_multiplier=4,
|
||||||
nil_guidance_fwd_proportion=.15,
|
nil_guidance_fwd_proportion=.15,
|
||||||
):
|
):
|
||||||
|
@ -111,7 +112,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
||||||
self.mask_token_id = num_tokens
|
self.mask_token_id = num_tokens
|
||||||
num_heads = model_channels // 64
|
num_heads = model_channels // 64 if num_heads is None else num_heads
|
||||||
|
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user