Integrate new diffusion network

This commit is contained in:
James Betker 2022-04-01 14:15:17 -06:00
parent d89c51a71c
commit 4747fae381
3 changed files with 189 additions and 390 deletions

49
api.py
View File

@ -49,6 +49,15 @@ def download_models():
print('Done.') print('Done.')
def pad_or_truncate(t, length):
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
""" """
Helper function to load a GaussianDiffusion instance configured for use as a vocoder. Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
@ -96,26 +105,25 @@ def fix_autoregressive_output(codes, stop_token):
return codes return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1): def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
""" """
Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip. Uses the specified diffusion model to convert discrete codes into a spectrogram.
""" """
with torch.no_grad(): with torch.no_grad():
cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False) cond_mels = []
# Pad MEL to multiples of 32 for sample in conditioning_samples:
msl = mel_codes.shape[-1] sample = pad_or_truncate(sample, 102400)
dsl = 32 cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
gap = dsl - (msl % dsl) cond_mels.append(cond_mel)
if gap > 0: cond_mels = torch.stack(cond_mels, dim=1)
mel = torch.nn.functional.pad(mel_codes, (0, gap))
output_shape = (mel.shape[0], 100, mel.shape[-1]*4) output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False)
noise = torch.randn(output_shape, device=mel_codes.device) * temperature noise = torch.randn(output_shape, device=mel_codes.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
return denormalize_tacotron_mel(mel)[:,:,:msl*4] return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4]
class TextToSpeech: class TextToSpeech:
@ -137,12 +145,9 @@ class TextToSpeech:
use_xformers=True).cpu().eval() use_xformers=True).cpu().eval()
self.clip.load_state_dict(torch.load('.models/clip.pth')) self.clip.load_state_dict(torch.load('.models/clip.pth'))
self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024, self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
token_conditioning_resolutions=[1, 4, 8], layer_drop=0, unconditioned_percentage=0).cpu().eval()
dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2,
time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
conditioning_expansion=1).cpu().eval()
self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
self.vocoder = UnivNetGenerator().cpu() self.vocoder = UnivNetGenerator().cpu()
@ -164,12 +169,6 @@ class TextToSpeech:
for vs in voice_samples: for vs in voice_samples:
conds.append(load_conditioning(vs)) conds.append(load_conditioning(vs))
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
cond_diffusion = voice_samples[0].cuda()
# The diffusion model expects = 88200 conditioning samples.
if cond_diffusion.shape[-1] < 88200:
cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1]))
else:
cond_diffusion = cond_diffusion[:, :88200]
diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
@ -211,7 +210,7 @@ class TextToSpeech:
self.vocoder = self.vocoder.cuda() self.vocoder = self.vocoder.cuda()
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
code = best_results[b].unsqueeze(0) code = best_results[b].unsqueeze(0)
mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature) mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
wav = self.vocoder.inference(mel) wav = self.vocoder.inference(mel)
wav_candidates.append(wav.cpu()) wav_candidates.append(wav.cpu())
self.diffusion = self.diffusion.cpu() self.diffusion = self.diffusion.cpu()

View File

@ -6,6 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from x_transformers import ContinuousTransformerWrapper from x_transformers import ContinuousTransformerWrapper
from x_transformers.x_transformers import RelativePositionBias
def zero_module(module): def zero_module(module):
@ -49,7 +50,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
def forward(self, qkv, mask=None): def forward(self, qkv, mask=None, rel_pos=None):
""" """
Apply QKV attention. Apply QKV attention.
@ -64,6 +65,8 @@ class QKVAttentionLegacy(nn.Module):
weight = torch.einsum( weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale "bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards ) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None: if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@ -87,9 +90,12 @@ class AttentionBlock(nn.Module):
channels, channels,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
do_checkpoint=True,
relative_pos_embeddings=False,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.do_checkpoint = do_checkpoint
if num_head_channels == -1: if num_head_channels == -1:
self.num_heads = num_heads self.num_heads = num_heads
else: else:
@ -99,21 +105,20 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.norm = normalization(channels) self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads) self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
else:
self.relative_pos_embeddings = None
def forward(self, x, mask=None): def forward(self, x, mask=None):
if mask is not None:
return self._forward(x, mask)
else:
return self._forward(x)
def _forward(self, x, mask=None):
b, c, *spatial = x.shape b, c, *spatial = x.shape
x = x.reshape(b, c, -1) x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x)) qkv = self.qkv(self.norm(x))
h = self.attention(qkv, mask) h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h) h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial) return (x + h).reshape(b, c, *spatial)

View File

