Add 8x dim reductor
This commit is contained in:
parent
07f7be24ce
commit
41809a6330
|
@ -362,16 +362,31 @@ class Mel2Vec(nn.Module):
|
||||||
mask_time_length=10,
|
mask_time_length=10,
|
||||||
disable_custom_linear_init=False,
|
disable_custom_linear_init=False,
|
||||||
linear_init_scale=.02,
|
linear_init_scale=.02,
|
||||||
|
dim_reduction_multiplier=4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
|
if dim_reduction_multiplier == 4:
|
||||||
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
|
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
|
||||||
nn.GELU(),
|
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
|
||||||
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
|
nn.GELU(),
|
||||||
nn.GELU(),
|
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
|
||||||
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
|
nn.GELU(),
|
||||||
nn.GELU(),
|
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
|
||||||
)
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
elif dim_reduction_multiplier == 8:
|
||||||
|
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2),
|
||||||
|
nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv1d(inner_dim//4, inner_dim//2, kernel_size=3, padding=1, stride=2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"dim_reduction_multiplier={dim_reduction_multiplier} not supported"
|
||||||
self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout)
|
self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout)
|
||||||
self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,))
|
self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,))
|
||||||
self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)
|
self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user