diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 5901ba03..5be289d5 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -365,11 +365,11 @@ class Mel2Vec(nn.Module): super().__init__() self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), - nn.SiLU(), + nn.GELU(), nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), - nn.SiLU(), + nn.GELU(), nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), - nn.SiLU(), + nn.GELU(), ) self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout) self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,))