@ -1,22 +1,13 @@
"""
This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
and an audio conditioning input. It has also been simplified somewhat.
Credit: https://github.com/openai/improved-diffusion
"""
import functools
import math import math
import random
from abc import abstractmethod from abc import abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast from torch import autocast
from torch.nn import Linear
from torch.utils.checkpoint import checkpoint
from x_transformers import ContinuousTransformerWrapper, Encoder
from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \ from models.arch_util import normalization, AttentionBlock
CheckpointedXTransformerEncoder
def is_latent(t): def is_latent(t):
@ -27,13 +18,6 @@ def is_sequence(t):
return t.dtype == torch.long return t.dtype == torch.long
def ceil_multiple(base, multiple):
res = base % multiple
if res == 0:
return base
return base + (multiple - res)
def timestep_embedding(timesteps, dim, max_period=10000): def timestep_embedding(timesteps, dim, max_period=10000):
""" """
Create sinusoidal timestep embeddings. Create sinusoidal timestep embeddings.
@ -56,10 +40,6 @@ def timestep_embedding(timesteps, dim, max_period=10000):
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod @abstractmethod
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -68,11 +48,6 @@ class TimestepBlock(nn.Module):
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb): def forward(self, x, emb):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
@ -89,6 +64,7 @@ class ResBlock(TimestepBlock):
emb_channels, emb_channels,
dropout, dropout,
out_channels=None, out_channels=None,
dims=2,
kernel_size=3, kernel_size=3,
efficient_config=True, efficient_config=True,
use_scale_shift_norm=False, use_scale_shift_norm=False,
@ -111,7 +87,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
Linear( nn.Linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
), ),
@ -120,9 +96,7 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
@ -131,18 +105,6 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb): def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, x, emb
)
def _forward(self, x, emb):
h = self.in_layers(x) h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
@ -158,372 +120,205 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
class DiffusionLayer(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
return self.attn(y)
class DiffusionTts(nn.Module): class DiffusionTts(nn.Module):
"""
The full UNet model with attention and timestep embedding.
Customized to be conditioned on an aligned prior derived from a autoregressive
GPT-style model.
:param in_channels: channels in the input Tensor.
:param in_latent_channels: channels from the input latent.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__( def __init__(
self, self,
model_channels, model_channels=512,
in_channels=1, num_layers=8,
in_latent_channels=1024, in_channels=100,
in_latent_channels=512,
in_tokens=8193, in_tokens=8193,
conditioning_dim_factor=8, out_channels=200, # mean and variance
conditioning_expansion=4,
out_channels=2, # mean and variance
dropout=0, dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
# spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
# attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
token_conditioning_resolutions=(1,16,),
attention_resolutions=(512,1024,2048),
conv_resample=True,
use_fp16=False, use_fp16=False,
num_heads=1, num_heads=16,
num_head_channels=-1,
num_heads_upsample=-1,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
freeze_main_net=False,
efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
use_scale_shift_norm=True,
# Parameters for regularization. # Parameters for regularization.
layer_drop=.1,
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
# Parameters for super-sampling.
super_sampling=False,
super_sampling_max_noising_factor=.1,
): ):
super().__init__() super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if super_sampling:
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.super_sampling_enabled = super_sampling
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
self.unconditioned_percentage = unconditioned_percentage self.unconditioned_percentage = unconditioned_percentage
self.enable_fp16 = use_fp16 self.enable_fp16 = use_fp16
self.alignment_size = 2 ** (len(channel_mult)+1) self.layer_drop = layer_drop
self.freeze_main_net = freeze_main_net
padding = 1 if kernel_size == 3 else 2
down_kernel = 1 if efficient_convs else 3
time_embed_dim = model_channels * time_embed_dim_multiplier self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
Linear(model_channels, time_embed_dim), nn.Linear(model_channels, model_channels),
nn.SiLU(), nn.SiLU(),
Linear(time_embed_dim, time_embed_dim), nn.Linear(model_channels, model_channels),
) )
conditioning_dim = model_channels * conditioning_dim_factor
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
# transformer network. # transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential( self.code_converter = nn.Sequential(
nn.Embedding(in_tokens, conditioning_dim), AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
CheckpointedXTransformerEncoder( AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
needs_permute=False, AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=conditioning_dim,
depth=3,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
) )
)) self.code_norm = normalization(model_channels)
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1) self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1)) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
if in_channels > 60: # It's a spectrogram. nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
CheckpointedXTransformerEncoder( AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
needs_permute=True, AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
max_seq_len=-1, AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
use_pos_emb=False, AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
attn_layers=Encoder( self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
dim=conditioning_dim,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
))
else:
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential( self.conditioning_timestep_integrator = TimestepEmbedSequential(
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), DiffusionLayer(model_channels, dropout, num_heads),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), DiffusionLayer(model_channels, dropout, num_heads),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), DiffusionLayer(model_channels, dropout, num_heads),
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
) )
self.conditioning_expansion = conditioning_expansion self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.input_blocks = nn.ModuleList( self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
[ [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
TimestepEmbedSequential(
nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
)
]
)
token_conditioning_blocks = []
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in token_conditioning_resolutions:
token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
token_conditioning_block.weight.data *= .02
self.input_blocks.append(token_conditioning_block)
token_conditioning_blocks.append(token_conditioning_block)
for _ in range(num_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(
ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=num_head_channels,
),
ResBlock(
ch,
time_embed_dim,
dropout,
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
kernel_size=kernel_size,
efficient_config=efficient_convs,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
)
)
if level and i == num_blocks:
out_ch = ch
layers.append(
Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(model_channels),
nn.SiLU(), nn.SiLU(),
zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)), nn.Conv1d(model_channels, out_channels, 3, padding=1),
) )
def fix_alignment(self, x, aligned_conditioning): def get_grad_norm_parameter_groups(self):
""" groups = {
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by 'minicoder': list(self.contextual_embedder.parameters()),
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning. 'layers': list(self.layers.parameters()),
""" 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
cm = ceil_multiple(x.shape[-1], self.alignment_size) 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
if cm != 0: 'time_embed': list(self.time_embed.parameters()),
pc = (cm-x.shape[-1])/x.shape[-1] }
x = F.pad(x, (0,cm-x.shape[-1])) return groups
# Also fix aligned_latent, which is aligned to x.
if is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def timestep_independent(self, aligned_conditioning, conditioning_input): def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
# Shuffle aligned_latent to BxCxS format # Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1) aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16): # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
cond_emb = self.contextual_embedder(conditioning_input) speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
if len(cond_emb.shape) == 3: # Just take the first element. conditioning_input.shape) == 3 else conditioning_input
cond_emb = cond_emb[:, :, 0] conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning) code_emb = self.latent_converter(aligned_conditioning)
else: else:
code_emb = self.code_converter(aligned_conditioning) code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1]) code_emb = self.code_converter(code_emb)
code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1)) code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
return code_emb
def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False): unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
assert x.shape[-1] % self.alignment_size == 0 # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
with autocast(x.device.type, enabled=self.enable_fp16): if not return_code_pred:
if conditioning_free: return expanded_code_emb
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
else: else:
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_converter.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
if is_latent(aligned_conditioning):
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
unused_params.extend(list(self.latent_converter.parameters()))
unused_params.append(self.unconditioned_embedding)
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
x = self.inp_block(x)
first = True x = torch.cat([x, code_emb], dim=1)
time_emb = time_emb.float() x = self.integrating_conv(x)
h = x for i, lyr in enumerate(self.layers):
hs = [] # Do layer drop where applicable. Do not drop first and last layers.
for k, module in enumerate(self.input_blocks): if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
if isinstance(module, nn.Conv1d): unused_params.extend(list(lyr.parameters()))
h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
h = h + h_tok
else: else:
with autocast(x.device.type, enabled=self.enable_fp16 and not first): # First and last blocks will have autocast disabled for improved precision.
# First block has autocast disabled to allow a high precision signal to be properly vectorized. with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
h = module(h, time_emb) x = lyr(x, time_emb)
hs.append(h)
first = False
h = self.middle_block(h, time_emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, time_emb)
# Last block also has autocast disabled for high-precision outputs. x = x.float()
h = h.float() out = self.out(x)
out = self.out(h)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
for p in unused_params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
if return_code_pred:
return out, mel_pred
return out return out
if __name__ == '__main__': if __name__ == '__main__':
clip = torch.randn(2, 1, 32868) clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2,388,1024) aligned_latent = torch.randn(2,388,512)
aligned_sequence = torch.randint(0,8192,(2,388)) aligned_sequence = torch.randint(0,8192,(2,100))
cond = torch.randn(2, 1, 44000) cond = torch.randn(2, 100, 400)
ts = torch.LongTensor([600, 600]) ts = torch.LongTensor([600, 600])
model = DiffusionTts(128, model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
channel_mult=[1,1.5,2, 3, 4, 6, 8],
num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
token_conditioning_resolutions=[1,4,16,64],
attention_resolutions=[],
num_heads=8,
kernel_size=3,
scale_factor=2,
time_embed_dim_multiplier=4,
super_sampling=False,
efficient_convs=False)
# Test with latent aligned conditioning # Test with latent aligned conditioning
o = model(clip, ts, aligned_latent, cond) #o = model(clip, ts, aligned_latent, cond)
# Test with sequence aligned conditioning # Test with sequence aligned conditioning
o = model(clip, ts, aligned_sequence, cond) o = model(clip, ts, aligned_sequence, cond)