forked from mrq/tortoise-tts
Initial commit
This commit is contained in:
parent
e52926391c
commit
e16ab82597
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -127,3 +127,6 @@ dmypy.json
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
|
.idea/*
|
||||||
|
.models/*
|
||||||
|
|
43
README.md
43
README.md
|
@ -1,2 +1,41 @@
|
||||||
# tortoise-tts
|
# Tortoise-TTS
|
||||||
A multi-voice TTS system trained with an emphasis on quality
|
|
||||||
|
Tortoise TTS is an experimental text-to-speech program that uses recent machine learning techniques to generate
|
||||||
|
high-quality speech samples.
|
||||||
|
|
||||||
|
This repo contains all the code needed to run Tortoise TTS in inference mode.
|
||||||
|
|
||||||
|
## What's in a name?
|
||||||
|
|
||||||
|
I'm naming my speech-related repos after Mojave desert flora and fauna. Tortoise is a bit tongue in cheek: this model
|
||||||
|
is insanely slow. It leverages both an autoregressive speech alignment model and a diffusion model, both of which
|
||||||
|
are known for their slow inference. It also performs CLIP sampling, which slows things down even further. You can
|
||||||
|
expect ~5 seconds of speech to take ~30 seconds to produce on the latest hardware. Still, the results are pretty cool.
|
||||||
|
|
||||||
|
## What the heck is this?
|
||||||
|
|
||||||
|
Tortoise TTS is inspired by OpenAI's DALLE, applied to speech data. It is made up of 4 separate models that work together:
|
||||||
|
|
||||||
|
First, an autoregressive transformer stack predicts discrete speech "tokens" given a text prompt. This model is very
|
||||||
|
similar to the GPT model used by DALLE, except it operates on speech data.
|
||||||
|
|
||||||
|
Next, a CLIP model judges a batch of outputs from the autoregressive transformer against the provided text and stack
|
||||||
|
ranks the outputs according to most probable. You could use greedy or beam-search decoding but in my experience CLIP
|
||||||
|
decoding creates considerably better results.
|
||||||
|
|
||||||
|
Next, the speech "tokens" are decoded into a low-quality MEL spectrogram using a VQVAE.
|
||||||
|
|
||||||
|
Finally, the output of the VQVAE is further decoded by a UNet diffusion model into raw audio, which can be placed in
|
||||||
|
a wav file.
|
||||||
|
|
||||||
|
## How do I use this?
|
||||||
|
|
||||||
|
<incoming>
|
||||||
|
|
||||||
|
## How do I train this?
|
||||||
|
|
||||||
|
Frankly - you don't. Building this model has been a labor of love for me, consuming most of my 6 RTX3090s worth of
|
||||||
|
resources for the better part of 6 months. It uses a dataset I've gathered, refined and transcribed that consists of
|
||||||
|
a lot of audio data which I cannot distribute because of copywrite or no open licenses.
|
||||||
|
|
||||||
|
With that said, I'm willing to help you out if you really want to give it a shot. DM me.
|
1
data/tokenizer.json
Normal file
1
data/tokenizer.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
|
168
do_tts.py
Normal file
168
do_tts.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
import yaml
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models.arch_util import TorchMelSpectrogram
|
||||||
|
from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
|
||||||
|
from models.lucidrains_dvae import DiscreteVAE
|
||||||
|
from models.text_voice_clip import VoiceCLIP
|
||||||
|
from models.unified_voice import UnifiedVoice
|
||||||
|
from utils.audio import load_audio
|
||||||
|
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
||||||
|
from utils.tokenizer import VoiceBpeTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
||||||
|
"""
|
||||||
|
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||||
|
"""
|
||||||
|
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||||
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
|
||||||
|
|
||||||
|
|
||||||
|
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128):
|
||||||
|
"""
|
||||||
|
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
mel = dvae_model.decode(mel_codes)[0]
|
||||||
|
|
||||||
|
# Pad MEL to multiples of 2048//spectrogram_compression_factor
|
||||||
|
msl = mel.shape[-1]
|
||||||
|
dsl = 2048 // spectrogram_compression_factor
|
||||||
|
gap = dsl - (msl % dsl)
|
||||||
|
if gap > 0:
|
||||||
|
mel = torch.nn.functional.pad(mel, (0, gap))
|
||||||
|
|
||||||
|
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
|
||||||
|
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
|
||||||
|
|
||||||
|
|
||||||
|
def load_conditioning(path, sample_rate=22050, cond_length=44100):
|
||||||
|
rel_clip = load_audio(path, sample_rate)
|
||||||
|
gap = rel_clip.shape[-1] - cond_length
|
||||||
|
if gap < 0:
|
||||||
|
rel_clip = F.pad(rel_clip, pad=(0, abs(gap)))
|
||||||
|
elif gap > 0:
|
||||||
|
rand_start = random.randint(0, gap)
|
||||||
|
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
|
||||||
|
mel_clip = TorchMelSpectrogram()(rel_clip.unsqueeze(0)).squeeze(0)
|
||||||
|
return mel_clip.unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def fix_autoregressive_output(codes, stop_token):
|
||||||
|
"""
|
||||||
|
This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was
|
||||||
|
trained on and what the autoregressive code generator creates (which has no padding or end).
|
||||||
|
This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with
|
||||||
|
a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE
|
||||||
|
and copying out the last few codes.
|
||||||
|
|
||||||
|
Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar.
|
||||||
|
"""
|
||||||
|
# Strip off the autoregressive stop token and add padding.
|
||||||
|
stop_token_indices = (codes == stop_token).nonzero()
|
||||||
|
if len(stop_token_indices) == 0:
|
||||||
|
print("No stop tokens found, enjoy that output of yours!")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
codes[stop_token_indices] = 83
|
||||||
|
stm = stop_token_indices.min().item()
|
||||||
|
codes[stm:] = 83
|
||||||
|
if stm - 3 < codes.shape[0]:
|
||||||
|
codes[-3] = 45
|
||||||
|
codes[-2] = 45
|
||||||
|
codes[-1] = 248
|
||||||
|
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
preselected_cond_voices = {
|
||||||
|
'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'],
|
||||||
|
'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'],
|
||||||
|
'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'],
|
||||||
|
'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'],
|
||||||
|
}
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('-autoregressive_model_path', type=str, help='Autoregressive model checkpoint to load.', default='.models/unified_voice.pth')
|
||||||
|
parser.add_argument('-clip_model_path', type=str, help='CLIP model checkpoint to load.', default='.models/clip.pth')
|
||||||
|
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='./models/diffusion_vocoder.pth')
|
||||||
|
parser.add_argument('-dvae_model_path', type=str, help='DVAE model checkpoint to load.', default='./models/dvae.pth')
|
||||||
|
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||||
|
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='dan_carlin')
|
||||||
|
parser.add_argument('-num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=32)
|
||||||
|
parser.add_argument('-num_batches', type=int, help='How many batches those samples should be produced over.', default=2)
|
||||||
|
parser.add_argument('-num_outputs', type=int, help='Number of outputs to produce.', default=2)
|
||||||
|
parser.add_argument('-output_path', type=str, help='Where to store outputs.', default='results/')
|
||||||
|
args = parser.parse_args()
|
||||||
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
print("Loading GPT TTS..")
|
||||||
|
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).eval()
|
||||||
|
autoregressive.load_state_dict(torch.load(args.autoregressive_model_path))
|
||||||
|
stop_mel_token = autoregressive.stop_mel_token
|
||||||
|
|
||||||
|
print("Loading data..")
|
||||||
|
tokenizer = VoiceBpeTokenizer()
|
||||||
|
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
||||||
|
text = F.pad(text, (0,1)) # This may not be necessary.
|
||||||
|
cond_paths = preselected_cond_voices[args.cond_preset]
|
||||||
|
conds = []
|
||||||
|
for cond_path in cond_paths:
|
||||||
|
c, cond_wav = load_conditioning(cond_path, cond_length=132300)
|
||||||
|
conds.append(c)
|
||||||
|
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
print("Performing GPT inference..")
|
||||||
|
samples = []
|
||||||
|
for b in tqdm(range(args.num_batches)):
|
||||||
|
codes = autoregressive.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
|
||||||
|
temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
|
||||||
|
padding_needed = 250 - codes.shape[1]
|
||||||
|
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||||
|
samples.append(codes)
|
||||||
|
samples = torch.cat(samples, dim=0)
|
||||||
|
del autoregressive
|
||||||
|
|
||||||
|
print("Loading CLIP..")
|
||||||
|
clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=8, text_seq_len=120, text_heads=8,
|
||||||
|
num_speech_tokens=8192, speech_enc_depth=10, speech_heads=8, speech_seq_len=250).eval()
|
||||||
|
clip.load_state_dict(torch.load(args.clip_model_path))
|
||||||
|
print("Performing CLIP filtering..")
|
||||||
|
for i in range(samples.shape[0]):
|
||||||
|
samples[i] = fix_autoregressive_output(samples[i], stop_mel_token)
|
||||||
|
clip_results = clip(text.repeat(samples.shape[0], 1),
|
||||||
|
torch.full((samples.shape[0],), fill_value=text.shape[1]-1, dtype=torch.long, device='cuda'),
|
||||||
|
samples, torch.full((samples.shape[0],), fill_value=samples.shape[1]*1024, dtype=torch.long, device='cuda'),
|
||||||
|
return_loss=False)
|
||||||
|
best_results = samples[torch.topk(clip_results, k=args.num_outputs).indices]
|
||||||
|
|
||||||
|
# Delete the autoregressive and clip models to free up GPU memory
|
||||||
|
del samples, clip
|
||||||
|
|
||||||
|
print("Loading DVAE..")
|
||||||
|
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
|
||||||
|
record_codes=True, kernel_size=3, use_transposed_convs=False).eval()
|
||||||
|
dvae.load_state_dict(torch.load(args.dvae_model_path))
|
||||||
|
print("Loading Diffusion Model..")
|
||||||
|
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
|
||||||
|
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
|
||||||
|
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).eval()
|
||||||
|
diffusion.load_state_dict(torch.load(args.diffusion_model_path))
|
||||||
|
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
||||||
|
|
||||||
|
print("Performing vocoding..")
|
||||||
|
# Perform vocoding on each batch element separately: Vocoding is very memory (and compute!) intensive.
|
||||||
|
for b in range(best_results.shape[0]):
|
||||||
|
code = best_results[b].unsqueeze(0)
|
||||||
|
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256)
|
||||||
|
torchaudio.save(os.path.join(args.output_path, f'gpt_tts_output_{b}.wav'), wav.squeeze(0).cpu(), 22050)
|
319
models/arch_util.py
Normal file
319
models/arch_util.py
Normal file
|
@ -0,0 +1,319 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
"""
|
||||||
|
Zero out the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm32(nn.GroupNorm):
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def normalization(channels):
|
||||||
|
"""
|
||||||
|
Make a standard normalization layer.
|
||||||
|
|
||||||
|
:param channels: number of input channels.
|
||||||
|
:return: an nn.Module for normalization.
|
||||||
|
"""
|
||||||
|
groups = 32
|
||||||
|
if channels <= 16:
|
||||||
|
groups = 8
|
||||||
|
elif channels <= 64:
|
||||||
|
groups = 16
|
||||||
|
while channels % groups != 0:
|
||||||
|
groups = int(groups / 2)
|
||||||
|
assert groups > 2
|
||||||
|
return GroupNorm32(groups, channels)
|
||||||
|
|
||||||
|
|
||||||
|
class QKVAttentionLegacy(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv, mask=None):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
|
||||||
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = torch.einsum(
|
||||||
|
"bct,bcs->bts", q * scale, k * scale
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||||
|
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||||
|
weight = weight * mask
|
||||||
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||||
|
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
|
||||||
|
Originally ported from here, but adapted to the N-d case.
|
||||||
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
if num_head_channels == -1:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
channels % num_head_channels == 0
|
||||||
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
|
self.num_heads = channels // num_head_channels
|
||||||
|
self.norm = normalization(channels)
|
||||||
|
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||||
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||||
|
|
||||||
|
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
if mask is not None:
|
||||||
|
return self._forward(x, mask)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
def _forward(self, x, mask=None):
|
||||||
|
b, c, *spatial = x.shape
|
||||||
|
x = x.reshape(b, c, -1)
|
||||||
|
qkv = self.qkv(self.norm(x))
|
||||||
|
h = self.attention(qkv, mask)
|
||||||
|
h = self.proj_out(h)
|
||||||
|
return (x + h).reshape(b, c, *spatial)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
"""
|
||||||
|
An upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv, out_channels=None, factor=4):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.factor = factor
|
||||||
|
if use_conv:
|
||||||
|
ksize = 5
|
||||||
|
pad = 2
|
||||||
|
self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
x = F.interpolate(x, scale_factor=self.factor, mode="nearest")
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
stride = factor
|
||||||
|
if use_conv:
|
||||||
|
self.op = nn.Conv1d(
|
||||||
|
self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.channels == self.out_channels
|
||||||
|
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
dropout,
|
||||||
|
out_channels=None,
|
||||||
|
use_conv=False,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
up=False,
|
||||||
|
down=False,
|
||||||
|
kernel_size=3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.dropout = dropout
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
|
self.in_layers = nn.Sequential(
|
||||||
|
normalization(channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.updown = up or down
|
||||||
|
|
||||||
|
if up:
|
||||||
|
self.h_upd = Upsample(channels, False)
|
||||||
|
self.x_upd = Upsample(channels, False)
|
||||||
|
elif down:
|
||||||
|
self.h_upd = Downsample(channels, False)
|
||||||
|
self.x_upd = Downsample(channels, False)
|
||||||
|
else:
|
||||||
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
|
self.out_layers = nn.Sequential(
|
||||||
|
normalization(self.out_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
zero_module(
|
||||||
|
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.out_channels == channels:
|
||||||
|
self.skip_connection = nn.Identity()
|
||||||
|
elif use_conv:
|
||||||
|
self.skip_connection = nn.Conv1d(
|
||||||
|
channels, self.out_channels, kernel_size, padding=padding
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.updown:
|
||||||
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||||
|
h = in_rest(x)
|
||||||
|
h = self.h_upd(h)
|
||||||
|
x = self.x_upd(x)
|
||||||
|
h = in_conv(h)
|
||||||
|
else:
|
||||||
|
h = self.in_layers(x)
|
||||||
|
h = self.out_layers(h)
|
||||||
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMiniEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
embedding_dim,
|
||||||
|
base_channels=128,
|
||||||
|
depth=2,
|
||||||
|
resnet_blocks=2,
|
||||||
|
attn_blocks=4,
|
||||||
|
num_attn_heads=4,
|
||||||
|
dropout=0,
|
||||||
|
downsample_factor=2,
|
||||||
|
kernel_size=3):
|
||||||
|
super().__init__()
|
||||||
|
self.init = nn.Sequential(
|
||||||
|
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
ch = base_channels
|
||||||
|
res = []
|
||||||
|
for l in range(depth):
|
||||||
|
for r in range(resnet_blocks):
|
||||||
|
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
||||||
|
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
||||||
|
ch *= 2
|
||||||
|
self.res = nn.Sequential(*res)
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(ch, embedding_dim, 1)
|
||||||
|
)
|
||||||
|
attn = []
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.res(h)
|
||||||
|
h = self.final(h)
|
||||||
|
h = self.attn(h)
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class TorchMelSpectrogram(nn.Module):
|
||||||
|
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, mel_fmin=0, mel_fmax=8000,
|
||||||
|
sampling_rate=22050, normalize=False, mel_norm_file='data/mel_norms.pth'):
|
||||||
|
super().__init__()
|
||||||
|
# These are the default tacotron values for the MEL spectrogram.
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
self.mel_fmin = mel_fmin
|
||||||
|
self.mel_fmax = mel_fmax
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length, power=2, normalized=normalize,
|
||||||
|
sample_rate=self.sampling_rate, f_min=self.mel_fmin,
|
||||||
|
f_max=self.mel_fmax, n_mels=self.n_mel_channels,
|
||||||
|
norm="slaney")
|
||||||
|
self.mel_norm_file = mel_norm_file
|
||||||
|
if self.mel_norm_file is not None:
|
||||||
|
self.mel_norms = torch.load(self.mel_norm_file)
|
||||||
|
else:
|
||||||
|
self.mel_norms = None
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||||
|
inp = inp.squeeze(1)
|
||||||
|
assert len(inp.shape) == 2
|
||||||
|
self.mel_stft = self.mel_stft.to(inp.device)
|
||||||
|
mel = self.mel_stft(inp)
|
||||||
|
# Perform dynamic range compression
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
if self.mel_norms is not None:
|
||||||
|
self.mel_norms = self.mel_norms.to(mel.device)
|
||||||
|
mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||||
|
return mel
|
390
models/lucidrains_dvae.py
Normal file
390
models/lucidrains_dvae.py
Normal file
|
@ -0,0 +1,390 @@
|
||||||
|
import functools
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if val is not None else d
|
||||||
|
|
||||||
|
|
||||||
|
def eval_decorator(fn):
|
||||||
|
def inner(model, *args, **kwargs):
|
||||||
|
was_training = model.training
|
||||||
|
model.eval()
|
||||||
|
out = fn(model, *args, **kwargs)
|
||||||
|
model.train(was_training)
|
||||||
|
return out
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
# Quantizer implemented by the rosinality vqvae repo.
|
||||||
|
# Credit: https://github.com/rosinality/vq-vae-2-pytorch
|
||||||
|
class Quantize(nn.Module):
|
||||||
|
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.decay = decay
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.balancing_heuristic = balancing_heuristic
|
||||||
|
self.codes = None
|
||||||
|
self.max_codes = 64000
|
||||||
|
self.codes_full = False
|
||||||
|
self.new_return_order = new_return_order
|
||||||
|
|
||||||
|
embed = torch.randn(dim, n_embed)
|
||||||
|
self.register_buffer("embed", embed)
|
||||||
|
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||||
|
self.register_buffer("embed_avg", embed.clone())
|
||||||
|
|
||||||
|
def forward(self, input, return_soft_codes=False):
|
||||||
|
if self.balancing_heuristic and self.codes_full:
|
||||||
|
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
||||||
|
mask = torch.logical_or(h > .9, h < .01).unsqueeze(1)
|
||||||
|
ep = self.embed.permute(1,0)
|
||||||
|
ea = self.embed_avg.permute(1,0)
|
||||||
|
rand_embed = torch.randn_like(ep) * mask
|
||||||
|
self.embed = (ep * ~mask + rand_embed).permute(1,0)
|
||||||
|
self.embed_avg = (ea * ~mask + rand_embed).permute(1,0)
|
||||||
|
self.cluster_size = self.cluster_size * ~mask.squeeze()
|
||||||
|
if torch.any(mask):
|
||||||
|
print(f"Reset {torch.sum(mask)} embedding codes.")
|
||||||
|
self.codes = None
|
||||||
|
self.codes_full = False
|
||||||
|
|
||||||
|
flatten = input.reshape(-1, self.dim)
|
||||||
|
dist = (
|
||||||
|
flatten.pow(2).sum(1, keepdim=True)
|
||||||
|
- 2 * flatten @ self.embed
|
||||||
|
+ self.embed.pow(2).sum(0, keepdim=True)
|
||||||
|
)
|
||||||
|
soft_codes = -dist
|
||||||
|
_, embed_ind = soft_codes.max(1)
|
||||||
|
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||||
|
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||||
|
quantize = self.embed_code(embed_ind)
|
||||||
|
|
||||||
|
if self.balancing_heuristic:
|
||||||
|
if self.codes is None:
|
||||||
|
self.codes = embed_ind.flatten()
|
||||||
|
else:
|
||||||
|
self.codes = torch.cat([self.codes, embed_ind.flatten()])
|
||||||
|
if len(self.codes) > self.max_codes:
|
||||||
|
self.codes = self.codes[-self.max_codes:]
|
||||||
|
self.codes_full = True
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
embed_onehot_sum = embed_onehot.sum(0)
|
||||||
|
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
||||||
|
|
||||||
|
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||||
|
distributed.all_reduce(embed_onehot_sum)
|
||||||
|
distributed.all_reduce(embed_sum)
|
||||||
|
|
||||||
|
self.cluster_size.data.mul_(self.decay).add_(
|
||||||
|
embed_onehot_sum, alpha=1 - self.decay
|
||||||
|
)
|
||||||
|
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
||||||
|
n = self.cluster_size.sum()
|
||||||
|
cluster_size = (
|
||||||
|
(self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
||||||
|
)
|
||||||
|
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
||||||
|
self.embed.data.copy_(embed_normalized)
|
||||||
|
|
||||||
|
diff = (quantize.detach() - input).pow(2).mean()
|
||||||
|
quantize = input + (quantize - input).detach()
|
||||||
|
|
||||||
|
if return_soft_codes:
|
||||||
|
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
||||||
|
elif self.new_return_order:
|
||||||
|
return quantize, embed_ind, diff
|
||||||
|
else:
|
||||||
|
return quantize, diff, embed_ind
|
||||||
|
|
||||||
|
def embed_code(self, embed_id):
|
||||||
|
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
|
||||||
|
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
|
||||||
|
# values with the specified expected variance.
|
||||||
|
class DiscretizationLoss(nn.Module):
|
||||||
|
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
|
||||||
|
super().__init__()
|
||||||
|
self.discrete_bins = discrete_bins
|
||||||
|
self.dim = dim
|
||||||
|
self.dist = torch.distributions.Normal(0, scale=expected_variance)
|
||||||
|
if store_past > 0:
|
||||||
|
self.record_past = True
|
||||||
|
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||||
|
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device='cpu'))
|
||||||
|
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
|
||||||
|
else:
|
||||||
|
self.record_past = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
other_dims = set(range(len(x.shape)))-set([self.dim])
|
||||||
|
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
|
||||||
|
averaged = averaged - averaged.mean()
|
||||||
|
|
||||||
|
if self.record_past:
|
||||||
|
acc_count = self.accumulator.shape[0]
|
||||||
|
avg = averaged.detach().clone()
|
||||||
|
if self.accumulator_filled > 0:
|
||||||
|
averaged = torch.mean(self.accumulator, dim=0) * (acc_count-1) / acc_count + \
|
||||||
|
averaged / acc_count
|
||||||
|
|
||||||
|
# Also push averaged into the accumulator.
|
||||||
|
self.accumulator[self.accumulator_index] = avg
|
||||||
|
self.accumulator_index += 1
|
||||||
|
if self.accumulator_index >= acc_count:
|
||||||
|
self.accumulator_index *= 0
|
||||||
|
if self.accumulator_filled <= 0:
|
||||||
|
self.accumulator_filled += 1
|
||||||
|
|
||||||
|
return torch.sum(-self.dist.log_prob(averaged))
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, chan, conv, activation):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
conv(chan, chan, 3, padding = 1),
|
||||||
|
activation(),
|
||||||
|
conv(chan, chan, 3, padding = 1),
|
||||||
|
activation(),
|
||||||
|
conv(chan, chan, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x) + x
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampledConv(nn.Module):
|
||||||
|
def __init__(self, conv, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
assert 'stride' in kwargs.keys()
|
||||||
|
self.stride = kwargs['stride']
|
||||||
|
del kwargs['stride']
|
||||||
|
self.conv = conv(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
|
||||||
|
return self.conv(up)
|
||||||
|
|
||||||
|
|
||||||
|
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
||||||
|
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
||||||
|
class DiscreteVAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
positional_dims=2,
|
||||||
|
num_tokens = 512,
|
||||||
|
codebook_dim = 512,
|
||||||
|
num_layers = 3,
|
||||||
|
num_resnet_blocks = 0,
|
||||||
|
hidden_dim = 64,
|
||||||
|
channels = 3,
|
||||||
|
stride = 2,
|
||||||
|
kernel_size = 4,
|
||||||
|
use_transposed_convs = True,
|
||||||
|
encoder_norm = False,
|
||||||
|
activation = 'relu',
|
||||||
|
smooth_l1_loss = False,
|
||||||
|
straight_through = False,
|
||||||
|
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||||
|
record_codes = False,
|
||||||
|
discretization_loss_averaging_steps = 100,
|
||||||
|
lr_quantizer_args = {},
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
has_resblocks = num_resnet_blocks > 0
|
||||||
|
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.straight_through = straight_through
|
||||||
|
self.positional_dims = positional_dims
|
||||||
|
self.discrete_loss = DiscretizationLoss(num_tokens, 2, 1 / (num_tokens*2), discretization_loss_averaging_steps)
|
||||||
|
|
||||||
|
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||||
|
if positional_dims == 2:
|
||||||
|
conv = nn.Conv2d
|
||||||
|
conv_transpose = nn.ConvTranspose2d
|
||||||
|
else:
|
||||||
|
conv = nn.Conv1d
|
||||||
|
conv_transpose = nn.ConvTranspose1d
|
||||||
|
if not use_transposed_convs:
|
||||||
|
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||||
|
|
||||||
|
if activation == 'relu':
|
||||||
|
act = nn.ReLU
|
||||||
|
elif activation == 'silu':
|
||||||
|
act = nn.SiLU
|
||||||
|
else:
|
||||||
|
assert NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
enc_layers = []
|
||||||
|
dec_layers = []
|
||||||
|
|
||||||
|
if num_layers > 0:
|
||||||
|
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||||
|
dec_chans = list(reversed(enc_chans))
|
||||||
|
|
||||||
|
enc_chans = [channels, *enc_chans]
|
||||||
|
|
||||||
|
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||||
|
dec_chans = [dec_init_chan, *dec_chans]
|
||||||
|
|
||||||
|
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||||
|
|
||||||
|
pad = (kernel_size - 1) // 2
|
||||||
|
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||||
|
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act()))
|
||||||
|
if encoder_norm:
|
||||||
|
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||||
|
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act()))
|
||||||
|
dec_out_chans = dec_chans[-1]
|
||||||
|
innermost_dim = dec_chans[0]
|
||||||
|
else:
|
||||||
|
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
|
||||||
|
dec_out_chans = hidden_dim
|
||||||
|
innermost_dim = hidden_dim
|
||||||
|
|
||||||
|
for _ in range(num_resnet_blocks):
|
||||||
|
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
||||||
|
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
||||||
|
|
||||||
|
if num_resnet_blocks > 0:
|
||||||
|
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||||
|
|
||||||
|
|
||||||
|
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||||
|
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(*enc_layers)
|
||||||
|
self.decoder = nn.Sequential(*dec_layers)
|
||||||
|
|
||||||
|
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||||
|
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||||
|
|
||||||
|
# take care of normalization within class
|
||||||
|
self.normalization = normalization
|
||||||
|
self.record_codes = record_codes
|
||||||
|
if record_codes:
|
||||||
|
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||||
|
self.code_ind = 0
|
||||||
|
self.total_codes = 0
|
||||||
|
self.internal_step = 0
|
||||||
|
|
||||||
|
def norm(self, images):
|
||||||
|
if not self.normalization is not None:
|
||||||
|
return images
|
||||||
|
|
||||||
|
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
||||||
|
arrange = 'c -> () c () ()' if self.positional_dims == 2 else 'c -> () c ()'
|
||||||
|
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
||||||
|
images = images.clone()
|
||||||
|
images.sub_(means).div_(stds)
|
||||||
|
return images
|
||||||
|
|
||||||
|
def get_debug_values(self, step, __):
|
||||||
|
if self.record_codes and self.total_codes > 0:
|
||||||
|
# Report annealing schedule
|
||||||
|
return {'histogram_codes': self.codes[:self.total_codes]}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@eval_decorator
|
||||||
|
def get_codebook_indices(self, images):
|
||||||
|
img = self.norm(images)
|
||||||
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
sampled, codes, _ = self.codebook(logits)
|
||||||
|
self.log_codes(codes)
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
img_seq
|
||||||
|
):
|
||||||
|
self.log_codes(img_seq)
|
||||||
|
if hasattr(self.codebook, 'embed_code'):
|
||||||
|
image_embeds = self.codebook.embed_code(img_seq)
|
||||||
|
else:
|
||||||
|
image_embeds = F.embedding(img_seq, self.codebook.codebook)
|
||||||
|
b, n, d = image_embeds.shape
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if self.positional_dims == 1:
|
||||||
|
arrange = 'b n d -> b d n'
|
||||||
|
else:
|
||||||
|
h = w = int(sqrt(n))
|
||||||
|
arrange = 'b (h w) d -> b d h w'
|
||||||
|
kwargs = {'h': h, 'w': w}
|
||||||
|
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
||||||
|
images = [image_embeds]
|
||||||
|
for layer in self.decoder:
|
||||||
|
images.append(layer(images[-1]))
|
||||||
|
return images[-1], images[-2]
|
||||||
|
|
||||||
|
def infer(self, img):
|
||||||
|
img = self.norm(img)
|
||||||
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
sampled, codes, commitment_loss = self.codebook(logits)
|
||||||
|
return self.decode(codes)
|
||||||
|
|
||||||
|
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||||
|
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
||||||
|
# more lossy (but useful for determining network performance).
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img
|
||||||
|
):
|
||||||
|
img = self.norm(img)
|
||||||
|
logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
sampled, codes, commitment_loss = self.codebook(logits)
|
||||||
|
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
out = sampled
|
||||||
|
for d in self.decoder:
|
||||||
|
out = d(out)
|
||||||
|
self.log_codes(codes)
|
||||||
|
else:
|
||||||
|
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||||
|
out, _ = self.decode(codes)
|
||||||
|
|
||||||
|
# reconstruction loss
|
||||||
|
recon_loss = self.loss_fn(img, out, reduction='none')
|
||||||
|
|
||||||
|
return recon_loss, commitment_loss, out
|
||||||
|
|
||||||
|
def log_codes(self, codes):
|
||||||
|
# This is so we can debug the distribution of codes being learned.
|
||||||
|
if self.record_codes and self.internal_step % 10 == 0:
|
||||||
|
codes = codes.flatten()
|
||||||
|
l = codes.shape[0]
|
||||||
|
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||||
|
self.codes[i:i+l] = codes.cpu()
|
||||||
|
self.code_ind = self.code_ind + l
|
||||||
|
if self.code_ind >= self.codes.shape[0]:
|
||||||
|
self.code_ind = 0
|
||||||
|
self.total_codes += 1
|
||||||
|
self.internal_step += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||||
|
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False)
|
||||||
|
r,l,o=v(torch.randn(1,80,256))
|
||||||
|
v.decode(torch.randint(0,8192,(1,256)))
|
||||||
|
print(o.shape, l.shape)
|
125
models/text_voice_clip.py
Normal file
125
models/text_voice_clip.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import einsum
|
||||||
|
from models.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def masked_mean(t, mask, dim = 1):
|
||||||
|
t = t.masked_fill(~mask[:, :, None], 0.)
|
||||||
|
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCLIP(nn.Module):
|
||||||
|
"""
|
||||||
|
CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding
|
||||||
|
transcribed text.
|
||||||
|
|
||||||
|
Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim_text=512,
|
||||||
|
dim_speech=512,
|
||||||
|
dim_latent=512,
|
||||||
|
num_text_tokens=256,
|
||||||
|
text_enc_depth=6,
|
||||||
|
text_seq_len=120,
|
||||||
|
text_heads=8,
|
||||||
|
num_speech_tokens=8192,
|
||||||
|
speech_enc_depth=6,
|
||||||
|
speech_heads=8,
|
||||||
|
speech_seq_len=250,
|
||||||
|
text_mask_percentage=0,
|
||||||
|
voice_mask_percentage=0,
|
||||||
|
wav_token_compression=1024,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
||||||
|
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
||||||
|
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
||||||
|
heads=text_heads)
|
||||||
|
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False)
|
||||||
|
|
||||||
|
self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||||
|
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||||
|
depth=speech_enc_depth, heads=speech_heads)
|
||||||
|
self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False)
|
||||||
|
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
|
self.text_mask_percentage = text_mask_percentage
|
||||||
|
self.voice_mask_percentage = voice_mask_percentage
|
||||||
|
self.wav_token_compression = wav_token_compression
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
speech_tokens,
|
||||||
|
wav_lengths,
|
||||||
|
return_loss=False
|
||||||
|
):
|
||||||
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
|
# chopping the inputs by the maximum actual length.
|
||||||
|
max_text_len = text_lengths.max()
|
||||||
|
text = text[:, :max_text_len]
|
||||||
|
max_mel_len = wav_lengths.max() // self.wav_token_compression
|
||||||
|
speech_tokens = speech_tokens[:, :max_mel_len]
|
||||||
|
|
||||||
|
b, device = text.shape[0], text.device
|
||||||
|
if self.training:
|
||||||
|
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||||
|
voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage
|
||||||
|
else:
|
||||||
|
text_mask = torch.ones_like(text.float()).bool()
|
||||||
|
voice_mask = torch.ones_like(speech_tokens.float()).bool()
|
||||||
|
|
||||||
|
text_emb = self.text_emb(text)
|
||||||
|
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device))
|
||||||
|
|
||||||
|
speech_emb = self.speech_emb(speech_tokens)
|
||||||
|
speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device))
|
||||||
|
|
||||||
|
enc_text = self.text_transformer(text_emb, mask=text_mask)
|
||||||
|
enc_speech = self.speech_transformer(speech_emb, mask=voice_mask)
|
||||||
|
|
||||||
|
text_latents = masked_mean(enc_text, text_mask, dim=1)
|
||||||
|
speech_latents = masked_mean(enc_speech, voice_mask, dim=1)
|
||||||
|
|
||||||
|
text_latents = self.to_text_latent(text_latents)
|
||||||
|
speech_latents = self.to_speech_latent(speech_latents)
|
||||||
|
|
||||||
|
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||||
|
|
||||||
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
|
if not return_loss:
|
||||||
|
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
|
||||||
|
return sim
|
||||||
|
|
||||||
|
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||||
|
labels = torch.arange(b, device=device)
|
||||||
|
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
clip = VoiceCLIP(text_mask_percentage=.2, voice_mask_percentage=.2)
|
||||||
|
clip(torch.randint(0,256,(2,120)),
|
||||||
|
torch.tensor([50,100]),
|
||||||
|
torch.randint(0,8192,(2,250)),
|
||||||
|
torch.tensor([101,102]),
|
||||||
|
return_loss=True)
|
||||||
|
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||||
|
torch.tensor([50,100]),
|
||||||
|
torch.randint(0,8192,(2,250)),
|
||||||
|
torch.tensor([101,102]),
|
||||||
|
return_loss=False)
|
||||||
|
print(nonloss.shape)
|
219
models/transformer.py
Normal file
219
models/transformer.py
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from rotary_embedding_torch import RotaryEmbedding, broadcat
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tuple(val, depth = 1):
|
||||||
|
if isinstance(val, list):
|
||||||
|
val = tuple(val)
|
||||||
|
return val if isinstance(val, tuple) else (val,) * depth
|
||||||
|
|
||||||
|
|
||||||
|
def max_neg_value(t):
|
||||||
|
return -torch.finfo(t.dtype).max
|
||||||
|
|
||||||
|
|
||||||
|
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||||
|
t = t / alpha
|
||||||
|
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||||
|
return (t * alpha).softmax(dim = dim)
|
||||||
|
|
||||||
|
|
||||||
|
def route_args(router, args, depth):
|
||||||
|
routed_args = [(dict(), dict()) for _ in range(depth)]
|
||||||
|
matched_keys = [key for key in args.keys() if key in router]
|
||||||
|
|
||||||
|
for key in matched_keys:
|
||||||
|
val = args[key]
|
||||||
|
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||||
|
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
||||||
|
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||||
|
return routed_args
|
||||||
|
|
||||||
|
|
||||||
|
# classes
|
||||||
|
class SequentialSequence(nn.Module):
|
||||||
|
def __init__(self, layers, args_route = {}, layer_dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
||||||
|
self.layers = layers
|
||||||
|
self.args_route = args_route
|
||||||
|
self.layer_dropout = layer_dropout
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
args = route_args(self.args_route, kwargs, len(self.layers))
|
||||||
|
layers_and_args = list(zip(self.layers, args))
|
||||||
|
|
||||||
|
for (f, g), (f_args, g_args) in layers_and_args:
|
||||||
|
x = x + f(x, **f_args)
|
||||||
|
x = x + g(x, **g_args)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DivideMax(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
maxes = x.amax(dim = self.dim, keepdim = True).detach()
|
||||||
|
return x / maxes
|
||||||
|
|
||||||
|
|
||||||
|
# https://arxiv.org/abs/2103.17239
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(self, dim, depth, fn):
|
||||||
|
super().__init__()
|
||||||
|
if depth <= 18:
|
||||||
|
init_eps = 0.1
|
||||||
|
elif depth > 18 and depth <= 24:
|
||||||
|
init_eps = 1e-5
|
||||||
|
else:
|
||||||
|
init_eps = 1e-6
|
||||||
|
|
||||||
|
scale = torch.zeros(1, 1, dim).fill_(init_eps)
|
||||||
|
self.scale = nn.Parameter(scale)
|
||||||
|
self.fn = fn
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.fn(x, **kwargs) * self.scale
|
||||||
|
|
||||||
|
# layer norm
|
||||||
|
|
||||||
|
|
||||||
|
class PreNorm(nn.Module):
|
||||||
|
def __init__(self, dim, fn, sandwich = False):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fn(x, **kwargs)
|
||||||
|
return self.norm_out(x)
|
||||||
|
|
||||||
|
# feed forward
|
||||||
|
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x, gates = x.chunk(2, dim = -1)
|
||||||
|
return x * F.gelu(gates)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dropout = 0., mult = 4.):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim * mult * 2),
|
||||||
|
GEGLU(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim * mult, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropout = 0.):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
self.to_out = nn.Sequential(
|
||||||
|
nn.Linear(inner_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask = None):
|
||||||
|
b, n, _, h, device = *x.shape, self.heads, x.device
|
||||||
|
softmax = torch.softmax
|
||||||
|
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
||||||
|
mask_value = max_neg_value(dots)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b j -> b () () j')
|
||||||
|
dots.masked_fill_(~mask, mask_value)
|
||||||
|
del mask
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
i, j = dots.shape[-2:]
|
||||||
|
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
|
||||||
|
dots.masked_fill_(mask, mask_value)
|
||||||
|
|
||||||
|
attn = softmax(dots, dim=-1)
|
||||||
|
|
||||||
|
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# main transformer class
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
seq_len,
|
||||||
|
causal = True,
|
||||||
|
heads = 8,
|
||||||
|
dim_head = 64,
|
||||||
|
ff_mult = 4,
|
||||||
|
attn_dropout = 0.,
|
||||||
|
ff_dropout = 0.,
|
||||||
|
sparse_attn = False,
|
||||||
|
sandwich_norm = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layers = nn.ModuleList([])
|
||||||
|
sparse_layer = cast_tuple(sparse_attn, depth)
|
||||||
|
|
||||||
|
for ind, sparse_attn in zip(range(depth), sparse_layer):
|
||||||
|
attn = Attention(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
|
||||||
|
|
||||||
|
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
|
||||||
|
|
||||||
|
layers.append(nn.ModuleList([
|
||||||
|
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
|
||||||
|
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
|
||||||
|
]))
|
||||||
|
|
||||||
|
execute_type = SequentialSequence
|
||||||
|
route_attn = ((True, False),) * depth
|
||||||
|
attn_route_map = {'mask': route_attn}
|
||||||
|
|
||||||
|
self.layers = execute_type(layers, args_route = attn_route_map)
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
return self.layers(x, **kwargs)
|
530
models/unified_voice.py
Normal file
530
models/unified_voice.py
Normal file
|
@ -0,0 +1,530 @@
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||||
|
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||||
|
from models.arch_util import AttentionBlock
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def null_position_embeddings(range, dim):
|
||||||
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Basic residual convolutional block that uses GroupNorm.
|
||||||
|
"""
|
||||||
|
def __init__(self, chan):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
|
nn.GroupNorm(chan//8, chan),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
||||||
|
nn.GroupNorm(chan//8, chan)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.relu(self.net(x) + x)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
|
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = gpt
|
||||||
|
self.text_pos_embedding = text_pos_emb
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.lm_head = nn.Sequential(norm, linear)
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
self.cached_mel_emb = None
|
||||||
|
|
||||||
|
def parallelize(self, device_map=None):
|
||||||
|
self.device_map = (
|
||||||
|
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||||
|
if device_map is None
|
||||||
|
else device_map
|
||||||
|
)
|
||||||
|
assert_device_map(self.device_map, len(self.transformer.h))
|
||||||
|
self.transformer.parallelize(self.device_map)
|
||||||
|
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
||||||
|
self.model_parallel = True
|
||||||
|
|
||||||
|
def deparallelize(self):
|
||||||
|
self.transformer.deparallelize()
|
||||||
|
self.transformer = self.transformer.to("cpu")
|
||||||
|
self.lm_head = self.lm_head.to("cpu")
|
||||||
|
self.model_parallel = False
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def store_mel_emb(self, mel_emb):
|
||||||
|
self.cached_mel_emb = mel_emb
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
|
||||||
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
if past:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
assert self.cached_mel_emb is not None
|
||||||
|
assert inputs_embeds is None # Not supported by this inference model.
|
||||||
|
assert labels is None # Training not supported by this inference model.
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# Create embedding
|
||||||
|
mel_len = self.cached_mel_emb.shape[1]
|
||||||
|
if input_ids.shape[1] != 1:
|
||||||
|
text_inputs = input_ids[:, mel_len:]
|
||||||
|
text_emb = self.embeddings(text_inputs)
|
||||||
|
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
||||||
|
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||||
|
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0]//self.cached_mel_emb.shape[0], 0)
|
||||||
|
else:
|
||||||
|
mel_emb = self.cached_mel_emb
|
||||||
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
else:
|
||||||
|
emb = self.embeddings(input_ids)
|
||||||
|
emb = emb + self.text_pos_embedding.get_fixed_embedding(attention_mask.shape[1]-mel_len, attention_mask.device)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
inputs_embeds=emb,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
# Set device for model parallelism
|
||||||
|
if self.model_parallel:
|
||||||
|
torch.cuda.set_device(self.transformer.first_device)
|
||||||
|
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||||
|
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (lm_logits,) + transformer_outputs[1:]
|
||||||
|
|
||||||
|
return CausalLMOutputWithCrossAttentions(
|
||||||
|
loss=None,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
|
attentions=transformer_outputs.attentions,
|
||||||
|
cross_attentions=transformer_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past, beam_idx):
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`past_key_values` cache if
|
||||||
|
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
||||||
|
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return tuple(
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||||
|
for layer_past in past
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
spec_dim,
|
||||||
|
embedding_dim,
|
||||||
|
attn_blocks=6,
|
||||||
|
num_attn_heads=4,
|
||||||
|
do_checkpointing=False):
|
||||||
|
super().__init__()
|
||||||
|
attn = []
|
||||||
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||||
|
for a in range(attn_blocks):
|
||||||
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||||
|
self.attn = nn.Sequential(*attn)
|
||||||
|
self.dim = embedding_dim
|
||||||
|
self.do_checkpointing = do_checkpointing
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.init(x)
|
||||||
|
h = self.attn(h)
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
|
def __init__(self, seq_len, model_dim, init=.02):
|
||||||
|
super().__init__()
|
||||||
|
self.emb = nn.Embedding(seq_len, model_dim)
|
||||||
|
# Initializing this way is standard for GPT-2
|
||||||
|
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
sl = x.shape[1]
|
||||||
|
return self.emb(torch.arange(0, sl, device=x.device))
|
||||||
|
|
||||||
|
def get_fixed_embedding(self, ind, dev):
|
||||||
|
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||||
|
"""
|
||||||
|
GPT-2 implemented by the HuggingFace library.
|
||||||
|
"""
|
||||||
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
||||||
|
n_positions=max_mel_seq_len+max_text_seq_len,
|
||||||
|
n_ctx=max_mel_seq_len+max_text_seq_len,
|
||||||
|
n_embd=model_dim,
|
||||||
|
n_layer=layers,
|
||||||
|
n_head=heads,
|
||||||
|
gradient_checkpointing=checkpointing,
|
||||||
|
use_cache=not checkpointing)
|
||||||
|
gpt = GPT2Model(gpt_config)
|
||||||
|
# Override the built in positional embeddings
|
||||||
|
del gpt.wpe
|
||||||
|
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||||
|
# Built-in token embeddings are unused.
|
||||||
|
del gpt.wte
|
||||||
|
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\
|
||||||
|
None, None
|
||||||
|
|
||||||
|
|
||||||
|
class MelEncoder(nn.Module):
|
||||||
|
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels//4, kernel_size=3, padding=1),
|
||||||
|
nn.Sequential(*[ResBlock(channels//4) for _ in range(resblocks_per_reduction)]),
|
||||||
|
nn.Conv1d(channels//4, channels//2, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(channels//16, channels//2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Sequential(*[ResBlock(channels//2) for _ in range(resblocks_per_reduction)]),
|
||||||
|
nn.Conv1d(channels//2, channels, kernel_size=3, stride=2, padding=1),
|
||||||
|
nn.GroupNorm(channels//8, channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||||
|
)
|
||||||
|
self.reduction = 4
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for e in self.encoder:
|
||||||
|
x = e(x)
|
||||||
|
return x.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
|
class UnifiedVoice(nn.Module):
|
||||||
|
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
||||||
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
|
checkpointing=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
layers: Number of layers in transformer stack.
|
||||||
|
model_dim: Operating dimensions of the transformer
|
||||||
|
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
||||||
|
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
||||||
|
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
||||||
|
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
||||||
|
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
||||||
|
number_text_tokens:
|
||||||
|
start_text_token:
|
||||||
|
stop_text_token:
|
||||||
|
number_mel_codes:
|
||||||
|
start_mel_token:
|
||||||
|
stop_mel_token:
|
||||||
|
train_solo_embeddings:
|
||||||
|
use_mel_codes_as_input:
|
||||||
|
checkpointing:
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.number_text_tokens = number_text_tokens
|
||||||
|
self.start_text_token = start_text_token
|
||||||
|
self.stop_text_token = stop_text_token
|
||||||
|
self.number_mel_codes = number_mel_codes
|
||||||
|
self.start_mel_token = start_mel_token
|
||||||
|
self.stop_mel_token = stop_mel_token
|
||||||
|
self.layers = layers
|
||||||
|
self.heads = heads
|
||||||
|
self.max_mel_tokens = max_mel_tokens
|
||||||
|
self.max_text_tokens = max_text_tokens
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
|
self.mel_length_compression = mel_length_compression
|
||||||
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
|
if use_mel_codes_as_input:
|
||||||
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
|
else:
|
||||||
|
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
||||||
|
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
||||||
|
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing)
|
||||||
|
if train_solo_embeddings:
|
||||||
|
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
||||||
|
else:
|
||||||
|
self.mel_solo_embedding = 0
|
||||||
|
self.text_solo_embedding = 0
|
||||||
|
|
||||||
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
|
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
||||||
|
|
||||||
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
|
embeddings = [self.text_embedding]
|
||||||
|
if use_mel_codes_as_input:
|
||||||
|
embeddings.append(self.mel_embedding)
|
||||||
|
for module in embeddings:
|
||||||
|
module.weight.data.normal_(mean=0.0, std=.02)
|
||||||
|
|
||||||
|
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||||
|
inp = F.pad(input, (1,0), value=start_token)
|
||||||
|
tar = F.pad(input, (0,1), value=stop_token)
|
||||||
|
return inp, tar
|
||||||
|
|
||||||
|
def set_mel_padding(self, mel_input_tokens, wav_lengths):
|
||||||
|
"""
|
||||||
|
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
||||||
|
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
|
||||||
|
preformatting to create a working TTS model.
|
||||||
|
"""
|
||||||
|
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||||
|
mel_lengths = wav_lengths // self.mel_length_compression
|
||||||
|
for b in range(len(mel_lengths)):
|
||||||
|
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
|
||||||
|
if actual_end < mel_input_tokens.shape[-1]:
|
||||||
|
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
||||||
|
return mel_input_tokens
|
||||||
|
|
||||||
|
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False):
|
||||||
|
if second_inputs is not None:
|
||||||
|
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||||
|
|
||||||
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
|
if get_attns:
|
||||||
|
return gpt_out.attentions
|
||||||
|
|
||||||
|
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||||
|
enc = self.final_norm(enc)
|
||||||
|
first_logits = enc[:, :first_inputs.shape[1]]
|
||||||
|
first_logits = first_head(first_logits)
|
||||||
|
first_logits = first_logits.permute(0,2,1)
|
||||||
|
if second_inputs is not None:
|
||||||
|
second_logits = enc[:, -second_inputs.shape[1]:]
|
||||||
|
second_logits = second_head(second_logits)
|
||||||
|
second_logits = second_logits.permute(0,2,1)
|
||||||
|
return first_logits, second_logits
|
||||||
|
else:
|
||||||
|
return first_logits
|
||||||
|
|
||||||
|
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False):
|
||||||
|
"""
|
||||||
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
|
(actuated by `text_first`).
|
||||||
|
|
||||||
|
speech_conditioning_input: MEL float tensor, (b,80,s)
|
||||||
|
text_inputs: long tensor, (b,t)
|
||||||
|
text_lengths: long tensor, (b,)
|
||||||
|
mel_inputs: long tensor, (b,m)
|
||||||
|
wav_lengths: long tensor, (b,)
|
||||||
|
raw_mels: MEL float tensor (b,80,s)
|
||||||
|
"""
|
||||||
|
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||||
|
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
|
|
||||||
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
|
# chopping the inputs by the maximum actual length.
|
||||||
|
max_text_len = text_lengths.max()
|
||||||
|
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||||
|
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||||
|
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||||
|
if raw_mels is not None:
|
||||||
|
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||||
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
|
|
||||||
|
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
|
||||||
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
|
if raw_mels is not None:
|
||||||
|
mel_inp = F.pad(raw_mels, (0, 8))
|
||||||
|
else:
|
||||||
|
mel_inp = mel_codes
|
||||||
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
||||||
|
if text_first:
|
||||||
|
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions)
|
||||||
|
else:
|
||||||
|
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions)
|
||||||
|
|
||||||
|
if return_attentions:
|
||||||
|
return mel_logits
|
||||||
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
|
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||||
|
|
||||||
|
def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
|
||||||
|
"""
|
||||||
|
Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
|
||||||
|
model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
|
||||||
|
"""
|
||||||
|
assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}'
|
||||||
|
|
||||||
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
|
# chopping the inputs by the maximum actual length.
|
||||||
|
max_text_len = text_lengths.max()
|
||||||
|
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||||
|
|
||||||
|
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
|
||||||
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||||
|
text_logits = self.get_logits(conds, text_emb, self.text_head)
|
||||||
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
|
return loss_text.mean()
|
||||||
|
|
||||||
|
def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
|
||||||
|
"""
|
||||||
|
Performs autoregressive modeling on only speech data.
|
||||||
|
"""
|
||||||
|
assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}'
|
||||||
|
|
||||||
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
|
# chopping the inputs by the maximum actual length.
|
||||||
|
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||||
|
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||||
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
|
if raw_mels is not None:
|
||||||
|
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||||
|
|
||||||
|
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
|
||||||
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
|
if raw_mels is not None:
|
||||||
|
mel_inp = F.pad(raw_mels, (0, 4))
|
||||||
|
else:
|
||||||
|
mel_inp = mel_codes
|
||||||
|
mel_emb = self.mel_embedding(mel_inp)
|
||||||
|
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
|
||||||
|
mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
|
||||||
|
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||||
|
return loss_mel.mean()
|
||||||
|
|
||||||
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
|
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
||||||
|
if not hasattr(self, 'inference_model'):
|
||||||
|
# TODO: Decouple gpt_config from this inference model.
|
||||||
|
gpt_config = GPT2Config(vocab_size=self.max_mel_tokens,
|
||||||
|
n_positions=seq_length,
|
||||||
|
n_ctx=seq_length,
|
||||||
|
n_embd=self.model_dim,
|
||||||
|
n_layer=self.layers,
|
||||||
|
n_head=self.heads,
|
||||||
|
gradient_checkpointing=False,
|
||||||
|
use_cache=True)
|
||||||
|
self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
||||||
|
self.gpt.wte = self.mel_embedding
|
||||||
|
|
||||||
|
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||||
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
|
||||||
|
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||||
|
conds = []
|
||||||
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
|
conds = torch.stack(conds, dim=1)
|
||||||
|
|
||||||
|
emb = torch.cat([conds, text_emb], dim=1)
|
||||||
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
||||||
|
fake_inputs = torch.full((emb.shape[0], conds.shape[1]+emb.shape[1],), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
|
fake_inputs[:,-1] = self.start_mel_token
|
||||||
|
|
||||||
|
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, eos_token_id=self.stop_mel_token,
|
||||||
|
max_length=seq_length, **hf_generate_kwargs)
|
||||||
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gpt = UnifiedVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True, max_conditioning_inputs=4)
|
||||||
|
l = gpt(torch.randn(2, 3, 80, 800),
|
||||||
|
torch.randint(high=120, size=(2,120)),
|
||||||
|
torch.tensor([32, 120]),
|
||||||
|
torch.randint(high=8192, size=(2,250)),
|
||||||
|
torch.tensor([250*256,195*256]))
|
||||||
|
gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
torch
|
||||||
|
torchaudio
|
||||||
|
rotary_embedding_torch
|
||||||
|
transformers
|
||||||
|
tokenizers
|
||||||
|
pyfastmp3decoder
|
||||||
|
inflect
|
44
utils/audio.py
Normal file
44
utils/audio.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav_to_torch(full_path):
|
||||||
|
sampling_rate, data = read(full_path)
|
||||||
|
if data.dtype == np.int32:
|
||||||
|
norm_fix = 2 ** 31
|
||||||
|
elif data.dtype == np.int16:
|
||||||
|
norm_fix = 2 ** 15
|
||||||
|
elif data.dtype == np.float16 or data.dtype == np.float32:
|
||||||
|
norm_fix = 1.
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f"Provided data dtype not supported: {data.dtype}")
|
||||||
|
return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(audiopath, sampling_rate):
|
||||||
|
if audiopath[-4:] == '.wav':
|
||||||
|
audio, lsr = load_wav_to_torch(audiopath)
|
||||||
|
elif audiopath[-4:] == '.mp3':
|
||||||
|
# https://github.com/neonbjb/pyfastmp3decoder - Definitely worth it.
|
||||||
|
from pyfastmp3decoder.mp3decoder import load_mp3
|
||||||
|
audio, lsr = load_mp3(audiopath, sampling_rate)
|
||||||
|
audio = torch.FloatTensor(audio)
|
||||||
|
|
||||||
|
# Remove any channel data.
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
if audio.shape[0] < 5:
|
||||||
|
audio = audio[0]
|
||||||
|
else:
|
||||||
|
assert audio.shape[1] < 5
|
||||||
|
audio = audio[:, 0]
|
||||||
|
|
||||||
|
if lsr != sampling_rate:
|
||||||
|
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||||
|
|
||||||
|
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||||
|
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||||
|
if torch.any(audio > 2) or not torch.any(audio < 0):
|
||||||
|
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||||
|
audio.clip_(-1, 1)
|
||||||
|
|
||||||
|
return audio.unsqueeze(0)
|
1232
utils/diffusion.py
Normal file
1232
utils/diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
173
utils/tokenizer.py
Normal file
173
utils/tokenizer.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
import inflect
|
||||||
|
import torch
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# Regular expression matching whitespace:
|
||||||
|
from unidecode import unidecode
|
||||||
|
|
||||||
|
_whitespace_re = re.compile(r'\s+')
|
||||||
|
|
||||||
|
|
||||||
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||||
|
('mrs', 'misess'),
|
||||||
|
('mr', 'mister'),
|
||||||
|
('dr', 'doctor'),
|
||||||
|
('st', 'saint'),
|
||||||
|
('co', 'company'),
|
||||||
|
('jr', 'junior'),
|
||||||
|
('maj', 'major'),
|
||||||
|
('gen', 'general'),
|
||||||
|
('drs', 'doctors'),
|
||||||
|
('rev', 'reverend'),
|
||||||
|
('lt', 'lieutenant'),
|
||||||
|
('hon', 'honorable'),
|
||||||
|
('sgt', 'sergeant'),
|
||||||
|
('capt', 'captain'),
|
||||||
|
('esq', 'esquire'),
|
||||||
|
('ltd', 'limited'),
|
||||||
|
('col', 'colonel'),
|
||||||
|
('ft', 'fort'),
|
||||||
|
]]
|
||||||
|
|
||||||
|
|
||||||
|
def expand_abbreviations(text):
|
||||||
|
for regex, replacement in _abbreviations:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
_inflect = inflect.engine()
|
||||||
|
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||||
|
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||||
|
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||||
|
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||||
|
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||||
|
_number_re = re.compile(r'[0-9]+')
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m):
|
||||||
|
return m.group(1).replace('.', ' point ')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_dollars(m):
|
||||||
|
match = m.group(1)
|
||||||
|
parts = match.split('.')
|
||||||
|
if len(parts) > 2:
|
||||||
|
return match + ' dollars' # Unexpected format
|
||||||
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
|
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||||
|
if dollars and cents:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||||
|
elif dollars:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
return '%s %s' % (dollars, dollar_unit)
|
||||||
|
elif cents:
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s' % (cents, cent_unit)
|
||||||
|
else:
|
||||||
|
return 'zero dollars'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m):
|
||||||
|
return _inflect.number_to_words(m.group(0))
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m):
|
||||||
|
num = int(m.group(0))
|
||||||
|
if num > 1000 and num < 3000:
|
||||||
|
if num == 2000:
|
||||||
|
return 'two thousand'
|
||||||
|
elif num > 2000 and num < 2010:
|
||||||
|
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||||
|
elif num % 100 == 0:
|
||||||
|
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(num, andword='')
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_numbers(text):
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||||
|
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||||
|
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||||
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def expand_numbers(text):
|
||||||
|
return normalize_numbers(text)
|
||||||
|
|
||||||
|
|
||||||
|
def lowercase(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def collapse_whitespace(text):
|
||||||
|
return re.sub(_whitespace_re, ' ', text)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_ascii(text):
|
||||||
|
return unidecode(text)
|
||||||
|
|
||||||
|
|
||||||
|
def basic_cleaners(text):
|
||||||
|
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||||
|
text = lowercase(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def transliteration_cleaners(text):
|
||||||
|
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||||
|
text = convert_to_ascii(text)
|
||||||
|
text = lowercase(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def english_cleaners(text):
|
||||||
|
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||||
|
text = convert_to_ascii(text)
|
||||||
|
text = lowercase(text)
|
||||||
|
text = expand_numbers(text)
|
||||||
|
text = expand_abbreviations(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
text = text.replace('"', '')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceBpeTokenizer:
|
||||||
|
def __init__(self, vocab_file='data/tokenizer.json'):
|
||||||
|
if vocab_file is not None:
|
||||||
|
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||||
|
|
||||||
|
def preprocess_text(self, txt):
|
||||||
|
txt = english_cleaners(txt)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
def encode(self, txt):
|
||||||
|
txt = self.preprocess_text(txt)
|
||||||
|
txt = txt.replace(' ', '[SPACE]')
|
||||||
|
return self.tokenizer.encode(txt).ids
|
||||||
|
|
||||||
|
def decode(self, seq):
|
||||||
|
if isinstance(seq, torch.Tensor):
|
||||||
|
seq = seq.cpu().numpy()
|
||||||
|
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
|
||||||
|
txt = txt.replace('[SPACE]', ' ')
|
||||||
|
txt = txt.replace('[STOP]', '')
|
||||||
|
txt = txt.replace('[UNK]', '')
|
||||||
|
return txt
|
Loading…
Reference in New Issue
Block a user