Add a trainable network for converting a normal distribution into a latent space

This commit is contained in:
James Betker 2022-05-02 09:47:30 -06:00
parent e402089556
commit c42c53e75a
5 changed files with 121 additions and 2 deletions

View File

@ -0,0 +1,63 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from trainer.networks import register_model
from utils.util import opt_get
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
return out
class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)],
nn.Linear(channels, channels))
self.channels = channels
def forward(self, ref):
r = torch.randn(ref.shape[0], self.channels, device=ref.device)
y = self.layers(r)
return y
@register_model
def register_random_latent_converter(opt_net, opt):
return RandomLatentConverter(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
model = RandomLatentConverter(512)
model(torch.randn(5,512))

View File

@ -300,6 +300,14 @@ class DiffusionTtsFlat(nn.Module):
return out, mel_pred
return out
def get_conditioning_latent(self, conditioning_input):
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
return conds.mean(dim=-1)
@register_model
def register_diffusion_tts_flat(opt_net, opt):

View File

@ -376,6 +376,7 @@ class UnifiedVoice(nn.Module):
conds = conds.mean(dim=1).unsqueeze(1)
return conds
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False, clip_inputs=True):
"""

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen_diffuser.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -4,6 +4,7 @@ import torch
import torch.nn.functional as F
import torchaudio
from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat
from trainer.inject import Injector
from utils.util import opt_get, load_model_from_config, pad_or_truncate
@ -229,4 +230,50 @@ class ReverseUnivnetInjector(Injector):
labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5)
output = torch.where(labels, original_audio, decoded_mel)
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
return {self.output: output, self.label_output_key: labels[:,0,0].long()}
class ConditioningLatentDistributionDivergenceInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
if 'gpt_config' in opt.keys():
# The unified_voice model.
cfg = opt_get(opt, ['gpt_config'], "../experiments/train_gpt_tts_unified.yml")
model_name = opt_get(opt, ['gpt_name'], 'gpt')
pretrained_path = opt['gpt_path']
self.latent_producer = load_model_from_config(cfg, model_name=model_name,
also_load_savepoint=False, load_path=pretrained_path).eval()
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
else:
self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False,
num_heads=16, layer_drop=0, unconditioned_percentage=0).eval()
self.latent_producer.load_state_dict(torch.load(opt['diffusion_path']))
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_fmax': 12000, 'sampling_rate': 24000, 'n_mel_channels': 100},{})
self.needs_move = True
# Aux input keys.
self.conditioning_key = opt['conditioning_clip']
# Output keys
self.var_loss_key = opt['var_loss']
def to_mel(self, t):
return self.mel_inj({'wav': t})['mel']
def forward(self, state):
with torch.no_grad():
state_preds = state[self.input]
state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
mel_conds = []
for k in range(state_cond.shape[1]):
mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1)
if self.needs_move:
self.latent_producer = self.latent_producer.to(mel_conds.device)
latents = self.latent_producer.get_conditioning_latent(mel_conds)
sp_means, sp_vars = state_preds.mean(dim=0), state_preds.var(dim=0)
tr_means, tr_vars = latents.mean(dim=0), latents.var(dim=0)
mean_loss = F.mse_loss(sp_means, tr_means)
var_loss = F.mse_loss(sp_vars, tr_vars)
return {self.output: mean_loss, self.var_loss_key: var_loss}