Support CVVP & fix for major bug in API

This commit is contained in:
James Betker 2022-04-18 14:47:44 -06:00
parent a4bc51cb6d
commit f717d24b0b
5 changed files with 161 additions and 13 deletions

26
api.py
View File

@ -7,12 +7,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import progressbar import progressbar
from models.cvvp import CVVP
from models.diffusion_decoder import DiffusionTts from models.diffusion_decoder import DiffusionTts
from models.autoregressive import UnifiedVoice from models.autoregressive import UnifiedVoice
from tqdm import tqdm from tqdm import tqdm
from models.arch_util import TorchMelSpectrogram from models.arch_util import TorchMelSpectrogram
from models.text_voice_clip import VoiceCLIP from models.clvp import CLVP
from models.vocoder import UnivNetGenerator from models.vocoder import UnivNetGenerator
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
@ -175,11 +176,15 @@ class TextToSpeech:
average_conditioning_embeddings=True).cpu().eval() average_conditioning_embeddings=True).cpu().eval()
self.autoregressive_for_diffusion.load_state_dict(torch.load('.models/autoregressive.pth')) self.autoregressive_for_diffusion.load_state_dict(torch.load('.models/autoregressive.pth'))
self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
text_seq_len=350, text_heads=8, text_seq_len=350, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
use_xformers=True).cpu().eval() use_xformers=True).cpu().eval()
self.clip.load_state_dict(torch.load('.models/clip.pth')) self.clvp.load_state_dict(torch.load('.models/clip.pth'))
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
self.cvvp.load_state_dict(torch.load('.models/cvvp.pth'))
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
@ -216,6 +221,8 @@ class TextToSpeech:
def tts(self, text, voice_samples, k=1, def tts(self, text, voice_samples, k=1,
# autoregressive generation parameters follow # autoregressive generation parameters follow
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
# CLVP & CVVP parameters
clvp_cvvp_slider=.5,
# diffusion generation parameters follow # diffusion generation parameters follow
diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0,
**hf_generate_kwargs): **hf_generate_kwargs):
@ -253,15 +260,22 @@ class TextToSpeech:
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()
clip_results = [] clip_results = []
self.clip = self.clip.cuda() self.clvp = self.clvp.cuda()
self.cvvp = self.cvvp.cuda()
for batch in samples: for batch in samples:
for i in range(batch.shape[0]): for i in range(batch.shape[0]):
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
clip_results.append(self.clip(text.repeat(batch.shape[0], 1), batch, return_loss=False)) clvp = self.clvp(text.repeat(batch.shape[0], 1), batch, return_loss=False)
cvvp_accumulator = 0
for cl in range(conds.shape[1]):
cvvp_accumulator = cvvp_accumulator + self.cvvp(conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False )
cvvp = cvvp_accumulator / conds.shape[1]
clip_results.append(clvp * clvp_cvvp_slider + cvvp * (1-clvp_cvvp_slider))
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
best_results = samples[torch.topk(clip_results, k=k).indices] best_results = samples[torch.topk(clip_results, k=k).indices]
self.clip = self.clip.cpu() self.clvp = self.clvp.cpu()
self.cvvp = self.cvvp.cpu()
del samples del samples
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning

View File

@ -562,7 +562,8 @@ class UnifiedVoice(nn.Module):
logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() logits_processor = LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length max_length = trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token, gen = self.inference_model.generate(inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=max_length, logits_processor=logits_processor, **hf_generate_kwargs) max_length=max_length, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, **hf_generate_kwargs)
return gen[:, trunc_index:] return gen[:, trunc_index:]

View File

