diff --git a/codes/models/audio/tts/unet_diffusion_tts10.py b/codes/models/audio/tts/unet_diffusion_tts10.py index f417905a..01c5fe70 100644 --- a/codes/models/audio/tts/unet_diffusion_tts10.py +++ b/codes/models/audio/tts/unet_diffusion_tts10.py @@ -96,6 +96,7 @@ class DiffusionTts(nn.Module): use_fp16=False, kernel_size=3, scale_factor=2, + num_heads=None, time_embed_dim_multiplier=4, nil_guidance_fwd_proportion=.15, ): @@ -111,7 +112,7 @@ class DiffusionTts(nn.Module): self.dims = dims self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion self.mask_token_id = num_tokens - num_heads = model_channels // 64 + num_heads = model_channels // 64 if num_heads is None else num_heads padding = 1 if kernel_size == 3 else 2