From c42c53e75a27b4ab82f9ee42180c23ac4d8425ee Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 2 May 2022 09:47:30 -0600 Subject: [PATCH] Add a trainable network for converting a normal distribution into a latent space --- .../audio/tts/random_latent_converter.py | 63 +++++++++++++++++++ .../audio/tts/unet_diffusion_tts_flat.py | 8 +++ codes/models/audio/tts/unified_voice2.py | 1 + codes/train.py | 2 +- codes/trainer/injectors/audio_injectors.py | 49 ++++++++++++++- 5 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 codes/models/audio/tts/random_latent_converter.py diff --git a/codes/models/audio/tts/random_latent_converter.py b/codes/models/audio/tts/random_latent_converter.py new file mode 100644 index 00000000..d4b5dd00 --- /dev/null +++ b/codes/models/audio/tts/random_latent_converter.py @@ -0,0 +1,63 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from trainer.networks import register_model +from utils.util import opt_get + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1 + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential(*[EqualLinear(channels, channels, lr_mul=.1) for _ in range(5)], + nn.Linear(channels, channels)) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +@register_model +def register_random_latent_converter(opt_net, opt): + return RandomLatentConverter(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = RandomLatentConverter(512) + model(torch.randn(5,512)) \ No newline at end of file diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index ef06396d..ce0ef9e5 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -300,6 +300,14 @@ class DiffusionTtsFlat(nn.Module): return out, mel_pred return out + def get_conditioning_latent(self, conditioning_input): + speech_conditioning_input = conditioning_input.unsqueeze(1) if len( + conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + return conds.mean(dim=-1) @register_model def register_diffusion_tts_flat(opt_net, opt): diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index eb481c91..d53e595a 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -376,6 +376,7 @@ class UnifiedVoice(nn.Module): conds = conds.mean(dim=1).unsqueeze(1) return conds + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, types=None, text_first=True, raw_mels=None, return_attentions=False, return_latent=False, clip_inputs=True): """ diff --git a/codes/train.py b/codes/train.py index f40435f5..01ee2044 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_tortoise_random_latent_gen_diffuser.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 4ed8d5c8..7181c32d 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F import torchaudio +from models.audio.tts.unet_diffusion_tts_flat import DiffusionTtsFlat from trainer.inject import Injector from utils.util import opt_get, load_model_from_config, pad_or_truncate @@ -229,4 +230,50 @@ class ReverseUnivnetInjector(Injector): labels = (torch.rand(mel.shape[0], 1, 1, device=mel.device) > .5) output = torch.where(labels, original_audio, decoded_mel) - return {self.output: output, self.label_output_key: labels[:,0,0].long()} \ No newline at end of file + return {self.output: output, self.label_output_key: labels[:,0,0].long()} + + +class ConditioningLatentDistributionDivergenceInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + if 'gpt_config' in opt.keys(): + # The unified_voice model. + cfg = opt_get(opt, ['gpt_config'], "../experiments/train_gpt_tts_unified.yml") + model_name = opt_get(opt, ['gpt_name'], 'gpt') + pretrained_path = opt['gpt_path'] + self.latent_producer = load_model_from_config(cfg, model_name=model_name, + also_load_savepoint=False, load_path=pretrained_path).eval() + self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{}) + else: + self.latent_producer = DiffusionTtsFlat(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, + num_heads=16, layer_drop=0, unconditioned_percentage=0).eval() + self.latent_producer.load_state_dict(torch.load(opt['diffusion_path'])) + self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_fmax': 12000, 'sampling_rate': 24000, 'n_mel_channels': 100},{}) + self.needs_move = True + # Aux input keys. + self.conditioning_key = opt['conditioning_clip'] + # Output keys + self.var_loss_key = opt['var_loss'] + + def to_mel(self, t): + return self.mel_inj({'wav': t})['mel'] + + def forward(self, state): + with torch.no_grad(): + state_preds = state[self.input] + state_cond = pad_or_truncate(state[self.conditioning_key], 132300) + mel_conds = [] + for k in range(state_cond.shape[1]): + mel_conds.append(self.to_mel(state_cond[:, k])) + mel_conds = torch.stack(mel_conds, dim=1) + + if self.needs_move: + self.latent_producer = self.latent_producer.to(mel_conds.device) + latents = self.latent_producer.get_conditioning_latent(mel_conds) + + sp_means, sp_vars = state_preds.mean(dim=0), state_preds.var(dim=0) + tr_means, tr_vars = latents.mean(dim=0), latents.var(dim=0) + mean_loss = F.mse_loss(sp_means, tr_means) + var_loss = F.mse_loss(sp_vars, tr_vars) + return {self.output: mean_loss, self.var_loss_key: var_loss}