@ -16,7 +16,7 @@ def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.) t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
class VoiceCLIP(nn.Module): class CLVP(nn.Module):
""" """
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
transcribed text. transcribed text.
@ -141,7 +141,7 @@ class VoiceCLIP(nn.Module):
if __name__ == '__main__': if __name__ == '__main__':
clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2) clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
clip(torch.randint(0,256,(2,120)), clip(torch.randint(0,256,(2,120)),
torch.tensor([50,100]), torch.tensor([50,100]),
torch.randint(0,8192,(2,250)), torch.randint(0,8192,(2,250)),

View File

@ -0,0 +1,133 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from torch.utils.checkpoint import checkpoint
from models.arch_util import AttentionBlock
from models.xtransformers import ContinuousTransformerWrapper, Encoder
def exists(val):
return val is not None
def masked_mean(t, mask):
t = t.masked_fill(~mask, 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)
class CollapsingTransformer(nn.Module):
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=model_dim,
depth=depth,
heads=heads,
ff_dropout=dropout,
ff_mult=1,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
**encoder_kwargs,
))
self.pre_combiner = nn.Sequential(nn.Conv1d(model_dim, output_dims, 1),
AttentionBlock(output_dims, num_heads=heads, do_checkpoint=False),
nn.Conv1d(output_dims, output_dims, 1))
self.mask_percentage = mask_percentage
def forward(self, x, **transformer_kwargs):
h = self.transformer(x, **transformer_kwargs)
h = h.permute(0,2,1)
h = checkpoint(self.pre_combiner, h).permute(0,2,1)
if self.training:
mask = torch.rand_like(h.float()) > self.mask_percentage
else:
mask = torch.ones_like(h.float()).bool()
return masked_mean(h, mask)
class ConvFormatEmbedding(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.emb = nn.Embedding(*args, **kwargs)
def forward(self, x):
y = self.emb(x)
return y.permute(0,2,1)
class CVVP(nn.Module):
def __init__(
self,
model_dim=512,
transformer_heads=8,
dropout=.1,
conditioning_enc_depth=8,
cond_mask_percentage=0,
mel_channels=80,
mel_codes=None,
speech_enc_depth=8,
speech_mask_percentage=0,
latent_multiplier=1,
):
super().__init__()
latent_dim = latent_multiplier*model_dim
self.temperature = nn.Parameter(torch.tensor(1.))
self.cond_emb = nn.Sequential(nn.Conv1d(mel_channels, model_dim//2, kernel_size=5, stride=2, padding=2),
nn.Conv1d(model_dim//2, model_dim, kernel_size=3, stride=2, padding=1))
self.conditioning_transformer = CollapsingTransformer(model_dim, model_dim, transformer_heads, dropout, conditioning_enc_depth, cond_mask_percentage)
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None:
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(model_dim, latent_dim, transformer_heads, dropout, speech_enc_depth, speech_mask_percentage)
self.to_speech_latent = nn.Linear(latent_dim, latent_dim, bias=False)
def get_grad_norm_parameter_groups(self):
return {
'conditioning': list(self.conditioning_transformer.parameters()),
'speech': list(self.speech_transformer.parameters()),
}
def forward(
self,
mel_cond,
mel_input,
return_loss=False
):
cond_emb = self.cond_emb(mel_cond).permute(0,2,1)
enc_cond = self.conditioning_transformer(cond_emb)
cond_latents = self.to_conditioning_latent(enc_cond)
speech_emb = self.speech_emb(mel_input).permute(0,2,1)
enc_speech = self.speech_transformer(speech_emb)
speech_latents = self.to_speech_latent(enc_speech)
cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', cond_latents, speech_latents) * temp
return sim
sim = einsum('i d, j d -> i j', cond_latents, speech_latents) * temp
labels = torch.arange(cond_latents.shape[0], device=mel_input.device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
if __name__ == '__main__':
clvp = CVVP()
clvp(torch.randn(2,80,100),
torch.randn(2,80,95),
return_loss=True)

View File

@ -28,7 +28,7 @@ def split_and_recombine_text(texts, desired_length=200, max_len=300):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood.txt") parser.add_argument('--textfile', type=str, help='A file containing the text to read.', default="data/riding_hood2.txt")
parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) '
'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart') 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart')
parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/') parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/longform/')