forked from mrq/tortoise-tts
Modifications to support "v1.5"
This commit is contained in:
parent
9f1aa57b8d
commit
79c74c1484
72
do_tts.py
72
do_tts.py
|
@ -8,14 +8,14 @@ import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import progressbar
|
import progressbar
|
||||||
|
|
||||||
from models.dvae import DiscreteVAE
|
from models.diffusion_decoder import DiffusionTts
|
||||||
from models.autoregressive import UnifiedVoice
|
from models.autoregressive import UnifiedVoice
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from models.arch_util import TorchMelSpectrogram
|
from models.arch_util import TorchMelSpectrogram
|
||||||
from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
|
|
||||||
from models.text_voice_clip import VoiceCLIP
|
from models.text_voice_clip import VoiceCLIP
|
||||||
from utils.audio import load_audio
|
from models.vocoder import UnivNetGenerator
|
||||||
|
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
|
||||||
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
|
||||||
from utils.tokenizer import VoiceBpeTokenizer
|
from utils.tokenizer import VoiceBpeTokenizer
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ pbar = None
|
||||||
def download_models():
|
def download_models():
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
|
||||||
'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin',
|
|
||||||
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
|
||||||
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
|
||||||
}
|
}
|
||||||
|
@ -47,12 +46,14 @@ def download_models():
|
||||||
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
request.urlretrieve(url, f'.models/{model_name}', show_progress)
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
|
||||||
"""
|
"""
|
||||||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||||
"""
|
"""
|
||||||
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
|
||||||
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
|
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
|
||||||
|
conditioning_free=True, conditioning_free_k=1)
|
||||||
|
|
||||||
|
|
||||||
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
def load_conditioning(path, sample_rate=22050, cond_length=132300):
|
||||||
|
@ -94,26 +95,26 @@ def fix_autoregressive_output(codes, stop_token):
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
|
|
||||||
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, mean=False):
|
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
|
||||||
"""
|
"""
|
||||||
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mel = dvae_model.decode(mel_codes)[0]
|
cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
|
||||||
|
# Pad MEL to multiples of 32
|
||||||
# Pad MEL to multiples of 2048//spectrogram_compression_factor
|
msl = mel_codes.shape[-1]
|
||||||
msl = mel.shape[-1]
|
dsl = 32
|
||||||
dsl = 2048 // spectrogram_compression_factor
|
|
||||||
gap = dsl - (msl % dsl)
|
gap = dsl - (msl % dsl)
|
||||||
if gap > 0:
|
if gap > 0:
|
||||||
mel = torch.nn.functional.pad(mel, (0, gap))
|
mel = torch.nn.functional.pad(mel_codes, (0, gap))
|
||||||
|
|
||||||
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
|
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
|
||||||
if mean:
|
if mean:
|
||||||
return diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
|
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
|
||||||
model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
|
model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
||||||
else:
|
else:
|
||||||
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
|
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
|
||||||
|
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -145,12 +146,6 @@ if __name__ == '__main__':
|
||||||
download_models()
|
download_models()
|
||||||
|
|
||||||
for voice in args.voice.split(','):
|
for voice in args.voice.split(','):
|
||||||
print("Loading GPT TTS..")
|
|
||||||
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
|
|
||||||
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
|
|
||||||
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
|
||||||
stop_mel_token = autoregressive.stop_mel_token
|
|
||||||
|
|
||||||
print("Loading data..")
|
print("Loading data..")
|
||||||
tokenizer = VoiceBpeTokenizer()
|
tokenizer = VoiceBpeTokenizer()
|
||||||
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
||||||
|
@ -160,7 +155,15 @@ if __name__ == '__main__':
|
||||||
for cond_path in cond_paths:
|
for cond_path in cond_paths:
|
||||||
c, cond_wav = load_conditioning(cond_path)
|
c, cond_wav = load_conditioning(cond_path)
|
||||||
conds.append(c)
|
conds.append(c)
|
||||||
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
|
conds = torch.stack(conds, dim=1)
|
||||||
|
cond_diffusion = cond_wav[:, :88200] # The diffusion model expects <= 88200 conditioning samples.
|
||||||
|
|
||||||
|
print("Loading GPT TTS..")
|
||||||
|
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
|
||||||
|
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False,
|
||||||
|
average_conditioning_embeddings=True).cuda().eval()
|
||||||
|
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
|
||||||
|
stop_mel_token = autoregressive.stop_mel_token
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
print("Performing autoregressive inference..")
|
print("Performing autoregressive inference..")
|
||||||
|
@ -194,20 +197,25 @@ if __name__ == '__main__':
|
||||||
# Delete the autoregressive and clip models to free up GPU memory
|
# Delete the autoregressive and clip models to free up GPU memory
|
||||||
del samples, clip
|
del samples, clip
|
||||||
|
|
||||||
print("Loading DVAE..")
|
|
||||||
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
|
|
||||||
record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
|
|
||||||
dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)
|
|
||||||
print("Loading Diffusion Model..")
|
print("Loading Diffusion Model..")
|
||||||
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
|
diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
|
||||||
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
|
channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], token_conditioning_resolutions=[1,4,8],
|
||||||
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
|
dropout=0, attention_resolutions=[4,8], num_heads=8, kernel_size=3, scale_factor=2,
|
||||||
|
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
|
||||||
|
conditioning_expansion=1)
|
||||||
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
|
||||||
|
diffusion = diffusion.cuda().eval()
|
||||||
|
print("Loading vocoder..")
|
||||||
|
vocoder = UnivNetGenerator()
|
||||||
|
vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
|
||||||
|
vocoder = vocoder.cuda()
|
||||||
|
vocoder.eval(inference=True)
|
||||||
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
|
||||||
|
|
||||||
print("Performing vocoding..")
|
print("Performing vocoding..")
|
||||||
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
|
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
|
||||||
for b in range(best_results.shape[0]):
|
for b in range(best_results.shape[0]):
|
||||||
code = best_results[b].unsqueeze(0)
|
code = best_results[b].unsqueeze(0)
|
||||||
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)
|
mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 22050)
|
wav = vocoder.inference(mel)
|
||||||
|
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)
|
||||||
|
|
|
@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
attn_blocks=6,
|
attn_blocks=6,
|
||||||
num_attn_heads=4,
|
num_attn_heads=4,
|
||||||
do_checkpointing=False):
|
do_checkpointing=False,
|
||||||
|
mean=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||||
|
@ -201,11 +202,15 @@ class ConditioningEncoder(nn.Module):
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
self.do_checkpointing = do_checkpointing
|
self.do_checkpointing = do_checkpointing
|
||||||
|
self.mean = mean
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = self.attn(h)
|
h = self.attn(h)
|
||||||
return h[:, :, 0]
|
if self.mean:
|
||||||
|
return h.mean(dim=2)
|
||||||
|
else:
|
||||||
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
|
@ -275,7 +280,7 @@ class UnifiedVoice(nn.Module):
|
||||||
mel_length_compression=1024, number_text_tokens=256,
|
mel_length_compression=1024, number_text_tokens=256,
|
||||||
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
|
||||||
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
|
||||||
checkpointing=True):
|
checkpointing=True, average_conditioning_embeddings=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
layers: Number of layers in transformer stack.
|
layers: Number of layers in transformer stack.
|
||||||
|
@ -294,6 +299,7 @@ class UnifiedVoice(nn.Module):
|
||||||
train_solo_embeddings:
|
train_solo_embeddings:
|
||||||
use_mel_codes_as_input:
|
use_mel_codes_as_input:
|
||||||
checkpointing:
|
checkpointing:
|
||||||
|
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -311,6 +317,7 @@ class UnifiedVoice(nn.Module):
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
if use_mel_codes_as_input:
|
if use_mel_codes_as_input:
|
||||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||||
|
@ -408,6 +415,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||||
|
@ -446,6 +455,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
|
||||||
|
@ -472,6 +483,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
||||||
if raw_mels is not None:
|
if raw_mels is not None:
|
||||||
|
@ -508,6 +521,8 @@ class UnifiedVoice(nn.Module):
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
|
if self.average_conditioning_embeddings:
|
||||||
|
conds = conds.mean(dim=1).unsqueeze(1)
|
||||||
|
|
||||||
emb = torch.cat([conds, text_emb], dim=1)
|
emb = torch.cat([conds, text_emb], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
|
598
models/diffusion_decoder.py
Normal file
598
models/diffusion_decoder.py
Normal file
|
@ -0,0 +1,598 @@
|
||||||
|
"""
|
||||||
|
This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
|
||||||
|
and an audio conditioning input. It has also been simplified somewhat.
|
||||||
|
Credit: https://github.com/openai/improved-diffusion
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import autocast
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
from x_transformers import ContinuousTransformerWrapper, Encoder
|
||||||
|
|
||||||
|
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
|
||||||
|
|
||||||
|
|
||||||
|
def is_latent(t):
|
||||||
|
return t.dtype == torch.float
|
||||||
|
|
||||||
|
|
||||||
|
def is_sequence(t):
|
||||||
|
return t.dtype == torch.long
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_multiple(base, multiple):
|
||||||
|
res = base % multiple
|
||||||
|
if res == 0:
|
||||||
|
return base
|
||||||
|
return base + (multiple - res)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
|
).to(device=timesteps.device)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, x, emb):
|
||||||
|
"""
|
||||||
|
Apply the module to `x` given `emb` timestep embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
|
"""
|
||||||
|
A sequential module that passes timestep embeddings to the children that
|
||||||
|
support it as an extra input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
for layer in self:
|
||||||
|
if isinstance(layer, TimestepBlock):
|
||||||
|
x = layer(x, emb)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(TimestepBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
emb_channels,
|
||||||
|
dropout,
|
||||||
|
out_channels=None,
|
||||||
|
kernel_size=3,
|
||||||
|
efficient_config=True,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.emb_channels = emb_channels
|
||||||
|
self.dropout = dropout
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
|
||||||
|
eff_kernel = 1 if efficient_config else 3
|
||||||
|
eff_padding = 0 if efficient_config else 1
|
||||||
|
|
||||||
|
self.in_layers = nn.Sequential(
|
||||||
|
normalization(channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
Linear(
|
||||||
|
emb_channels,
|
||||||
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.out_layers = nn.Sequential(
|
||||||
|
normalization(self.out_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
zero_module(
|
||||||
|
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.out_channels == channels:
|
||||||
|
self.skip_connection = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
"""
|
||||||
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||||
|
|
||||||
|
:param x: an [N x C x ...] Tensor of features.
|
||||||
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
return checkpoint(
|
||||||
|
self._forward, x, emb
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x, emb):
|
||||||
|
h = self.in_layers(x)
|
||||||
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
|
while len(emb_out.shape) < len(h.shape):
|
||||||
|
emb_out = emb_out[..., None]
|
||||||
|
if self.use_scale_shift_norm:
|
||||||
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||||
|
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||||
|
h = out_norm(h) * (1 + scale) + shift
|
||||||
|
h = out_rest(h)
|
||||||
|
else:
|
||||||
|
h = h + emb_out
|
||||||
|
h = self.out_layers(h)
|
||||||
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||||
|
checkpoint for all other args.
|
||||||
|
"""
|
||||||
|
def __init__(self, wrap):
|
||||||
|
super().__init__()
|
||||||
|
self.wrap = wrap
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
|
||||||
|
partial = functools.partial(self.wrap, **kwargs)
|
||||||
|
return torch.utils.checkpoint.checkpoint(partial, x, *args)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedXTransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||||
|
to channels-last that XTransformer expects.
|
||||||
|
"""
|
||||||
|
def __init__(self, needs_permute=True, **xtransformer_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||||
|
self.needs_permute = needs_permute
|
||||||
|
|
||||||
|
for i in range(len(self.transformer.attn_layers.layers)):
|
||||||
|
n, b, r = self.transformer.attn_layers.layers[i]
|
||||||
|
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
if self.needs_permute:
|
||||||
|
x = x.permute(0,2,1)
|
||||||
|
h = self.transformer(x, **kwargs)
|
||||||
|
return h.permute(0,2,1)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionTts(nn.Module):
|
||||||
|
"""
|
||||||
|
The full UNet model with attention and timestep embedding.
|
||||||
|
|
||||||
|
Customized to be conditioned on an aligned prior derived from a autoregressive
|
||||||
|
GPT-style model.
|
||||||
|
|
||||||
|
:param in_channels: channels in the input Tensor.
|
||||||
|
:param in_latent_channels: channels from the input latent.
|
||||||
|
:param model_channels: base channel count for the model.
|
||||||
|
:param out_channels: channels in the output Tensor.
|
||||||
|
:param num_res_blocks: number of residual blocks per downsample.
|
||||||
|
:param attention_resolutions: a collection of downsample rates at which
|
||||||
|
attention will take place. May be a set, list, or tuple.
|
||||||
|
For example, if this contains 4, then at 4x downsampling, attention
|
||||||
|
will be used.
|
||||||
|
:param dropout: the dropout probability.
|
||||||
|
:param channel_mult: channel multiplier for each level of the UNet.
|
||||||
|
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||||
|
downsampling.
|
||||||
|
:param num_heads: the number of attention heads in each attention layer.
|
||||||
|
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||||
|
a fixed channel width per attention head.
|
||||||
|
:param num_heads_upsample: works with num_heads to set a different number
|
||||||
|
of heads for upsampling. Deprecated.
|
||||||
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||||
|
:param resblock_updown: use residual blocks for up/downsampling.
|
||||||
|
:param use_new_attention_order: use a different attention pattern for potentially
|
||||||
|
increased efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_channels,
|
||||||
|
in_channels=1,
|
||||||
|
in_latent_channels=1024,
|
||||||
|
in_tokens=8193,
|
||||||
|
conditioning_dim_factor=8,
|
||||||
|
conditioning_expansion=4,
|
||||||
|
out_channels=2, # mean and variance
|
||||||
|
dropout=0,
|
||||||
|
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||||
|
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||||
|
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||||
|
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
|
||||||
|
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
|
||||||
|
token_conditioning_resolutions=(1,16,),
|
||||||
|
attention_resolutions=(512,1024,2048),
|
||||||
|
conv_resample=True,
|
||||||
|
use_fp16=False,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
num_heads_upsample=-1,
|
||||||
|
kernel_size=3,
|
||||||
|
scale_factor=2,
|
||||||
|
time_embed_dim_multiplier=4,
|
||||||
|
freeze_main_net=False,
|
||||||
|
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
|
||||||
|
use_scale_shift_norm=True,
|
||||||
|
# Parameters for regularization.
|
||||||
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
|
# Parameters for super-sampling.
|
||||||
|
super_sampling=False,
|
||||||
|
super_sampling_max_noising_factor=.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if num_heads_upsample == -1:
|
||||||
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if super_sampling:
|
||||||
|
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.attention_resolutions = attention_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.channel_mult = channel_mult
|
||||||
|
self.conv_resample = conv_resample
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_head_channels = num_head_channels
|
||||||
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
self.super_sampling_enabled = super_sampling
|
||||||
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
|
self.enable_fp16 = use_fp16
|
||||||
|
self.alignment_size = 2 ** (len(channel_mult)+1)
|
||||||
|
self.freeze_main_net = freeze_main_net
|
||||||
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
down_kernel = 1 if efficient_convs else 3
|
||||||
|
|
||||||
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
Linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
Linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_dim = model_channels * conditioning_dim_factor
|
||||||
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
|
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
|
||||||
|
# transformer network.
|
||||||
|
self.code_converter = nn.Sequential(
|
||||||
|
nn.Embedding(in_tokens, conditioning_dim),
|
||||||
|
CheckpointedXTransformerEncoder(
|
||||||
|
needs_permute=False,
|
||||||
|
max_seq_len=-1,
|
||||||
|
use_pos_emb=False,
|
||||||
|
attn_layers=Encoder(
|
||||||
|
dim=conditioning_dim,
|
||||||
|
depth=3,
|
||||||
|
heads=num_heads,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_emb_dim=True,
|
||||||
|
)
|
||||||
|
))
|
||||||
|
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
|
||||||
|
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
|
||||||
|
if in_channels > 60: # It's a spectrogram.
|
||||||
|
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
|
||||||
|
CheckpointedXTransformerEncoder(
|
||||||
|
needs_permute=True,
|
||||||
|
max_seq_len=-1,
|
||||||
|
use_pos_emb=False,
|
||||||
|
attn_layers=Encoder(
|
||||||
|
dim=conditioning_dim,
|
||||||
|
depth=4,
|
||||||
|
heads=num_heads,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_emb_dim=True,
|
||||||
|
)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
|
||||||
|
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
|
||||||
|
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
|
||||||
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
|
||||||
|
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||||
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
|
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||||
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
|
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||||
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
|
)
|
||||||
|
self.conditioning_expansion = conditioning_expansion
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
token_conditioning_blocks = []
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
|
||||||
|
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||||
|
if ds in token_conditioning_resolutions:
|
||||||
|
token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
|
||||||
|
token_conditioning_block.weight.data *= .02
|
||||||
|
self.input_blocks.append(token_conditioning_block)
|
||||||
|
token_conditioning_blocks.append(token_conditioning_block)
|
||||||
|
|
||||||
|
for _ in range(num_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=int(mult * model_channels),
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = int(mult * model_channels)
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
Downsample(
|
||||||
|
ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
),
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
|
||||||
|
for i in range(num_blocks + 1):
|
||||||
|
ich = input_block_chans.pop()
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch + ich,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=int(model_channels * mult),
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
efficient_config=efficient_convs,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = int(model_channels * mult)
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=num_heads_upsample,
|
||||||
|
num_head_channels=num_head_channels,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if level and i == num_blocks:
|
||||||
|
out_ch = ch
|
||||||
|
layers.append(
|
||||||
|
Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
|
||||||
|
)
|
||||||
|
ds //= 2
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def fix_alignment(self, x, aligned_conditioning):
|
||||||
|
"""
|
||||||
|
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
|
||||||
|
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
|
||||||
|
"""
|
||||||
|
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||||
|
if cm != 0:
|
||||||
|
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||||
|
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||||
|
# Also fix aligned_latent, which is aligned to x.
|
||||||
|
if is_latent(aligned_conditioning):
|
||||||
|
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||||
|
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||||
|
else:
|
||||||
|
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
|
||||||
|
return x, aligned_conditioning
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||||
|
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||||
|
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
|
||||||
|
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
assert conditioning_input is not None
|
||||||
|
if self.super_sampling_enabled:
|
||||||
|
assert lr_input is not None
|
||||||
|
if self.training and self.super_sampling_max_noising_factor > 0:
|
||||||
|
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||||
|
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
|
||||||
|
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||||
|
x = torch.cat([x, lr_input], dim=1)
|
||||||
|
|
||||||
|
# Shuffle aligned_latent to BxCxS format
|
||||||
|
if is_latent(aligned_conditioning):
|
||||||
|
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||||
|
orig_x_shape = x.shape[-1]
|
||||||
|
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||||
|
|
||||||
|
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||||
|
hs = []
|
||||||
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
|
|
||||||
|
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
|
||||||
|
if conditioning_free:
|
||||||
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||||
|
else:
|
||||||
|
cond_emb = self.contextual_embedder(conditioning_input)
|
||||||
|
if len(cond_emb.shape) == 3: # Just take the first element.
|
||||||
|
cond_emb = cond_emb[:, :, 0]
|
||||||
|
if is_latent(aligned_conditioning):
|
||||||
|
code_emb = self.latent_converter(aligned_conditioning)
|
||||||
|
else:
|
||||||
|
code_emb = self.code_converter(aligned_conditioning)
|
||||||
|
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
|
||||||
|
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
|
||||||
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
|
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||||
|
device=code_emb.device) < self.unconditioned_percentage
|
||||||
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||||
|
code_emb)
|
||||||
|
|
||||||
|
# Everything after this comment is timestep dependent.
|
||||||
|
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||||
|
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||||
|
|
||||||
|
first = True
|
||||||
|
time_emb = time_emb.float()
|
||||||
|
h = x
|
||||||
|
for k, module in enumerate(self.input_blocks):
|
||||||
|
if isinstance(module, nn.Conv1d):
|
||||||
|
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
|
||||||
|
h = h + h_tok
|
||||||
|
else:
|
||||||
|
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
|
||||||
|
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
|
||||||
|
h = module(h, time_emb)
|
||||||
|
hs.append(h)
|
||||||
|
first = False
|
||||||
|
h = self.middle_block(h, time_emb)
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(h, time_emb)
|
||||||
|
|
||||||
|
# Last block also has autocast disabled for high-precision outputs.
|
||||||
|
h = h.float()
|
||||||
|
out = self.out(h)
|
||||||
|
|
||||||
|
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||||
|
extraneous_addition = 0
|
||||||
|
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
|
||||||
|
for p in params:
|
||||||
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
|
out = out + extraneous_addition * 0
|
||||||
|
|
||||||
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
clip = torch.randn(2, 1, 32868)
|
||||||
|
aligned_latent = torch.randn(2,388,1024)
|
||||||
|
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||||
|
cond = torch.randn(2, 1, 44000)
|
||||||
|
ts = torch.LongTensor([600, 600])
|
||||||
|
model = DiffusionTts(128,
|
||||||
|
channel_mult=[1,1.5,2, 3, 4, 6, 8],
|
||||||
|
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
|
||||||
|
token_conditioning_resolutions=[1,4,16,64],
|
||||||
|
attention_resolutions=[],
|
||||||
|
num_heads=8,
|
||||||
|
kernel_size=3,
|
||||||
|
scale_factor=2,
|
||||||
|
time_embed_dim_multiplier=4,
|
||||||
|
super_sampling=False,
|
||||||
|
efficient_convs=False)
|
||||||
|
# Test with latent aligned conditioning
|
||||||
|
o = model(clip, ts, aligned_latent, cond)
|
||||||
|
# Test with sequence aligned conditioning
|
||||||
|
o = model(clip, ts, aligned_sequence, cond)
|
325
models/vocoder.py
Normal file
325
models/vocoder.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
class KernelPredictor(torch.nn.Module):
|
||||||
|
''' Kernel predictor for the location-variable convolutions'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cond_channels,
|
||||||
|
conv_in_channels,
|
||||||
|
conv_out_channels,
|
||||||
|
conv_layers,
|
||||||
|
conv_kernel_size=3,
|
||||||
|
kpnet_hidden_channels=64,
|
||||||
|
kpnet_conv_size=3,
|
||||||
|
kpnet_dropout=0.0,
|
||||||
|
kpnet_nonlinear_activation="LeakyReLU",
|
||||||
|
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
cond_channels (int): number of channel for the conditioning sequence,
|
||||||
|
conv_in_channels (int): number of channel for the input sequence,
|
||||||
|
conv_out_channels (int): number of channel for the output sequence,
|
||||||
|
conv_layers (int): number of layers
|
||||||
|
'''
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_in_channels = conv_in_channels
|
||||||
|
self.conv_out_channels = conv_out_channels
|
||||||
|
self.conv_kernel_size = conv_kernel_size
|
||||||
|
self.conv_layers = conv_layers
|
||||||
|
|
||||||
|
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||||
|
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||||
|
|
||||||
|
self.input_conv = nn.Sequential(
|
||||||
|
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.residual_convs = nn.ModuleList()
|
||||||
|
padding = (kpnet_conv_size - 1) // 2
|
||||||
|
for _ in range(3):
|
||||||
|
self.residual_convs.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Dropout(kpnet_dropout),
|
||||||
|
nn.utils.weight_norm(
|
||||||
|
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
||||||
|
bias=True)),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
nn.utils.weight_norm(
|
||||||
|
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
|
||||||
|
bias=True)),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.kernel_conv = nn.utils.weight_norm(
|
||||||
|
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
|
||||||
|
self.bias_conv = nn.utils.weight_norm(
|
||||||
|
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
|
||||||
|
|
||||||
|
def forward(self, c):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||||
|
'''
|
||||||
|
batch, _, cond_length = c.shape
|
||||||
|
c = self.input_conv(c)
|
||||||
|
for residual_conv in self.residual_convs:
|
||||||
|
residual_conv.to(c.device)
|
||||||
|
c = c + residual_conv(c)
|
||||||
|
k = self.kernel_conv(c)
|
||||||
|
b = self.bias_conv(c)
|
||||||
|
kernels = k.contiguous().view(
|
||||||
|
batch,
|
||||||
|
self.conv_layers,
|
||||||
|
self.conv_in_channels,
|
||||||
|
self.conv_out_channels,
|
||||||
|
self.conv_kernel_size,
|
||||||
|
cond_length,
|
||||||
|
)
|
||||||
|
bias = b.contiguous().view(
|
||||||
|
batch,
|
||||||
|
self.conv_layers,
|
||||||
|
self.conv_out_channels,
|
||||||
|
cond_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernels, bias
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
nn.utils.remove_weight_norm(self.input_conv[0])
|
||||||
|
nn.utils.remove_weight_norm(self.kernel_conv)
|
||||||
|
nn.utils.remove_weight_norm(self.bias_conv)
|
||||||
|
for block in self.residual_convs:
|
||||||
|
nn.utils.remove_weight_norm(block[1])
|
||||||
|
nn.utils.remove_weight_norm(block[3])
|
||||||
|
|
||||||
|
|
||||||
|
class LVCBlock(torch.nn.Module):
|
||||||
|
'''the location-variable convolutions'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
cond_channels,
|
||||||
|
stride,
|
||||||
|
dilations=[1, 3, 9, 27],
|
||||||
|
lReLU_slope=0.2,
|
||||||
|
conv_kernel_size=3,
|
||||||
|
cond_hop_length=256,
|
||||||
|
kpnet_hidden_channels=64,
|
||||||
|
kpnet_conv_size=3,
|
||||||
|
kpnet_dropout=0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_hop_length = cond_hop_length
|
||||||
|
self.conv_layers = len(dilations)
|
||||||
|
self.conv_kernel_size = conv_kernel_size
|
||||||
|
|
||||||
|
self.kernel_predictor = KernelPredictor(
|
||||||
|
cond_channels=cond_channels,
|
||||||
|
conv_in_channels=in_channels,
|
||||||
|
conv_out_channels=2 * in_channels,
|
||||||
|
conv_layers=len(dilations),
|
||||||
|
conv_kernel_size=conv_kernel_size,
|
||||||
|
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||||
|
kpnet_conv_size=kpnet_conv_size,
|
||||||
|
kpnet_dropout=kpnet_dropout,
|
||||||
|
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convt_pre = nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
|
||||||
|
padding=stride // 2 + stride % 2, output_padding=stride % 2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_blocks = nn.ModuleList()
|
||||||
|
for dilation in dilations:
|
||||||
|
self.conv_blocks.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
|
||||||
|
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
''' forward propagation of the location-variable convolutions.
|
||||||
|
Args:
|
||||||
|
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||||
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the output sequence (batch, in_channels, in_length)
|
||||||
|
'''
|
||||||
|
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||||
|
|
||||||
|
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||||
|
kernels, bias = self.kernel_predictor(c)
|
||||||
|
|
||||||
|
for i, conv in enumerate(self.conv_blocks):
|
||||||
|
output = conv(x) # (B, c_g, stride * L')
|
||||||
|
|
||||||
|
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||||
|
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||||
|
|
||||||
|
output = self.location_variable_convolution(output, k, b,
|
||||||
|
hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
|
||||||
|
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||||
|
output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
||||||
|
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||||
|
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||||
|
Args:
|
||||||
|
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||||
|
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||||
|
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||||
|
dilation (int): the dilation of convolution.
|
||||||
|
hop_size (int): the hop_size of the conditioning sequence.
|
||||||
|
Returns:
|
||||||
|
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||||
|
'''
|
||||||
|
batch, _, in_length = x.shape
|
||||||
|
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||||
|
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||||
|
|
||||||
|
padding = dilation * int((kernel_size - 1) / 2)
|
||||||
|
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
|
||||||
|
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||||
|
|
||||||
|
if hop_size < dilation:
|
||||||
|
x = F.pad(x, (0, dilation), 'constant', 0)
|
||||||
|
x = x.unfold(3, dilation,
|
||||||
|
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||||
|
x = x[:, :, :, :, :hop_size]
|
||||||
|
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||||
|
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||||
|
|
||||||
|
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
|
||||||
|
o = o.to(memory_format=torch.channels_last_3d)
|
||||||
|
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||||
|
o = o + bias
|
||||||
|
o = o.contiguous().view(batch, out_channels, -1)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
self.kernel_predictor.remove_weight_norm()
|
||||||
|
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||||
|
for block in self.conv_blocks:
|
||||||
|
nn.utils.remove_weight_norm(block[1])
|
||||||
|
|
||||||
|
|
||||||
|
class UnivNetGenerator(nn.Module):
|
||||||
|
"""UnivNet Generator"""
|
||||||
|
|
||||||
|
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
|
||||||
|
# Below are MEL configurations options that this generator requires.
|
||||||
|
hop_length=256, n_mel_channels=100):
|
||||||
|
super(UnivNetGenerator, self).__init__()
|
||||||
|
self.mel_channel = n_mel_channels
|
||||||
|
self.noise_dim = noise_dim
|
||||||
|
self.hop_length = hop_length
|
||||||
|
channel_size = channel_size
|
||||||
|
kpnet_conv_size = kpnet_conv_size
|
||||||
|
|
||||||
|
self.res_stack = nn.ModuleList()
|
||||||
|
hop_length = 1
|
||||||
|
for stride in strides:
|
||||||
|
hop_length = stride * hop_length
|
||||||
|
self.res_stack.append(
|
||||||
|
LVCBlock(
|
||||||
|
channel_size,
|
||||||
|
n_mel_channels,
|
||||||
|
stride=stride,
|
||||||
|
dilations=dilations,
|
||||||
|
lReLU_slope=lReLU_slope,
|
||||||
|
cond_hop_length=hop_length,
|
||||||
|
kpnet_conv_size=kpnet_conv_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_pre = \
|
||||||
|
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
|
||||||
|
|
||||||
|
self.conv_post = nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, c, z):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
||||||
|
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
||||||
|
|
||||||
|
'''
|
||||||
|
z = self.conv_pre(z) # (B, c_g, L)
|
||||||
|
|
||||||
|
for res_block in self.res_stack:
|
||||||
|
res_block.to(z.device)
|
||||||
|
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
||||||
|
|
||||||
|
z = self.conv_post(z) # (B, 1, L * 256)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def eval(self, inference=False):
|
||||||
|
super(UnivNetGenerator, self).eval()
|
||||||
|
# don't remove weight norm while validation in training loop
|
||||||
|
if inference:
|
||||||
|
self.remove_weight_norm()
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
|
||||||
|
nn.utils.remove_weight_norm(self.conv_pre)
|
||||||
|
|
||||||
|
for layer in self.conv_post:
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
|
||||||
|
for res_block in self.res_stack:
|
||||||
|
res_block.remove_weight_norm()
|
||||||
|
|
||||||
|
def inference(self, c, z=None):
|
||||||
|
# pad input mel with zeros to cut artifact
|
||||||
|
# see https://github.com/seungwonpark/melgan/issues/8
|
||||||
|
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
||||||
|
mel = torch.cat((c, zero), dim=2)
|
||||||
|
|
||||||
|
if z is None:
|
||||||
|
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
||||||
|
|
||||||
|
audio = self.forward(mel, z)
|
||||||
|
audio = audio[:, :, :-(self.hop_length * 10)]
|
||||||
|
audio = audio.clamp(min=-1, max=1)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = UnivNetGenerator()
|
||||||
|
|
||||||
|
c = torch.randn(3, 100, 10)
|
||||||
|
z = torch.randn(3, 64, 10)
|
||||||
|
print(c.shape)
|
||||||
|
|
||||||
|
y = model(c, z)
|
||||||
|
print(y.shape)
|
||||||
|
assert y.shape == torch.Size([3, 1, 2560])
|
||||||
|
|
||||||
|
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(pytorch_total_params)
|
|
@ -6,4 +6,5 @@ tokenizers
|
||||||
inflect
|
inflect
|
||||||
progressbar
|
progressbar
|
||||||
einops
|
einops
|
||||||
unidecode
|
unidecode
|
||||||
|
x-transformers
|
|
@ -3,6 +3,8 @@ import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io.wavfile import read
|
from scipy.io.wavfile import read
|
||||||
|
|
||||||
|
from utils.stft import STFT
|
||||||
|
|
||||||
|
|
||||||
def load_wav_to_torch(full_path):
|
def load_wav_to_torch(full_path):
|
||||||
sampling_rate, data = read(full_path)
|
sampling_rate, data = read(full_path)
|
||||||
|
@ -43,4 +45,86 @@ def load_audio(audiopath, sampling_rate):
|
||||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||||
audio.clip_(-1, 1)
|
audio.clip_(-1, 1)
|
||||||
|
|
||||||
return audio.unsqueeze(0)
|
return audio.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||||
|
TACOTRON_MEL_MIN = -11.512925148010254
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize_tacotron_mel(norm_mel):
|
||||||
|
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_tacotron_mel(mel):
|
||||||
|
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor
|
||||||
|
"""
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression(x, C=1):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor used to compress
|
||||||
|
"""
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
class TacotronSTFT(torch.nn.Module):
|
||||||
|
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
||||||
|
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
||||||
|
mel_fmax=8000.0):
|
||||||
|
super(TacotronSTFT, self).__init__()
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
mel_basis = librosa_mel_fn(
|
||||||
|
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
||||||
|
mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
self.register_buffer('mel_basis', mel_basis)
|
||||||
|
|
||||||
|
def spectral_normalize(self, magnitudes):
|
||||||
|
output = dynamic_range_compression(magnitudes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def spectral_de_normalize(self, magnitudes):
|
||||||
|
output = dynamic_range_decompression(magnitudes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def mel_spectrogram(self, y):
|
||||||
|
"""Computes mel-spectrograms from a batch of waves
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
||||||
|
|
||||||
|
RETURNS
|
||||||
|
-------
|
||||||
|
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
||||||
|
"""
|
||||||
|
assert(torch.min(y.data) >= -10)
|
||||||
|
assert(torch.max(y.data) <= 10)
|
||||||
|
y = torch.clip(y, min=-1, max=1)
|
||||||
|
|
||||||
|
magnitudes, phases = self.stft_fn.transform(y)
|
||||||
|
magnitudes = magnitudes.data
|
||||||
|
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
||||||
|
mel_output = self.spectral_normalize(mel_output)
|
||||||
|
return mel_output
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_univnet_mel(wav, do_normalization=False):
|
||||||
|
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
|
||||||
|
stft = stft.cuda()
|
||||||
|
mel = stft.mel_spectrogram(wav)
|
||||||
|
if do_normalization:
|
||||||
|
mel = normalize_tacotron_mel(mel)
|
||||||
|
return mel
|
|
@ -197,11 +197,17 @@ class GaussianDiffusion:
|
||||||
model_var_type,
|
model_var_type,
|
||||||
loss_type,
|
loss_type,
|
||||||
rescale_timesteps=False,
|
rescale_timesteps=False,
|
||||||
|
conditioning_free=False,
|
||||||
|
conditioning_free_k=1,
|
||||||
|
ramp_conditioning_free=True,
|
||||||
):
|
):
|
||||||
self.model_mean_type = ModelMeanType(model_mean_type)
|
self.model_mean_type = ModelMeanType(model_mean_type)
|
||||||
self.model_var_type = ModelVarType(model_var_type)
|
self.model_var_type = ModelVarType(model_var_type)
|
||||||
self.loss_type = LossType(loss_type)
|
self.loss_type = LossType(loss_type)
|
||||||
self.rescale_timesteps = rescale_timesteps
|
self.rescale_timesteps = rescale_timesteps
|
||||||
|
self.conditioning_free = conditioning_free
|
||||||
|
self.conditioning_free_k = conditioning_free_k
|
||||||
|
self.ramp_conditioning_free = ramp_conditioning_free
|
||||||
|
|
||||||
# Use float64 for accuracy.
|
# Use float64 for accuracy.
|
||||||
betas = np.array(betas, dtype=np.float64)
|
betas = np.array(betas, dtype=np.float64)
|
||||||
|
@ -332,10 +338,14 @@ class GaussianDiffusion:
|
||||||
B, C = x.shape[:2]
|
B, C = x.shape[:2]
|
||||||
assert t.shape == (B,)
|
assert t.shape == (B,)
|
||||||
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
|
||||||
|
if self.conditioning_free:
|
||||||
|
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
|
||||||
|
|
||||||
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
||||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||||
model_output, model_var_values = th.split(model_output, C, dim=1)
|
model_output, model_var_values = th.split(model_output, C, dim=1)
|
||||||
|
if self.conditioning_free:
|
||||||
|
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
|
||||||
if self.model_var_type == ModelVarType.LEARNED:
|
if self.model_var_type == ModelVarType.LEARNED:
|
||||||
model_log_variance = model_var_values
|
model_log_variance = model_var_values
|
||||||
model_variance = th.exp(model_log_variance)
|
model_variance = th.exp(model_log_variance)
|
||||||
|
@ -364,6 +374,14 @@ class GaussianDiffusion:
|
||||||
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
||||||
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
||||||
|
|
||||||
|
if self.conditioning_free:
|
||||||
|
if self.ramp_conditioning_free:
|
||||||
|
assert t.shape[0] == 1 # This should only be used in inference.
|
||||||
|
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
|
||||||
|
else:
|
||||||
|
cfk = self.conditioning_free_k
|
||||||
|
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
|
||||||
|
|
||||||
def process_xstart(x):
|
def process_xstart(x):
|
||||||
if denoised_fn is not None:
|
if denoised_fn is not None:
|
||||||
x = denoised_fn(x)
|
x = denoised_fn(x)
|
||||||
|
|
193
utils/stft.py
Normal file
193
utils/stft.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
"""
|
||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2017, Prem Seetharaman
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from this
|
||||||
|
software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||||
|
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||||
|
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from librosa.util import pad_center, tiny
|
||||||
|
import librosa.util as librosa_util
|
||||||
|
|
||||||
|
|
||||||
|
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||||
|
n_fft=800, dtype=np.float32, norm=None):
|
||||||
|
"""
|
||||||
|
# from librosa 0.6
|
||||||
|
Compute the sum-square envelope of a window function at a given hop length.
|
||||||
|
|
||||||
|
This is used to estimate modulation effects induced by windowing
|
||||||
|
observations in short-time fourier transforms.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window : string, tuple, number, callable, or list-like
|
||||||
|
Window specification, as in `get_window`
|
||||||
|
|
||||||
|
n_frames : int > 0
|
||||||
|
The number of analysis frames
|
||||||
|
|
||||||
|
hop_length : int > 0
|
||||||
|
The number of samples to advance between frames
|
||||||
|
|
||||||
|
win_length : [optional]
|
||||||
|
The length of the window function. By default, this matches `n_fft`.
|
||||||
|
|
||||||
|
n_fft : int > 0
|
||||||
|
The length of each analysis frame.
|
||||||
|
|
||||||
|
dtype : np.dtype
|
||||||
|
The data type of the output
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||||
|
The sum-squared envelope of the window function
|
||||||
|
"""
|
||||||
|
if win_length is None:
|
||||||
|
win_length = n_fft
|
||||||
|
|
||||||
|
n = n_fft + hop_length * (n_frames - 1)
|
||||||
|
x = np.zeros(n, dtype=dtype)
|
||||||
|
|
||||||
|
# Compute the squared window at the desired length
|
||||||
|
win_sq = get_window(window, win_length, fftbins=True)
|
||||||
|
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
||||||
|
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
||||||
|
|
||||||
|
# Fill the envelope
|
||||||
|
for i in range(n_frames):
|
||||||
|
sample = i * hop_length
|
||||||
|
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class STFT(torch.nn.Module):
|
||||||
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||||
|
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
||||||
|
window='hann'):
|
||||||
|
super(STFT, self).__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = window
|
||||||
|
self.forward_transform = None
|
||||||
|
scale = self.filter_length / self.hop_length
|
||||||
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2 + 1))
|
||||||
|
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
||||||
|
np.imag(fourier_basis[:cutoff, :])])
|
||||||
|
|
||||||
|
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||||
|
inverse_basis = torch.FloatTensor(
|
||||||
|
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||||
|
|
||||||
|
if window is not None:
|
||||||
|
assert(filter_length >= win_length)
|
||||||
|
# get window and zero center pad it to filter_length
|
||||||
|
fft_window = get_window(window, win_length, fftbins=True)
|
||||||
|
fft_window = pad_center(fft_window, filter_length)
|
||||||
|
fft_window = torch.from_numpy(fft_window).float()
|
||||||
|
|
||||||
|
# window the bases
|
||||||
|
forward_basis *= fft_window
|
||||||
|
inverse_basis *= fft_window
|
||||||
|
|
||||||
|
self.register_buffer('forward_basis', forward_basis.float())
|
||||||
|
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||||
|
|
||||||
|
def transform(self, input_data):
|
||||||
|
num_batches = input_data.size(0)
|
||||||
|
num_samples = input_data.size(1)
|
||||||
|
|
||||||
|
self.num_samples = num_samples
|
||||||
|
|
||||||
|
# similar to librosa, reflect-pad the input
|
||||||
|
input_data = input_data.view(num_batches, 1, num_samples)
|
||||||
|
input_data = F.pad(
|
||||||
|
input_data.unsqueeze(1),
|
||||||
|
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
||||||
|
mode='reflect')
|
||||||
|
input_data = input_data.squeeze(1)
|
||||||
|
|
||||||
|
forward_transform = F.conv1d(
|
||||||
|
input_data,
|
||||||
|
Variable(self.forward_basis, requires_grad=False),
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
cutoff = int((self.filter_length / 2) + 1)
|
||||||
|
real_part = forward_transform[:, :cutoff, :]
|
||||||
|
imag_part = forward_transform[:, cutoff:, :]
|
||||||
|
|
||||||
|
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||||
|
phase = torch.autograd.Variable(
|
||||||
|
torch.atan2(imag_part.data, real_part.data))
|
||||||
|
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
def inverse(self, magnitude, phase):
|
||||||
|
recombine_magnitude_phase = torch.cat(
|
||||||
|
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
||||||
|
|
||||||
|
inverse_transform = F.conv_transpose1d(
|
||||||
|
recombine_magnitude_phase,
|
||||||
|
Variable(self.inverse_basis, requires_grad=False),
|
||||||
|
stride=self.hop_length,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
if self.window is not None:
|
||||||
|
window_sum = window_sumsquare(
|
||||||
|
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length, n_fft=self.filter_length,
|
||||||
|
dtype=np.float32)
|
||||||
|
# remove modulation effects
|
||||||
|
approx_nonzero_indices = torch.from_numpy(
|
||||||
|
np.where(window_sum > tiny(window_sum))[0])
|
||||||
|
window_sum = torch.autograd.Variable(
|
||||||
|
torch.from_numpy(window_sum), requires_grad=False)
|
||||||
|
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
||||||
|
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||||
|
|
||||||
|
# scale by hop ratio
|
||||||
|
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||||
|
|
||||||
|
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
||||||
|
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
||||||
|
|
||||||
|
return inverse_transform
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
self.magnitude, self.phase = self.transform(input_data)
|
||||||
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
|
return reconstruction
|
Loading…
Reference in New Issue
Block a user