diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 5af7b50d..60e64d14 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -71,7 +71,7 @@ class DiffusionTtsFlat(nn.Module): attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, - rotary_emb_dim=True, + rotary_pos_emb=True, ) ) ) @@ -91,7 +91,7 @@ class DiffusionTtsFlat(nn.Module): attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, - rotary_emb_dim=True, + rotary_pos_emb=True, ) )) self.conditioning_conv = nn.Conv1d(model_channels*2, model_channels, 1) @@ -110,7 +110,7 @@ class DiffusionTtsFlat(nn.Module): attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, - rotary_emb_dim=True, + rotary_pos_emb=True, layerdrop_percent=0, ) ) @@ -130,9 +130,10 @@ class DiffusionTtsFlat(nn.Module): attn_dropout=dropout, use_rmsnorm=True, ff_glu=True, - rotary_emb_dim=True, + rotary_pos_emb=True, layerdrop_percent=layer_drop, zero_init_branch_output=True, + sandwich_coef=4, ) ) self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder.