Add a trainable network for converting a normal distribution into a latent space
This commit is contained in:
parent
e402089556
commit
c42c53e75a
63
codes/models/audio/tts/random_latent_converter.py
Normal file
63
codes/models/audio/tts/random_latent_converter.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
||||
if bias is not None:
|
||||
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
||||
return (
|
||||
F.leaky_relu(
|
||||
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
|
||||
)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
return F.leaky_relu(input, negative_slope=0.2) * scale
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
else:
|
||||
self.bias = None
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
return out
|
||||
|
||||
|
||||
class RandomLatentConverter(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
|
||||
nn.Linear(channels, channels))
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, ref):
|
||||
r = torch.randn(ref.shape[0], self.channels, device=ref.device)
|
||||
y = self.layers(r)
|
||||
return y
|
||||
|
||||
|
||||
@register_model
|
||||
def register_random_latent_converter(opt_net, opt):
|
||||
return RandomLatentConverter(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = RandomLatentConverter(512)
|
||||
model(torch.randn(5,512))
|
|
@ -300,6 +300,14 @@ class DiffusionTtsFlat(nn.Module):
|
|||
return out, mel_pred
|
||||
return out
|
||||
|
||||
def get_conditioning_latent(self, conditioning_input):
|
||||
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
||||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
return conds.mean(dim=-1)
|
||||
|
||||
@register_model
|
||||
def register_diffusion_tts_flat(opt_net, opt):
|
||||
|
|
|
@ -376,6 +376,7 @@ class UnifiedVoice(nn.Module):
|
|||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
return conds
|
||||
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False, clip_inputs=True):
|
||||
"""
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen_diffuser.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
|
||||
from trainer.inject import Injector
|
||||
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
|
@ -229,4 +230,50 @@ class ReverseUnivnetInjector(Injector):
|
|||
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
|
||||
output = torch.where(labels, original_audio, decoded_mel)
|
||||
|
||||
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
|
||||
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
|
||||
|
||||
|
||||
class ConditioningLatentDistributionDivergenceInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
if 'gpt_config' in opt.keys():
|
||||
# The unified_voice model.
|
||||
cfg = opt_get(opt, ['gpt_config'], "../experiments/train_gpt_tts_unified.yml")
|
||||
model_name = opt_get(opt, ['gpt_name'], 'gpt')
|
||||
pretrained_path = opt['gpt_path']
|
||||
self.latent_producer = load_model_from_config(cfg, model_name=model_name,
|
||||
also_load_savepoint=False, load_path=pretrained_path).eval()
|
||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
||||
else:
|
||||
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
|
||||
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
|
||||
self.latent_producer.load_state_dict(torch.load(opt['diffusion_path']))
|
||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_fmax': 12000, 'sampling_rate': 24000, 'n_mel_channels': 100},{})
|
||||
self.needs_move = True
|
||||
# Aux input keys.
|
||||
self.conditioning_key = opt['conditioning_clip']
|
||||
# Output keys
|
||||
self.var_loss_key = opt['var_loss']
|
||||
|
||||
def to_mel(self, t):
|
||||
return self.mel_inj({'wav': t})['mel']
|
||||
|
||||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
state_preds = state[self.input]
|
||||
state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
|
||||
mel_conds = []
|
||||
for k in range(state_cond.shape[1]):
|
||||
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||
mel_conds = torch.stack(mel_conds, dim=1)
|
||||
|
||||
if self.needs_move:
|
||||
self.latent_producer = self.latent_producer.to(mel_conds.device)
|
||||
latents = self.latent_producer.get_conditioning_latent(mel_conds)
|
||||
|
||||
sp_means, sp_vars = state_preds.mean(dim=0), state_preds.var(dim=0)
|
||||
tr_means, tr_vars = latents.mean(dim=0), latents.var(dim=0)
|
||||
mean_loss = F.mse_loss(sp_means, tr_means)
|
||||
var_loss = F.mse_loss(sp_vars, tr_vars)
|
||||
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
||||
|
|
Loading…
Reference in New Issue
Block a user