residual features

This commit is contained in:
James Betker 2022-05-23 09:58:30 -06:00
parent 1f521d6a1d
commit dc471f5c6d

View File

@ -364,10 +364,10 @@ class Mel2Vec(nn.Module):
mask_time_length=10,
disable_custom_linear_init=False,
linear_init_scale=.02,
dim_reduction_multiplier=4,
feature_producer_type='standard',
):
super().__init__()
if dim_reduction_multiplier == 4:
if feature_producer_type == 'standard':
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
nn.GELU(),
@ -376,7 +376,20 @@ class Mel2Vec(nn.Module):
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
nn.GELU(),
)
elif dim_reduction_multiplier == 8:
self.dim_reduction_mult = 4
elif feature_producer_type == 'residual':
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
ResBlock(dims=1, channels=inner_dim//2, dropout=dropout),
nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2),
nn.GELU(),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout),
)
self.dim_reduction_mult = 4
elif feature_producer_type == 'voice_8x':
self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2),
nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True),
nn.GELU(),
@ -387,8 +400,9 @@ class Mel2Vec(nn.Module):
nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1),
nn.GELU(),
)
self.dim_reduction_mult = 8
else:
assert False, f"dim_reduction_multiplier={dim_reduction_multiplier} not supported"
assert False, f"feature_producer_type={feature_producer_type} not supported"
self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout)
self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,))
self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop)
@ -396,7 +410,6 @@ class Mel2Vec(nn.Module):
self.mask_time_length = mask_time_length
self.disable_custom_linear_init = disable_custom_linear_init
self.linear_init_scale = linear_init_scale
self.dim_reduction_mult = dim_reduction_multiplier
self.mel_dim = mel_input_channels
self.apply(self.init)
@ -733,6 +746,6 @@ def register_mel2vec(opt_net, opt):
if __name__ == '__main__':
model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True)
model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True, feature_producer_type='residual')
mel = torch.randn((2,256,401))
print(model(mel))