forked from mrq/DL-Art-School
add reconstruction loss to m2v
This commit is contained in:
parent
1d758c3bc8
commit
1f521d6a1d
|
@ -282,19 +282,13 @@ class Upsample(nn.Module):
|
|||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if factor is None:
|
||||
if dims == 1:
|
||||
self.factor = 4
|
||||
else:
|
||||
self.factor = 2
|
||||
else:
|
||||
self.factor = factor
|
||||
self.factor = factor
|
||||
if use_conv:
|
||||
ksize = 3
|
||||
pad = 1
|
||||
|
@ -399,11 +393,11 @@ class ResBlock(nn.Module):
|
|||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
self.h_upd = Upsample(channels, use_conv, dims)
|
||||
self.x_upd = Upsample(channels, use_conv, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
self.h_upd = Downsample(channels, use_conv, dims)
|
||||
self.x_upd = Downsample(channels, use_conv, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch import distributed
|
|||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from models.arch_util import ResBlock
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
@ -396,6 +397,7 @@ class Mel2Vec(nn.Module):
|
|||
self.disable_custom_linear_init = disable_custom_linear_init
|
||||
self.linear_init_scale = linear_init_scale
|
||||
self.dim_reduction_mult = dim_reduction_multiplier
|
||||
self.mel_dim = mel_input_channels
|
||||
self.apply(self.init)
|
||||
|
||||
def init(self, module):
|
||||
|
@ -571,6 +573,7 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100,
|
||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
||||
codebook_size=320, codebook_groups=2, freq_mask_percent=0, inp_length_multiplier=256,
|
||||
do_reconstruction_loss=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob,
|
||||
|
@ -590,6 +593,17 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
|
||||
self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim)
|
||||
|
||||
self.reconstruction = do_reconstruction_loss
|
||||
if do_reconstruction_loss:
|
||||
blocks = [[ResBlock(dims=1, channels=inner_dim, dropout=dropout),
|
||||
ResBlock(dims=1, channels=inner_dim, dropout=dropout, use_conv=True, up=True)] for _ in range(int(math.log2(self.m2v.dim_reduction_mult)))]
|
||||
blocks = sum(blocks, [])
|
||||
blocks.append(nn.Conv1d(inner_dim, self.m2v.mel_dim, kernel_size=3, padding=1))
|
||||
self.reconstruction_net = nn.Sequential(
|
||||
nn.Conv1d(self.quantizer.codevector_dim, inner_dim, kernel_size=3, padding=1),
|
||||
*blocks
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_contrastive_logits(
|
||||
target_features: torch.FloatTensor,
|
||||
|
@ -700,6 +714,11 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
num_codevectors = self.quantizer.num_codevectors
|
||||
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
||||
|
||||
if self.reconstruction:
|
||||
reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1))
|
||||
reconstruction_loss = F.mse_loss(reconstruction, mel)
|
||||
return contrastive_loss, diversity_loss, reconstruction_loss
|
||||
|
||||
return contrastive_loss, diversity_loss
|
||||
|
||||
|
||||
|
@ -714,6 +733,6 @@ def register_mel2vec(opt_net, opt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ContrastiveTrainingWrapper(freq_mask_percent=.5)
|
||||
model = ContrastiveTrainingWrapper(freq_mask_percent=.5, do_reconstruction_loss=True)
|
||||
mel = torch.randn((2,256,401))
|
||||
print(model(mel))
|
Loading…
Reference in New Issue
Block a user