add reconstruction loss to m2v

This commit is contained in:
James Betker 2022-05-23 09:28:41 -06:00
parent 1d758c3bc8
commit 1f521d6a1d
2 changed files with 26 additions and 13 deletions

View File

@ -282,19 +282,13 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None):
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if factor is None:
if dims == 1:
self.factor = 4
else:
self.factor = 2
else:
self.factor = factor
self.factor = factor
if use_conv:
ksize = 3
pad = 1
@ -399,11 +393,11 @@ class ResBlock(nn.Module):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, use_conv, dims)
self.x_upd = Upsample(channels, use_conv, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, use_conv, dims)
self.x_upd = Downsample(channels, use_conv, dims)
else:
self.h_upd = self.x_upd = nn.Identity()

View File

@ -12,6 +12,7 @@ from torch import distributed
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.deepspeed import is_deepspeed_zero3_enabled
from models.arch_util import ResBlock
from trainer.networks import register_model
from utils.util import checkpoint
@ -396,6 +397,7 @@ class Mel2Vec(nn.Module):
self.disable_custom_linear_init = disable_custom_linear_init
self.linear_init_scale = linear_init_scale
self.dim_reduction_mult = dim_reduction_multiplier
self.mel_dim = mel_input_channels
self.apply(self.init)
def init(self, module):
@ -571,6 +573,7 @@ class ContrastiveTrainingWrapper(nn.Module):
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100,
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
codebook_size=320, codebook_groups=2, freq_mask_percent=0, inp_length_multiplier=256,
do_reconstruction_loss=False,
**kwargs):
super().__init__()
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
@ -590,6 +593,17 @@ class ContrastiveTrainingWrapper(nn.Module):
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
self.reconstruction = do_reconstruction_loss
if do_reconstruction_loss:
blocks = [[ResBlock(dims=1, channels=inner_dim, dropout=dropout),
ResBlock(dims=1, channels=inner_dim, dropout=dropout, use_conv=True, up=True)] for _ in range(int(math.log2(self.m2v.dim_reduction_mult)))]
blocks = sum(blocks, [])
blocks.append(nn.Conv1d(inner_dim, self.m2v.mel_dim, kernel_size=3, padding=1))
self.reconstruction_net = nn.Sequential(
nn.Conv1d(self.quantizer.codevector_dim, inner_dim, kernel_size=3, padding=1),
*blocks
)
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
@ -700,6 +714,11 @@ class ContrastiveTrainingWrapper(nn.Module):
num_codevectors = self.quantizer.num_codevectors
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
if self.reconstruction:
reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1))
reconstruction_loss = F.mse_loss(reconstruction, mel)
return contrastive_loss, diversity_loss, reconstruction_loss
return contrastive_loss, diversity_loss
@ -714,6 +733,6 @@ def register_mel2vec(opt_net, opt):
if __name__ == '__main__':
model = ContrastiveTrainingWrapper(freq_mask_percent=.5)
model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True)
mel = torch.randn((2,256,401))
print(model(mel))