From 1f521d6a1d29ca363c4a0cb669e7624c12bd94ef Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 23 May 2022 09:28:41 -0600 Subject: [PATCH] add reconstruction loss to m2v --- codes/models/arch_util.py | 18 ++++++------------ codes/models/audio/mel2vec.py | 21 ++++++++++++++++++++- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 13c5f2f8..07669961 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -282,19 +282,13 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims - if factor is None: - if dims == 1: - self.factor = 4 - else: - self.factor = 2 - else: - self.factor = factor + self.factor = factor if use_conv: ksize = 3 pad = 1 @@ -399,11 +393,11 @@ class ResBlock(nn.Module): self.updown = up or down if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) + self.h_upd = Upsample(channels, use_conv, dims) + self.x_upd = Upsample(channels, use_conv, dims) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, use_conv, dims) + self.x_upd = Downsample(channels, use_conv, dims) else: self.h_upd = self.x_upd = nn.Identity() diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index bd2ff45d..d005d493 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -12,6 +12,7 @@ from torch import distributed from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices from transformers.deepspeed import is_deepspeed_zero3_enabled +from models.arch_util import ResBlock from trainer.networks import register_model from utils.util import checkpoint @@ -396,6 +397,7 @@ class Mel2Vec(nn.Module): self.disable_custom_linear_init = disable_custom_linear_init self.linear_init_scale = linear_init_scale self.dim_reduction_mult = dim_reduction_multiplier + self.mel_dim = mel_input_channels self.apply(self.init) def init(self, module): @@ -571,6 +573,7 @@ class ContrastiveTrainingWrapper(nn.Module): def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, codebook_size=320, codebook_groups=2, freq_mask_percent=0, inp_length_multiplier=256, + do_reconstruction_loss=False, **kwargs): super().__init__() self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, @@ -590,6 +593,17 @@ class ContrastiveTrainingWrapper(nn.Module): self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim) + self.reconstruction = do_reconstruction_loss + if do_reconstruction_loss: + blocks = [[ResBlock(dims=1, channels=inner_dim, dropout=dropout), + ResBlock(dims=1, channels=inner_dim, dropout=dropout, use_conv=True, up=True)] for _ in range(int(math.log2(self.m2v.dim_reduction_mult)))] + blocks = sum(blocks, []) + blocks.append(nn.Conv1d(inner_dim, self.m2v.mel_dim, kernel_size=3, padding=1)) + self.reconstruction_net = nn.Sequential( + nn.Conv1d(self.quantizer.codevector_dim, inner_dim, kernel_size=3, padding=1), + *blocks + ) + @staticmethod def compute_contrastive_logits( target_features: torch.FloatTensor, @@ -700,6 +714,11 @@ class ContrastiveTrainingWrapper(nn.Module): num_codevectors = self.quantizer.num_codevectors diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors + if self.reconstruction: + reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1)) + reconstruction_loss = F.mse_loss(reconstruction, mel) + return contrastive_loss, diversity_loss, reconstruction_loss + return contrastive_loss, diversity_loss @@ -714,6 +733,6 @@ def register_mel2vec(opt_net, opt): if __name__ == '__main__': - model = ContrastiveTrainingWrapper(freq_mask_percent=.5) + model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True) mel = torch.randn((2,256,401)) print(model(mel)) \ No newline at end of file