From 41809a633011d41cb4bfda79906eef740227310d Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 22 May 2022 20:23:16 -0600 Subject: [PATCH] Add 8x dim reductor --- codes/models/audio/mel2vec.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 6a5ddb3e..42e6ba31 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -362,16 +362,31 @@ class Mel2Vec(nn.Module): mask_time_length=10, disable_custom_linear_init=False, linear_init_scale=.02, + dim_reduction_multiplier=4, ): super().__init__() - self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), - nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), - nn.GELU(), - nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), - nn.GELU(), - nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), - nn.GELU(), - ) + if dim_reduction_multiplier == 4: + self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), + nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), + nn.GELU(), + nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), + nn.GELU(), + nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + elif dim_reduction_multiplier == 8: + self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2), + nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True), + nn.GELU(), + nn.Conv1d(inner_dim//4, inner_dim//2, kernel_size=3, padding=1, stride=2), + nn.GELU(), + nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), + nn.GELU(), + nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), + nn.GELU(), + ) + else: + assert False, f"dim_reduction_multiplier={dim_reduction_multiplier} not supported" self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout) self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,)) self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)