Initial commit

This commit is contained in:
James Betker 2022-01-27 23:19:29 -07:00
parent 051f500010
commit 5a958b4f4b
14 changed files with 3720 additions and 2 deletions

3
.gitignore vendored
View File

@ -127,3 +127,6 @@ dmypy.json
# Pyre type checker
.pyre/
.idea/*
.models/*

View File

@ -1,2 +1,41 @@
# tortoise-tts
A multi-voice TTS system trained with an emphasis on quality
# Tortoise-TTS
Tortoise TTS is an experimental text-to-speech program that uses recent machine learning techniques to generate
high-quality speech samples.
This repo contains all the code needed to run Tortoise TTS in inference mode.
## What's in a name?
I'm naming my speech-related repos after Mojave desert flora and fauna. Tortoise is a bit tongue in cheek: this model
is insanely slow. It leverages both an autoregressive speech alignment model and a diffusion model, both of which
are known for their slow inference. It also performs CLIP sampling, which slows things down even further. You can
expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardware. Still, the results are pretty cool.
## What the heck is this?
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together:
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
similar to the GPT model used by DALLE, except it operates on speech data.
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
decoding creates considerably better results.
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
a wav file.
## How do I use this?
<incoming>
## How do I train this?
Frankly - you don't. Building this model has been a labor of love for me, consuming most of my 6 RTX3090s worth of
resources for the better part of 6 months. It uses a dataset I've gathered, refined and transcribed that consists of
a lot of audio data which I cannot distribute because of copywrite or no open licenses.
With that said, I'm willing to help you out if you really want to give it a shot. DM me.

1
data/tokenizer.json Normal file
View File

@ -0,0 +1 @@
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}

168
do_tts.py Normal file
View File

@ -0,0 +1,168 @@
import argparse
import os
import random
import torch
import torch.nn.functional as F
import torchaudio
import yaml
from tqdm import tqdm
from models.arch_util import TorchMelSpectrogram
from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
from models.lucidrains_dvae import DiscreteVAE
from models.text_voice_clip import VoiceCLIP
from models.unified_voice import UnifiedVoice
from utils.audio import load_audio
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from utils.tokenizer import VoiceBpeTokenizer
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
def load_conditioning(path, sample_rate=22050, cond_length=44100):
rel_clip = load_audio(path, sample_rate)
gap = rel_clip.shape[-1] - cond_length
if gap < 0:
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
mel_clip = TorchMelSpectrogram()(rel_clip.unsqueeze(0)).squeeze(0)
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
def fix_autoregressive_output(codes, stop_token):
"""
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
trained on and what the autoregressive code generator creates (which has no padding or end).
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
and copying out the last few codes.
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
"""
# Strip off the autoregressive stop token and add padding.
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
print("No stop tokens found, enjoy that output of yours!")
return
else:
codes[stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[stm:] = 83
if stm - 3 < codes.shape[0]:
codes[-3] = 45
codes[-2] = 45
codes[-1] = 248
return codes
if __name__ == '__main__':
preselected_cond_voices = {
'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'],
'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'],
'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'],
'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'],
}
parser = argparse.ArgumentParser()
parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='./models/diffusion_vocoder.pth')
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='./models/dvae.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dan_carlin')
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=2)
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).eval()
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
stop_mel_token = autoregressive.stop_mel_token
print("Loading data..")
tokenizer = VoiceBpeTokenizer()
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
text = F.pad(text, (0,1)) # This may not be necessary.
cond_paths = preselected_cond_voices[args.cond_preset]
conds = []
for cond_path in cond_paths:
c, cond_wav = load_conditioning(cond_path, cond_length=132300)
conds.append(c)
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
with torch.no_grad():
print("Performing GPT inference..")
samples = []
for b in tqdm(range(args.num_batches)):
codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
samples = torch.cat(samples, dim=0)
del autoregressive
print("Loading CLIP..")
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).eval()
clip.load_state_dict(torch.load(args.clip_model_path))
print("Performing CLIP filtering..")
for i in range(samples.shape[0]):
samples[i] = fix_autoregressive_output(samples[i], stop_mel_token)
clip_results = clip(text.repeat(samples.shape[0], 1),
torch.full((samples.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
samples, torch.full((samples.shape[0],), fill_value=samples.shape[1]*1024, dtype=torch.long, device='cuda'),
return_loss=False)
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
# Delete the autoregressive and clip models to free up GPU memory
del samples, clip
print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).eval()
dvae.load_state_dict(torch.load(args.dvae_model_path))
print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).eval()
diffusion.load_state_dict(torch.load(args.diffusion_model_path))
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Performing vocoding..")
# Perform vocoding on each batch element separately: Vocoding is very memory (and compute!) intensive.
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256)
torchaudio.save(os.path.join(args.output_path, f'gpt_tts_output_{b}.wav'), wav.squeeze(0).cpu(), 22050)

319
models/arch_util.py Normal file
View File

@ -0,0 +1,319 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
def forward(self, x, mask=None):
if mask is not None:
return self._forward(x, mask)
else:
return self._forward(x)
def _forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv, mask)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
"""
def __init__(self, channels, use_conv, out_channels=None, factor=4):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.factor = factor
if use_conv:
ksize = 5
pad = 2
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
def forward(self, x):
assert x.shape[1] == self.channels
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
"""
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = factor
if use_conv:
self.op = nn.Conv1d(
self.channels, self.out_channels, ksize, stride=stride, padding=pad
)
else:
assert self.channels == self.out_channels
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
up=False,
down=False,
kernel_size=3,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False)
self.x_upd = Upsample(channels, False)
elif down:
self.h_upd = Downsample(channels, False)
self.x_upd = Downsample(channels, False)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(
channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
def forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AudioMiniEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
base_channels=128,
depth=2,
resnet_blocks=2,
attn_blocks=4,
num_attn_heads=4,
dropout=0,
downsample_factor=2,
kernel_size=3):
super().__init__()
self.init = nn.Sequential(
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
)
ch = base_channels
res = []
for l in range(depth):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.Conv1d(ch, embedding_dim, 1)
)
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
h = self.attn(h)
return h[:, :, 0]
class TorchMelSpectrogram(nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
sampling_rate=22050, normalize=False, mel_norm_file='data/mel_norms.pth'):
super().__init__()
# These are the default tacotron values for the MEL spectrogram.
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
win_length=self.win_length, power=2, normalized=normalize,
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
norm="slaney")
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file)
else:
self.mel_norms = None
def forward(self, inp):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
mel = self.mel_stft(inp)
# Perform dynamic range compression
mel = torch.log(torch.clamp(mel, min=1e-5))
if self.mel_norms is not None:
self.mel_norms = self.mel_norms.to(mel.device)
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

View File

@ -0,0 +1,468 @@
"""
This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
and an audio conditioning input. It has also been simplified somewhat.
Credit: https://github.com/openai/improved-diffusion
"""
import math
from abc import abstractmethod
import torch
import torch.nn as nn
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class TimestepResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
up=False,
down=False,
kernel_size=3,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
padding = 1 if kernel_size == 3 else (2 if kernel_size == 5 else 0)
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(
channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
def forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels, level):
super().__init__()
self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, channels, kernel_size=3))
self.level = level
"""
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
:param x: bxcxS waveform latent
:param codes: bxN discrete codes, N <= S
"""
def forward(self, x, dvae_in):
b, c, S = x.shape
_, q, N = dvae_in.shape
emb = self.intg(dvae_in)
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
return torch.cat([x, emb], dim=1)
class DiscreteDiffusionVocoder(nn.Module):
"""
The full UNet model with attention and timestep embedding.
Customized to be conditioned on a spectrogram prior.
:param in_channels: channels in the input Tensor.
:param spectrogram_channels: channels in the conditioning spectrogram.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
model_channels,
in_channels=1,
out_channels=2, # mean and variance
dvae_dim=512,
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048),
conv_resample=True,
dims=1,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
kernel_size=3,
scale_factor=2,
conditioning_inputs_provided=True,
time_embed_dim_multiplier=4,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.dims = dims
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
self.conditioning_enabled = conditioning_inputs_provided
if conditioning_inputs_provided:
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
seqlyr = TimestepEmbedSequential(
nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
)
seqlyr.level = 0
self.input_blocks = nn.ModuleList([seqlyr])
spectrogram_blocks = []
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
spec_cond_block = DiscreteSpectrogramConditioningBlock(dvae_dim, ch, 2 ** level)
self.input_blocks.append(spec_cond_block)
spectrogram_blocks.append(spec_cond_block)
ch *= 2
for _ in range(num_blocks):
layers = [
TimestepResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
)
)
layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.input_blocks.append(layer)
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
upblk = TimestepEmbedSequential(
TimestepResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
kernel_size=kernel_size,
)
if resblock_updown
else Downsample(
ch, conv_resample, out_channels=out_ch, factor=scale_factor
)
)
upblk.level = 2 ** level
self.input_blocks.append(upblk)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
TimestepResBlock(
ch,
time_embed_dim,
dropout,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
),
TimestepResBlock(
ch,
time_embed_dim,
dropout,
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_blocks + 1):
ich = input_block_chans.pop()
layers = [
TimestepResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
use_scale_shift_norm=use_scale_shift_norm,
kernel_size=kernel_size,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
)
)
if level and i == num_blocks:
out_ch = ch
layers.append(
TimestepResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
kernel_size=kernel_size,
)
if resblock_updown
else Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
)
ds //= 2
layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.output_blocks.append(layer)
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
)
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert x.shape[-1] % 2048 == 0 # This model operates at base//2048 at it's bottom levels, thus this requirement.
if self.conditioning_enabled:
assert conditioning_input is not None
hs = []
emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.conditioning_enabled:
emb2 = self.contextual_embedder(conditioning_input)
emb = emb1 + emb2
else:
emb = emb1
h = x.type(self.dtype)
for k, module in enumerate(self.input_blocks):
if isinstance(module, DiscreteSpectrogramConditioningBlock):
h = module(h, spectrogram)
else:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)
h = h.type(x.dtype)
return self.out(h)
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = torch.randn(2, 1, 40960)
spec = torch.randn(2,80,160)
cond = torch.randn(2, 1, 40960)
ts = torch.LongTensor([555, 556])
model = DiscreteDiffusionVocoder(model_channels=128, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8],
num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
dvae_dim=80)
print(model(clip, ts, spec, cond).shape)

390
models/lucidrains_dvae.py Normal file
View File

@ -0,0 +1,390 @@
import functools
from math import sqrt
import torch
import torch.distributed as distributed
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def default(val, d):
return val if val is not None else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# Quantizer implemented by the rosinality vqvae repo.
# Credit: https://github.com/rosinality/vq-vae-2-pytorch
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
self.balancing_heuristic = balancing_heuristic
self.codes = None
self.max_codes = 64000
self.codes_full = False
self.new_return_order = new_return_order
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input, return_soft_codes=False):
if self.balancing_heuristic and self.codes_full:
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
ep = self.embed.permute(1,0)
ea = self.embed_avg.permute(1,0)
rand_embed = torch.randn_like(ep) * mask
self.embed = (ep * ~mask + rand_embed).permute(1,0)
self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
self.codes = None
self.codes_full = False
flatten = input.reshape(-1, self.dim)
dist = (
flatten.pow(2).sum(1, keepdim=True)
- 2 * flatten @ self.embed
+ self.embed.pow(2).sum(0, keepdim=True)
)
soft_codes = -dist
_, embed_ind = soft_codes.max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.balancing_heuristic:
if self.codes is None:
self.codes = embed_ind.flatten()
else:
self.codes = torch.cat([self.codes, embed_ind.flatten()])
if len(self.codes) > self.max_codes:
self.codes = self.codes[-self.max_codes:]
self.codes_full = True
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(
embed_onehot_sum, alpha=1 - self.decay
)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
elif self.new_return_order:
return quantize, embed_ind, diff
else:
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
# values with the specified expected variance.
class DiscretizationLoss(nn.Module):
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
super().__init__()
self.discrete_bins = discrete_bins
self.dim = dim
self.dist = torch.distributions.Normal(0, scale=expected_variance)
if store_past > 0:
self.record_past = True
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
else:
self.record_past = False
def forward(self, x):
other_dims = set(range(len(x.shape)))-set([self.dim])
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
averaged = averaged - averaged.mean()
if self.record_past:
acc_count = self.accumulator.shape[0]
avg = averaged.detach().clone()
if self.accumulator_filled > 0:
averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
averaged / acc_count
# Also push averaged into the accumulator.
self.accumulator[self.accumulator_index] = avg
self.accumulator_index += 1
if self.accumulator_index >= acc_count:
self.accumulator_index *= 0
if self.accumulator_filled <= 0:
self.accumulator_filled += 1
return torch.sum(-self.dist.log_prob(averaged))
class ResBlock(nn.Module):
def __init__(self, chan, conv, activation):
super().__init__()
self.net = nn.Sequential(
conv(chan, chan, 3, padding = 1),
activation(),
conv(chan, chan, 3, padding = 1),
activation(),
conv(chan, chan, 1)
)
def forward(self, x):
return self.net(x) + x
class UpsampledConv(nn.Module):
def __init__(self, conv, *args, **kwargs):
super().__init__()
assert 'stride' in kwargs.keys()
self.stride = kwargs['stride']
del kwargs['stride']
self.conv = conv(*args, **kwargs)
def forward(self, x):
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
return self.conv(up)
# DiscreteVAE partially derived from lucidrains DALLE implementation
# Credit: https://github.com/lucidrains/DALLE-pytorch
class DiscreteVAE(nn.Module):
def __init__(
self,
positional_dims=2,
num_tokens = 512,
codebook_dim = 512,
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
stride = 2,
kernel_size = 4,
use_transposed_convs = True,
encoder_norm = False,
activation = 'relu',
smooth_l1_loss = False,
straight_through = False,
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
record_codes = False,
discretization_loss_averaging_steps = 100,
lr_quantizer_args = {},
):
super().__init__()
has_resblocks = num_resnet_blocks > 0
self.num_tokens = num_tokens
self.num_layers = num_layers
self.straight_through = straight_through
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2:
conv = nn.Conv2d
conv_transpose = nn.ConvTranspose2d
else:
conv = nn.Conv1d
conv_transpose = nn.ConvTranspose1d
if not use_transposed_convs:
conv_transpose = functools.partial(UpsampledConv, conv)
if activation == 'relu':
act = nn.ReLU
elif activation == 'silu':
act = nn.SiLU
else:
assert NotImplementedError()
enc_layers = []
dec_layers = []
if num_layers > 0:
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
dec_out_chans = dec_chans[-1]
innermost_dim = dec_chans[0]
else:
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
dec_out_chans = hidden_dim
innermost_dim = hidden_dim
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
enc_layers.append(ResBlock(innermost_dim, conv, act))
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
# take care of normalization within class
self.normalization = normalization
self.record_codes = record_codes
if record_codes:
self.codes = torch.zeros((1228800,), dtype=torch.long)
self.code_ind = 0
self.total_codes = 0
self.internal_step = 0
def norm(self, images):
if not self.normalization is not None:
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
def get_debug_values(self, step, __):
if self.record_codes and self.total_codes > 0:
# Report annealing schedule
return {'histogram_codes': self.codes[:self.total_codes]}
else:
return {}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes
def decode(
self,
img_seq
):
self.log_codes(img_seq)
if hasattr(self.codebook, 'embed_code'):
image_embeds = self.codebook.embed_code(img_seq)
else:
image_embeds = F.embedding(img_seq, self.codebook.codebook)
b, n, d = image_embeds.shape
kwargs = {}
if self.positional_dims == 1:
arrange = 'b n d -> b d n'
else:
h = w = int(sqrt(n))
arrange = 'b (h w) d -> b d h w'
kwargs = {'h': h, 'w': w}
image_embeds = rearrange(image_embeds, arrange, **kwargs)
images = [image_embeds]
for layer in self.decoder:
images.append(layer(images[-1]))
return images[-1], images[-2]
def infer(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, codes, commitment_loss = self.codebook(logits)
return self.decode(codes)
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
# more lossy (but useful for determining network performance).
def forward(
self,
img
):
img = self.norm(img)
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
sampled, codes, commitment_loss = self.codebook(logits)
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
if self.training:
out = sampled
for d in self.decoder:
out = d(out)
self.log_codes(codes)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out, _ = self.decode(codes)
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction='none')
return recon_loss, commitment_loss, out
def log_codes(self, codes):
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 10 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i:i+l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.total_codes += 1
self.internal_step += 1
if __name__ == '__main__':
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False)
r,l,o=v(torch.randn(1,80,256))
v.decode(torch.randint(0,8192,(1,256)))
print(o.shape, l.shape)

125
models/text_voice_clip.py Normal file
View File

@ -0,0 +1,125 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from models.transformer import Transformer
def exists(val):
return val is not None
def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
class VoiceCLIP(nn.Module):
"""
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
transcribed text.
Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
"""
def __init__(
self,
*,
dim_text=512,
dim_speech=512,
dim_latent=512,
num_text_tokens=256,
text_enc_depth=6,
text_seq_len=120,
text_heads=8,
num_speech_tokens=8192,
speech_enc_depth=6,
speech_heads=8,
speech_seq_len=250,
text_mask_percentage=0,
voice_mask_percentage=0,
wav_token_compression=1024,
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads)
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression
def forward(
self,
text,
text_lengths,
speech_tokens,
wav_lengths,
return_loss=False
):
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text = text[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.wav_token_compression
speech_tokens = speech_tokens[:, :max_mel_len]
b, device = text.shape[0], text.device
if self.training:
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
else:
text_mask = torch.ones_like(text.float()).bool()
voice_mask = torch.ones_like(speech_tokens.float()).bool()
text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
speech_emb = self.speech_emb(speech_tokens)
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
enc_text = self.text_transformer(text_emb, mask=text_mask)
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
text_latents = masked_mean(enc_text, text_mask, dim=1)
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
return sim
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
if __name__ == '__main__':
clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2)
clip(torch.randint(0,256,(2,120)),
torch.tensor([50,100]),
torch.randint(0,8192,(2,250)),
torch.tensor([101,102]),
return_loss=True)
nonloss = clip(torch.randint(0,256,(2,120)),
torch.tensor([50,100]),
torch.randint(0,8192,(2,250)),
torch.tensor([101,102]),
return_loss=False)
print(nonloss.shape)

219
models/transformer.py Normal file
View File

@ -0,0 +1,219 @@
from functools import partial
import torch
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding, broadcat
from torch import nn
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth = 1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
t = t / alpha
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
return (t * alpha).softmax(dim = dim)
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# classes
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
super().__init__()
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim = self.dim, keepdim = True).detach()
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn, sandwich = False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
x = self.fn(x, **kwargs)
return self.norm_out(x)
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout = 0., mult = 4.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# Attention
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head ** -0.5
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = q * self.scale
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
# main transformer class
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
causal = True,
heads = 8,
dim_head = 64,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
sparse_attn = False,
sandwich_norm = False,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
layers.append(nn.ModuleList([
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
]))
execute_type = SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {'mask': route_attn}
self.layers = execute_type(layers, args_route = attn_route_map)
def forward(self, x, **kwargs):
return self.layers(x, **kwargs)

530
models/unified_voice.py Normal file
View File

@ -0,0 +1,530 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
from models.arch_util import AttentionBlock
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class ResBlock(nn.Module):
"""
Basic residual convolutional block that uses GroupNorm.
"""
def __init__(self, chan):
super().__init__()
self.net = nn.Sequential(
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan),
nn.ReLU(),
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
nn.GroupNorm(chan//8, chan)
)
def forward(self, x):
return F.relu(self.net(x) + x)
class GPT2InferenceModel(GPT2PreTrainedModel):
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
super().__init__(config)
self.transformer = gpt
self.text_pos_embedding = text_pos_emb
self.embeddings = embeddings
self.lm_head = nn.Sequential(norm, linear)
# Model parallel
self.model_parallel = False
self.device_map = None
self.cached_mel_emb = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def store_mel_emb(self, mel_emb):
self.cached_mel_emb = mel_emb
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
if input_ids.shape[1] != 1:
text_inputs = input_ids[:, mel_len:]
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
else:
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class ConditioningEncoder(nn.Module):
def __init__(self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
def forward(self, x):
h = self.init(x)
h = self.attn(h)
return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=.02):
super().__init__()
self.emb = nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
def forward(self, x):
sl = x.shape[1]
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(vocab_size=256, # Unused.
n_positions=max_mel_seq_len+max_text_seq_len,
n_ctx=max_mel_seq_len+max_text_seq_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
None, None
class MelEncoder(nn.Module):
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
super().__init__()
self.channels = channels
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//16, channels//2),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels//8, channels),
nn.ReLU(),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
def forward(self, x):
for e in self.encoder:
x = e(x)
return x.permute(0,2,1)
class UnifiedVoice(nn.Module):
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True):
"""
Args:
layers: Number of layers in transformer stack.
model_dim: Operating dimensions of the transformer
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
max_text_tokens: Maximum number of text tokens that will be encountered by model.
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
number_text_tokens:
start_text_token:
stop_text_token:
number_mel_codes:
start_mel_token:
stop_mel_token:
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
"""
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.layers = layers
self.heads = heads
self.max_mel_tokens = max_mel_tokens
self.max_text_tokens = max_text_tokens
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
# Initialize the embeddings per the GPT-2 scheme
embeddings = [self.text_embedding]
if use_mel_codes_as_input:
embeddings.append(self.mel_embedding)
for module in embeddings:
module.weight.data.normal_(mean=0.0, std=.02)
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1,0), value=start_token)
tar = F.pad(input, (0,1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, wav_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_mel_token
return mel_input_tokens
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc)
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0,2,1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1]:]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0,2,1)
return first_logits, second_logits
else:
return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
speech_conditioning_input: MEL float tensor, (b,80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
raw_mels: MEL float tensor (b,80,s)
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 8))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions)
if return_attentions:
return mel_logits
loss_text = F.cross_entropy(text_logits, text_targets.long())
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_text.mean(), loss_mel.mean(), mel_logits
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
"""
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
"""
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
text_logits = self.get_logits(conds, text_emb, self.text_head)
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean()
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
"""
Performs autoregressive modeling on only speech data.
"""
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
mel_inp = F.pad(raw_mels, (0, 4))
else:
mel_inp = mel_codes
mel_emb = self.mel_embedding(mel_inp)
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
return loss_mel.mean()
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
if not hasattr(self, 'inference_model'):
# TODO: Decouple gpt_config from this inference model.
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True)
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
self.gpt.wte = self.mel_embedding
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.start_mel_token
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
max_length=seq_length, **hf_generate_kwargs)
return gen[:, fake_inputs.shape[1]:]
if __name__ == '__main__':
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
l = gpt(torch.randn(2, 3, 80, 800),
torch.randint(high=120, size=(2,120)),
torch.tensor([32, 120]),
torch.randint(high=8192, size=(2,250)),
torch.tensor([250*256,195*256]))
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
torch
torchaudio
rotary_embedding_torch
transformers
tokenizers
pyfastmp3decoder
inflect

44
utils/audio.py Normal file
View File

@ -0,0 +1,44 @@
import torch
import torchaudio
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
if data.dtype == np.int32:
norm_fix = 2 ** 31
elif data.dtype == np.int16:
norm_fix = 2 ** 15
elif data.dtype == np.float16 or data.dtype == np.float32:
norm_fix = 1.
else:
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
def load_audio(audiopath, sampling_rate):
if audiopath[-4:] == '.wav':
audio, lsr = load_wav_to_torch(audiopath)
elif audiopath[-4:] == '.mp3':
# https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
from pyfastmp3decoder.mp3decoder import load_mp3
audio, lsr = load_mp3(audiopath, sampling_rate)
audio = torch.FloatTensor(audio)
# Remove any channel data.
if len(audio.shape) > 1:
if audio.shape[0] < 5:
audio = audio[0]
else:
assert audio.shape[1] < 5
audio = audio[:, 0]
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
audio.clip_(-1, 1)
return audio.unsqueeze(0)

1232
utils/diffusion.py Normal file

File diff suppressed because it is too large Load Diff

173
utils/tokenizer.py Normal file
View File

@ -0,0 +1,173 @@
import re
import inflect
import torch
from tokenizers import Tokenizer
# Regular expression matching whitespace:
from unidecode import unidecode
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', '')
return text
class VoiceBpeTokenizer:
def __init__(self, vocab_file='data/tokenizer.json'):
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt):
txt = english_cleaners(txt)
return txt
def encode(self, txt):
txt = self.preprocess_text(txt)
txt = txt.replace(' ', '[SPACE]')
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
txt = txt.replace('[SPACE]', ' ')
txt = txt.replace('[STOP]', '')
txt = txt.replace('[UNK]', '')
return txt