default to residual encoder

This commit is contained in:
James Betker 2022-05-23 12:24:00 -06:00
parent f432bdf7ae
commit 560b83e770

View File

@ -371,30 +371,6 @@ class Mel2Vec(nn.Module):
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
nn.GELU(),
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
nn.GELU(),
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
nn.GELU(),
)
self.dim_reduction_mult = 4
elif feature_producer_type == 'residual':
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
)
self.dim_reduction_mult = 4
elif feature_producer_type == 'deep_residual':
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
@ -411,12 +387,20 @@ class Mel2Vec(nn.Module):
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim//4, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//4, dropout=dropout),
nn.Conv1d(inner_dim//4, inner_dim//2, kernel_size=3, padding=1, stride=2),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
nn.GELU(),
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
)
self.dim_reduction_mult = 8
else: