From 560b83e770e6428366251fcec10c7a02f9284dd1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 23 May 2022 12:24:00 -0600 Subject: [PATCH] default to residual encoder --- codes/models/audio/mel2vec.py | 36 ++++++++++------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index d2210ad1..93bf59f2 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -371,30 +371,6 @@ class Mel2Vec(nn.Module): self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), nn.GELU(), - nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), - nn.GELU(), - nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), - nn.GELU(), - ) - self.dim_reduction_mult = 4 - elif feature_producer_type == 'residual': - self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), - nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), - nn.GELU(), - ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), - ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), - nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), - nn.GELU(), - ResBlock(dims=1, channels=inner_dim, dropout=dropout), - ResBlock(dims=1, channels=inner_dim, dropout=dropout), - ) - self.dim_reduction_mult = 4 - elif feature_producer_type == 'deep_residual': - self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), - nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), - nn.GELU(), - ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), - ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), @@ -411,12 +387,20 @@ class Mel2Vec(nn.Module): self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True), nn.GELU(), + ResBlock(dims=1, channels=inner_dim//4, dropout=dropout), + ResBlock(dims=1, channels=inner_dim//4, dropout=dropout), nn.Conv1d(inner_dim//4, inner_dim//2, kernel_size=3, padding=1, stride=2), nn.GELU(), + ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), + ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), + ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), nn.GELU(), - nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), - nn.GELU(), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), ) self.dim_reduction_mult = 8 else: