From dc471f5c6df8ffc924a64f1c0698709cdbb97bb0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 23 May 2022 09:58:30 -0600 Subject: [PATCH] residual features --- codes/models/audio/mel2vec.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index d005d493..7e1996a0 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -364,10 +364,10 @@ class Mel2Vec(nn.Module): mask_time_length=10, disable_custom_linear_init=False, linear_init_scale=.02, - dim_reduction_multiplier=4, + feature_producer_type='standard', ): super().__init__() - if dim_reduction_multiplier == 4: + if feature_producer_type == 'standard': self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), nn.GELU(), @@ -376,7 +376,20 @@ class Mel2Vec(nn.Module): nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), nn.GELU(), ) - elif dim_reduction_multiplier == 8: + self.dim_reduction_mult = 4 + elif feature_producer_type == 'residual': + self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), + nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), + nn.GELU(), + ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), + ResBlock(dims=1, channels=inner_dim//2, dropout=dropout), + nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), + nn.GELU(), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ) + self.dim_reduction_mult = 4 + elif feature_producer_type == 'voice_8x': self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//4, kernel_size=5, padding=2, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim//4, affine=True), nn.GELU(), @@ -387,8 +400,9 @@ class Mel2Vec(nn.Module): nn.Conv1d(inner_dim, inner_dim, kernel_size=3, padding=1), nn.GELU(), ) + self.dim_reduction_mult = 8 else: - assert False, f"dim_reduction_multiplier={dim_reduction_multiplier} not supported" + assert False, f"feature_producer_type={feature_producer_type} not supported" self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout) self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,)) self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop) @@ -396,7 +410,6 @@ class Mel2Vec(nn.Module): self.mask_time_length = mask_time_length self.disable_custom_linear_init = disable_custom_linear_init self.linear_init_scale = linear_init_scale - self.dim_reduction_mult = dim_reduction_multiplier self.mel_dim = mel_input_channels self.apply(self.init) @@ -733,6 +746,6 @@ def register_mel2vec(opt_net, opt): if __name__ == '__main__': - model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True) + model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True, feature_producer_type='residual') mel = torch.randn((2,256,401)) print(model(mel)) \ No newline at end of file