Modifications to support "v1.5"

This commit is contained in:
James Betker 2022-03-22 11:52:46 -06:00
parent 9f1aa57b8d
commit 79c74c1484
8 changed files with 1279 additions and 37 deletions

View File

@ -8,14 +8,14 @@ import torch.nn.functional as F
import torchaudio
import progressbar
from models.dvae import DiscreteVAE
from models.diffusion_decoder import DiffusionTts
from models.autoregressive import UnifiedVoice
from tqdm import tqdm
from models.arch_util import TorchMelSpectrogram
from models.discrete_diffusion_vocoder import DiscreteDiffusionVocoder
from models.text_voice_clip import VoiceCLIP
from utils.audio import load_audio
from models.vocoder import UnivNetGenerator
from utils.audio import load_audio, wav_to_univnet_mel, denormalize_tacotron_mel
from utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from utils.tokenizer import VoiceBpeTokenizer
@ -23,7 +23,6 @@ pbar = None
def download_models():
MODELS = {
'clip.pth': 'https://huggingface.co/jbetker/tortoise-tts-clip/resolve/main/pytorch-model.bin',
'dvae.pth': 'https://huggingface.co/jbetker/voice-dvae/resolve/main/pytorch_model.bin',
'diffusion.pth': 'https://huggingface.co/jbetker/tortoise-tts-diffusion-v1/resolve/main/pytorch-model.bin',
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-autoregressive/resolve/main/pytorch-model.bin'
}
@ -47,12 +46,14 @@ def download_models():
request.urlretrieve(url, f'.models/{model_name}', show_progress)
print('Done.')
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200):
"""
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type='epsilon',
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps))
model_var_type='learned_range', loss_type='mse', betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=True, conditioning_free_k=1)
def load_conditioning(path, sample_rate=22050, cond_length=132300):
@ -94,26 +95,26 @@ def fix_autoregressive_output(codes, stop_token):
return codes
def do_spectrogram_diffusion(diffusion_model, dvae_model, diffuser, mel_codes, conditioning_input, spectrogram_compression_factor=128, mean=False):
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, mean=False):
"""
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
"""
with torch.no_grad():
mel = dvae_model.decode(mel_codes)[0]
# Pad MEL to multiples of 2048//spectrogram_compression_factor
msl = mel.shape[-1]
dsl = 2048 // spectrogram_compression_factor
cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
# Pad MEL to multiples of 32
msl = mel_codes.shape[-1]
dsl = 32
gap = dsl - (msl % dsl)
if gap > 0:
mel = torch.nn.functional.pad(mel, (0, gap))
mel = torch.nn.functional.pad(mel_codes, (0, gap))
output_shape = (mel.shape[0], 1, mel.shape[-1] * spectrogram_compression_factor)
output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
if mean:
return diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=torch.zeros(output_shape, device=mel_codes.device),
model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
else:
return diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'spectrogram': mel, 'conditioning_input': conditioning_input})
mel = diffuser.p_sample_loop(diffusion_model, output_shape, model_kwargs={'aligned_conditioning': mel_codes, 'conditioning_input': cond_mel})
return denormalize_tacotron_mel(mel)[:,:,:msl*4]
if __name__ == '__main__':
@ -145,12 +146,6 @@ if __name__ == '__main__':
download_models()
for voice in args.voice.split(','):
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cuda().eval()
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
stop_mel_token = autoregressive.stop_mel_token
print("Loading data..")
tokenizer = VoiceBpeTokenizer()
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
@ -160,7 +155,15 @@ if __name__ == '__main__':
for cond_path in cond_paths:
c, cond_wav = load_conditioning(cond_path)
conds.append(c)
conds = torch.stack(conds, dim=1) # And just use the last cond_wav for the diffusion model.
conds = torch.stack(conds, dim=1)
cond_diffusion = cond_wav[:, :88200] # The diffusion model expects <= 88200 conditioning samples.
print("Loading GPT TTS..")
autoregressive = UnifiedVoice(max_mel_tokens=300, max_text_tokens=200, max_conditioning_inputs=2, layers=30, model_dim=1024,
heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, train_solo_embeddings=False,
average_conditioning_embeddings=True).cuda().eval()
autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
stop_mel_token = autoregressive.stop_mel_token
with torch.no_grad():
print("Performing autoregressive inference..")
@ -194,20 +197,25 @@ if __name__ == '__main__':
# Delete the autoregressive and clip models to free up GPU memory
del samples, clip
print("Loading DVAE..")
dvae = DiscreteVAE(positional_dims=1, channels=80, hidden_dim=512, num_resnet_blocks=3, codebook_dim=512, num_tokens=8192, num_layers=2,
record_codes=True, kernel_size=3, use_transposed_convs=False).cuda().eval()
dvae.load_state_dict(torch.load('.models/dvae.pth'), strict=False)
print("Loading Diffusion Model..")
diffusion = DiscreteDiffusionVocoder(model_channels=128, dvae_dim=80, channel_mult=[1, 1, 1.5, 2, 3, 4, 6, 8, 8, 8, 8], num_res_blocks=[1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1],
spectrogram_conditioning_resolutions=[2,512], attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, time_embed_dim_multiplier=4).cuda().eval()
diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], token_conditioning_resolutions=[1,4,8],
dropout=0, attention_resolutions=[4,8], num_heads=8, kernel_size=3, scale_factor=2,
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
conditioning_expansion=1)
diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
diffusion = diffusion.cuda().eval()
print("Loading vocoder..")
vocoder = UnivNetGenerator()
vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
vocoder = vocoder.cuda()
vocoder.eval(inference=True)
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=100)
print("Performing vocoding..")
# Perform vocoding on each batch element separately: The diffusion model is very memory (and compute!) intensive.
for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, code, cond_wav, spectrogram_compression_factor=256, mean=True)
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 22050)
mel = do_spectrogram_diffusion(diffusion, diffuser, code, cond_diffusion, mean=False)
wav = vocoder.inference(mel)
torchaudio.save(os.path.join(args.output_path, f'{voice}_{b}.wav'), wav.squeeze(0).cpu(), 24000)

View File

@ -192,7 +192,8 @@ class ConditioningEncoder(nn.Module):
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
do_checkpointing=False):
do_checkpointing=False,
mean=False):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
@ -201,10 +202,14 @@ class ConditioningEncoder(nn.Module):
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.do_checkpointing = do_checkpointing
self.mean = mean
def forward(self, x):
h = self.init(x)
h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0]
@ -275,7 +280,7 @@ class UnifiedVoice(nn.Module):
mel_length_compression=1024, number_text_tokens=256,
start_text_token=255, stop_text_token=0, number_mel_codes=8194, start_mel_token=8192,
stop_mel_token=8193, train_solo_embeddings=False, use_mel_codes_as_input=True,
checkpointing=True):
checkpointing=True, average_conditioning_embeddings=False):
"""
Args:
layers: Number of layers in transformer stack.
@ -294,6 +299,7 @@ class UnifiedVoice(nn.Module):
train_solo_embeddings:
use_mel_codes_as_input:
checkpointing:
average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
"""
super().__init__()
@ -311,6 +317,7 @@ class UnifiedVoice(nn.Module):
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
@ -408,6 +415,8 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
@ -446,6 +455,8 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
@ -472,6 +483,8 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
if raw_mels is not None:
@ -508,6 +521,8 @@ class UnifiedVoice(nn.Module):
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1)
self.inference_model.store_mel_emb(emb)

598
models/diffusion_decoder.py Normal file
View File

@ -0,0 +1,598 @@
"""
This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
and an audio conditioning input. It has also been simplified somewhat.
Credit: https://github.com/openai/improved-diffusion
"""
import functools
import math
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from torch.nn import Linear
from torch.utils.checkpoint import checkpoint
from x_transformers import ContinuousTransformerWrapper, Encoder
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock
def is_latent(t):
return t.dtype == torch.float
def is_sequence(t):
return t.dtype == torch.long
def ceil_multiple(base, multiple):
res = base % multiple
if res == 0:
return base
return base + (multiple - res)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class ResBlock(TimestepBlock):
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
kernel_size=3,
efficient_config=True,
use_scale_shift_norm=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_scale_shift_norm = use_scale_shift_norm
padding = {1: 0, 3: 1, 5: 2}[kernel_size]
eff_kernel = 1 if efficient_config else 3
eff_padding = 0 if efficient_config else 1
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, x, emb
)
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
def forward(self, x, *args, **kwargs):
for k, v in kwargs.items():
assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing.
partial = functools.partial(self.wrap, **kwargs)
return torch.utils.checkpoint.checkpoint(partial, x, *args)
class CheckpointedXTransformerEncoder(nn.Module):
"""
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, needs_permute=True, **xtransformer_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
self.needs_permute = needs_permute
for i in range(len(self.transformer.attn_layers.layers)):
n, b, r = self.transformer.attn_layers.layers[i]
self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
def forward(self, x, **kwargs):
if self.needs_permute:
x = x.permute(0,2,1)
h = self.transformer(x, **kwargs)
return h.permute(0,2,1)
class DiffusionTts(nn.Module):
"""
The full UNet model with attention and timestep embedding.
Customized to be conditioned on an aligned prior derived from a autoregressive
GPT-style model.
:param in_channels: channels in the input Tensor.
:param in_latent_channels: channels from the input latent.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
model_channels,
in_channels=1,
in_latent_channels=1024,
in_tokens=8193,
conditioning_dim_factor=8,
conditioning_expansion=4,
out_channels=2, # mean and variance
dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
token_conditioning_resolutions=(1,16,),
attention_resolutions=(512,1024,2048),
conv_resample=True,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
use_scale_shift_norm=True,
# Parameters for regularization.
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for super-sampling.
super_sampling=False,
super_sampling_max_noising_factor=.1,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.super_sampling_enabled = super_sampling
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16
self.alignment_size = 2 ** (len(channel_mult)+1)
self.freeze_main_net = freeze_main_net
padding = 1 if kernel_size == 3 else 2
down_kernel = 1 if efficient_convs else 3
time_embed_dim = model_channels * time_embed_dim_multiplier
self.time_embed = nn.Sequential(
Linear(model_channels, time_embed_dim),
nn.SiLU(),
Linear(time_embed_dim, time_embed_dim),
)
conditioning_dim = model_channels * conditioning_dim_factor
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network.
self.code_converter = nn.Sequential(
nn.Embedding(in_tokens, conditioning_dim),
CheckpointedXTransformerEncoder(
needs_permute=False,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=conditioning_dim,
depth=3,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
))
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
if in_channels > 60: # It's a spectrogram.
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
CheckpointedXTransformerEncoder(
needs_permute=True,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=conditioning_dim,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
))
else:
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
)
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
)
]
)
token_conditioning_blocks = []
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in token_conditioning_resolutions:
token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
token_conditioning_block.weight.data *= .02
self.input_blocks.append(token_conditioning_block)
token_conditioning_blocks.append(token_conditioning_block)
for _ in range(num_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(
ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
),
ResBlock(
ch,
time_embed_dim,
dropout,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
)
)
if level and i == num_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
)
def fix_alignment(self, x, aligned_conditioning):
"""
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
"""
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param lr_input: for super-sampling models, a guidance audio clip at a lower sampling rate.
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert conditioning_input is not None
if self.super_sampling_enabled:
assert lr_input is not None
if self.training and self.super_sampling_max_noising_factor > 0:
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
lr_input = torch.randn_like(lr_input) * noising_factor + lr_input
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
x = torch.cat([x, lr_input], dim=1)
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
orig_x_shape = x.shape[-1]
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
with autocast(x.device.type, enabled=self.enable_fp16):
hs = []
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
# Note: this block does not need to repeated on inference, since it is not timestep-dependent.
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
else:
cond_emb = self.contextual_embedder(conditioning_input)
if len(cond_emb.shape) == 3: # Just take the first element.
cond_emb = cond_emb[:, :, 0]
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
else:
code_emb = self.code_converter(aligned_conditioning)
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
code_emb)
# Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
first = True
time_emb = time_emb.float()
h = x
for k, module in enumerate(self.input_blocks):
if isinstance(module, nn.Conv1d):
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h = h + h_tok
else:
with autocast(x.device.type, enabled=self.enable_fp16 and not first):
# First block has autocast disabled to allow a high precision signal to be properly vectorized.
h = module(h, time_emb)
hs.append(h)
first = False
h = self.middle_block(h, time_emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, time_emb)
# Last block also has autocast disabled for high-precision outputs.
h = h.float()
out = self.out(h)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
for p in params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
return out[:, :, :orig_x_shape]
if __name__ == '__main__':
clip = torch.randn(2, 1, 32868)
aligned_latent = torch.randn(2,388,1024)
aligned_sequence = torch.randint(0,8192,(2,388))
cond = torch.randn(2, 1, 44000)
ts = torch.LongTensor([600, 600])
model = DiffusionTts(128,
channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
attention_resolutions=[],
num_heads=8,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
super_sampling=False,
efficient_convs=False)
# Test with latent aligned conditioning
o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond)

325
models/vocoder.py Normal file
View File

@ -0,0 +1,325 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
MAX_WAV_VALUE = 32768.0
class KernelPredictor(torch.nn.Module):
''' Kernel predictor for the location-variable convolutions'''
def __init__(
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
'''
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
'''
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding,
bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True))
self.bias_conv = nn.utils.weight_norm(
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True))
def forward(self, c):
'''
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
'''
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
class LVCBlock(torch.nn.Module):
'''the location-variable convolutions'''
def __init__(
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}
)
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride,
padding=stride // 2 + stride % 2, output_padding=stride % 2)),
)
self.conv_blocks = nn.ModuleList()
for dilation in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size,
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
''' forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
'''
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(output, k, b,
hop_size=self.cond_hop_length) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
'''
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), 'constant', 0)
x = x.unfold(3, dilation,
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
class UnivNetGenerator(nn.Module):
"""UnivNet Generator"""
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3,
# Below are MEL configurations options that this generator requires.
hop_length=256, n_mel_channels=100):
super(UnivNetGenerator, self).__init__()
self.mel_channel = n_mel_channels
self.noise_dim = noise_dim
self.hop_length = hop_length
channel_size = channel_size
kpnet_conv_size = kpnet_conv_size
self.res_stack = nn.ModuleList()
hop_length = 1
for stride in strides:
hop_length = stride * hop_length
self.res_stack.append(
LVCBlock(
channel_size,
n_mel_channels,
stride=stride,
dilations=dilations,
lReLU_slope=lReLU_slope,
cond_hop_length=hop_length,
kpnet_conv_size=kpnet_conv_size
)
)
self.conv_pre = \
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect'))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')),
nn.Tanh(),
)
def forward(self, c, z):
'''
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
z (Tensor): the noise sequence (batch, noise_dim, in_length)
'''
z = self.conv_pre(z) # (B, c_g, L)
for res_block in self.res_stack:
res_block.to(z.device)
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
z = self.conv_post(z) # (B, 1, L * 256)
return z
def eval(self, inference=False):
super(UnivNetGenerator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
print('Removing weight norm...')
nn.utils.remove_weight_norm(self.conv_pre)
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
for res_block in self.res_stack:
res_block.remove_weight_norm()
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, :-(self.hop_length * 10)]
audio = audio.clamp(min=-1, max=1)
return audio
if __name__ == '__main__':
model = UnivNetGenerator()
c = torch.randn(3, 100, 10)
z = torch.randn(3, 64, 10)
print(c.shape)
y = model(c, z)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@ -7,3 +7,4 @@ inflect
progressbar
einops
unidecode
x-transformers

View File

@ -3,6 +3,8 @@ import torchaudio
import numpy as np
from scipy.io.wavfile import read
from utils.stft import STFT
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
@ -44,3 +46,85 @@ def load_audio(audiopath, sampling_rate):
audio.clip_(-1, 1)
return audio.unsqueeze(0)
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
from librosa.filters import mel as librosa_mel_fn
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer('mel_basis', mel_basis)
def spectral_normalize(self, magnitudes):
output = dynamic_range_compression(magnitudes)
return output
def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output
def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert(torch.min(y.data) >= -10)
assert(torch.max(y.data) <= 10)
y = torch.clip(y, min=-1, max=1)
magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output
def wav_to_univnet_mel(wav, do_normalization=False):
stft = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000)
stft = stft.cuda()
mel = stft.mel_spectrogram(wav)
if do_normalization:
mel = normalize_tacotron_mel(mel)
return mel

View File

@ -197,11 +197,17 @@ class GaussianDiffusion:
model_var_type,
loss_type,
rescale_timesteps=False,
conditioning_free=False,
conditioning_free_k=1,
ramp_conditioning_free=True,
):
self.model_mean_type = ModelMeanType(model_mean_type)
self.model_var_type = ModelVarType(model_var_type)
self.loss_type = LossType(loss_type)
self.rescale_timesteps = rescale_timesteps
self.conditioning_free = conditioning_free
self.conditioning_free_k = conditioning_free_k
self.ramp_conditioning_free = ramp_conditioning_free
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
@ -332,10 +338,14 @@ class GaussianDiffusion:
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.conditioning_free:
model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.conditioning_free:
model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
@ -364,6 +374,14 @@ class GaussianDiffusion:
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
if self.conditioning_free:
if self.ramp_conditioning_free:
assert t.shape[0] == 1 # This should only be used in inference.
cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps)
else:
cfk = self.conditioning_free_k
model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)

193
utils/stft.py Normal file
View File

@ -0,0 +1,193 @@
"""
BSD 3-Clause License
Copyright (c) 2017, Prem Seetharaman
All rights reserved.
* Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny
import librosa.util as librosa_util
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann'):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode='reflect')
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction