diff --git a/api.py b/api.py index be07783..6c3fb1e 100644 --- a/api.py +++ b/api.py @@ -49,6 +49,15 @@ def download_models(): print('Done.') +def pad_or_truncate(t, length): + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1): """ Helper function to load a GaussianDiffusion instance configured for use as a vocoder. @@ -96,26 +105,25 @@ def fix_autoregressive_output(codes, stop_token): return codes -def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1): +def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1): """ - Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip. + Uses the specified diffusion model to convert discrete codes into a spectrogram. """ with torch.no_grad(): - cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False) - # Pad MEL to multiples of 32 - msl = mel_codes.shape[-1] - dsl = 32 - gap = dsl - (msl % dsl) - if gap > 0: - mel = torch.nn.functional.pad(mel_codes, (0, gap)) + cond_mels = [] + for sample in conditioning_samples: + sample = pad_or_truncate(sample, 102400) + cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False) + cond_mels.append(cond_mel) + cond_mels = torch.stack(cond_mels, dim=1) - output_shape = (mel.shape[0], 100, mel.shape[-1]*4) - precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel) + output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4) + precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False) noise = torch.randn(output_shape, device=mel_codes.device) * temperature mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) - return denormalize_tacotron_mel(mel)[:,:,:msl*4] + return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4] class TextToSpeech: @@ -137,12 +145,9 @@ class TextToSpeech: use_xformers=True).cpu().eval() self.clip.load_state_dict(torch.load('.models/clip.pth')) - self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024, - channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3], - token_conditioning_resolutions=[1, 4, 8], - dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2, - time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2, - conditioning_expansion=1).cpu().eval() + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, + layer_drop=0, unconditioned_percentage=0).cpu().eval() self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) self.vocoder = UnivNetGenerator().cpu() @@ -164,12 +169,6 @@ class TextToSpeech: for vs in voice_samples: conds.append(load_conditioning(vs)) conds = torch.stack(conds, dim=1) - cond_diffusion = voice_samples[0].cuda() - # The diffusion model expects = 88200 conditioning samples. - if cond_diffusion.shape[-1] < 88200: - cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1])) - else: - cond_diffusion = cond_diffusion[:, :88200] diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) @@ -211,7 +210,7 @@ class TextToSpeech: self.vocoder = self.vocoder.cuda() for b in range(best_results.shape[0]): code = best_results[b].unsqueeze(0) - mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature) + mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature) wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) self.diffusion = self.diffusion.cpu() diff --git a/models/arch_util.py b/models/arch_util.py index d374594..89488f4 100644 --- a/models/arch_util.py +++ b/models/arch_util.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import torchaudio from x_transformers import ContinuousTransformerWrapper +from x_transformers.x_transformers import RelativePositionBias def zero_module(module): @@ -49,7 +50,7 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv, mask=None): + def forward(self, qkv, mask=None, rel_pos=None): """ Apply QKV attention. @@ -64,6 +65,8 @@ class QKVAttentionLegacy(nn.Module): weight = torch.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. @@ -87,9 +90,12 @@ class AttentionBlock(nn.Module): channels, num_heads=1, num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, ): super().__init__() self.channels = channels + self.do_checkpoint = do_checkpoint if num_head_channels == -1: self.num_heads = num_heads else: @@ -99,21 +105,20 @@ class AttentionBlock(nn.Module): self.num_heads = channels // num_head_channels self.norm = normalization(channels) self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + else: + self.relative_pos_embeddings = None def forward(self, x, mask=None): - if mask is not None: - return self._forward(x, mask) - else: - return self._forward(x) - - def _forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) - h = self.attention(qkv, mask) + h = self.attention(qkv, mask, self.relative_pos_embeddings) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py index c57e9fb..cacdfc1 100644 --- a/models/diffusion_decoder.py +++ b/models/diffusion_decoder.py @@ -1,22 +1,13 @@ -""" -This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal -and an audio conditioning input. It has also been simplified somewhat. -Credit: https://github.com/openai/improved-diffusion -""" -import functools import math +import random from abc import abstractmethod import torch import torch.nn as nn import torch.nn.functional as F from torch import autocast -from torch.nn import Linear -from torch.utils.checkpoint import checkpoint -from x_transformers import ContinuousTransformerWrapper, Encoder -from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \ - CheckpointedXTransformerEncoder +from models.arch_util import normalization, AttentionBlock def is_latent(t): @@ -27,13 +18,6 @@ def is_sequence(t): return t.dtype == torch.long -def ceil_multiple(base, multiple): - res = base % multiple - if res == 0: - return base - return base + (multiple - res) - - def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. @@ -56,10 +40,6 @@ def timestep_embedding(timesteps, dim, max_period=10000): class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - @abstractmethod def forward(self, x, emb): """ @@ -68,11 +48,6 @@ class TimestepBlock(nn.Module): class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - def forward(self, x, emb): for layer in self: if isinstance(layer, TimestepBlock): @@ -89,6 +64,7 @@ class ResBlock(TimestepBlock): emb_channels, dropout, out_channels=None, + dims=2, kernel_size=3, efficient_config=True, use_scale_shift_norm=False, @@ -111,7 +87,7 @@ class ResBlock(TimestepBlock): self.emb_layers = nn.Sequential( nn.SiLU(), - Linear( + nn.Linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), @@ -120,9 +96,7 @@ class ResBlock(TimestepBlock): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) - ), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), ) if self.out_channels == channels: @@ -131,18 +105,6 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): @@ -158,372 +120,205 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h +class DiffusionLayer(TimestepBlock): + def __init__(self, model_channels, dropout, num_heads): + super().__init__() + self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + + def forward(self, x, time_emb): + y = self.resblk(x, time_emb) + return self.attn(y) + + class DiffusionTts(nn.Module): - """ - The full UNet model with attention and timestep embedding. - - Customized to be conditioned on an aligned prior derived from a autoregressive - GPT-style model. - - :param in_channels: channels in the input Tensor. - :param in_latent_channels: channels from the input latent. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - def __init__( self, - model_channels, - in_channels=1, - in_latent_channels=1024, + model_channels=512, + num_layers=8, + in_channels=100, + in_latent_channels=512, in_tokens=8193, - conditioning_dim_factor=8, - conditioning_expansion=4, - out_channels=2, # mean and variance + out_channels=200, # mean and variance dropout=0, - # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K - channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), - num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), - # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) - # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - token_conditioning_resolutions=(1,16,), - attention_resolutions=(512,1024,2048), - conv_resample=True, use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, - freeze_main_net=False, - efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. - use_scale_shift_norm=True, + num_heads=16, # Parameters for regularization. + layer_drop=.1, unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - # Parameters for super-sampling. - super_sampling=False, - super_sampling_max_noising_factor=.1, ): super().__init__() - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - self.attention_resolutions = attention_resolutions self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.super_sampling_enabled = super_sampling - self.super_sampling_max_noising_factor = super_sampling_max_noising_factor self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 - self.alignment_size = 2 ** (len(channel_mult)+1) - self.freeze_main_net = freeze_main_net - padding = 1 if kernel_size == 3 else 2 - down_kernel = 1 if efficient_convs else 3 + self.layer_drop = layer_drop - time_embed_dim = model_channels * time_embed_dim_multiplier + self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) self.time_embed = nn.Sequential( - Linear(model_channels, time_embed_dim), + nn.Linear(model_channels, model_channels), nn.SiLU(), - Linear(time_embed_dim, time_embed_dim), + nn.Linear(model_channels, model_channels), ) - conditioning_dim = model_channels * conditioning_dim_factor # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive # transformer network. + self.code_embedding = nn.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( - nn.Embedding(in_tokens, conditioning_dim), - CheckpointedXTransformerEncoder( - needs_permute=False, - max_seq_len=-1, - use_pos_emb=False, - attn_layers=Encoder( - dim=conditioning_dim, - depth=3, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_emb_dim=True, - ) - )) - self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1) - self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1)) - if in_channels > 60: # It's a spectrogram. - self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2), - CheckpointedXTransformerEncoder( - needs_permute=True, - max_seq_len=-1, - use_pos_emb=False, - attn_layers=Encoder( - dim=conditioning_dim, - depth=4, - heads=num_heads, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_emb_dim=True, - ) - )) - else: - self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1, - attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5) - self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) + self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) + self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), + nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), ) - self.conditioning_expansion = conditioning_expansion + self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding) - ) - ] - ) - token_conditioning_blocks = [] - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): - if ds in token_conditioning_resolutions: - token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1) - token_conditioning_block.weight.data *= .02 - self.input_blocks.append(token_conditioning_block) - token_conditioning_blocks.append(token_conditioning_block) - - for _ in range(num_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - kernel_size=kernel_size, - efficient_config=efficient_convs, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - Downsample( - ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1 - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - kernel_size=kernel_size, - efficient_config=efficient_convs, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - kernel_size=kernel_size, - efficient_config=efficient_convs, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]: - for i in range(num_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - kernel_size=kernel_size, - efficient_config=efficient_convs, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - ) - ) - if level and i == num_blocks: - out_ch = ch - layers.append( - Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch + self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + + [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) self.out = nn.Sequential( - normalization(ch), + normalization(model_channels), nn.SiLU(), - zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)), + nn.Conv1d(model_channels, out_channels, 3, padding=1), ) - def fix_alignment(self, x, aligned_conditioning): - """ - The UNet requires that the input is a certain multiple of 2, defined by the UNet depth. Enforce this by - padding both and before forward propagation and removing the padding before returning. - """ - cm = ceil_multiple(x.shape[-1], self.alignment_size) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - # Also fix aligned_latent, which is aligned to x. - if is_latent(aligned_conditioning): - aligned_conditioning = torch.cat([aligned_conditioning, - self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) - else: - aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1]))) - return x, aligned_conditioning + def get_grad_norm_parameter_groups(self): + groups = { + 'minicoder': list(self.contextual_embedder.parameters()), + 'layers': list(self.layers.parameters()), + 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), + 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), + 'time_embed': list(self.time_embed.parameters()), + } + return groups - def timestep_independent(self, aligned_conditioning, conditioning_input): + def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred): # Shuffle aligned_latent to BxCxS format if is_latent(aligned_conditioning): aligned_conditioning = aligned_conditioning.permute(0, 2, 1) - with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16): - cond_emb = self.contextual_embedder(conditioning_input) - if len(cond_emb.shape) == 3: # Just take the first element. - cond_emb = cond_emb[:, :, 0] - if is_latent(aligned_conditioning): - code_emb = self.latent_converter(aligned_conditioning) - else: - code_emb = self.code_converter(aligned_conditioning) - cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1]) - code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1)) - return code_emb + # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent. + speech_conditioning_input = conditioning_input.unsqueeze(1) if len( + conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + cond_emb = conds.mean(dim=-1) + cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_converter(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) - def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False): - assert x.shape[-1] % self.alignment_size == 0 + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), + device=code_emb.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb) + expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest') - with autocast(x.device.type, enabled=self.enable_fp16): - if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) - else: + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + + def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() + :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. + :return: an [N x C x ...] Tensor of outputs. + """ + assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None) + assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_converter.parameters())) + else: + if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings - - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) - - first = True - time_emb = time_emb.float() - h = x - hs = [] - for k, module in enumerate(self.input_blocks): - if isinstance(module, nn.Conv1d): - h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest') - h = h + h_tok + else: + code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True) + if is_latent(aligned_conditioning): + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - with autocast(x.device.type, enabled=self.enable_fp16 and not first): - # First block has autocast disabled to allow a high precision signal to be properly vectorized. - h = module(h, time_emb) - hs.append(h) - first = False - h = self.middle_block(h, time_emb) - for module in self.output_blocks: - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, time_emb) + unused_params.extend(list(self.latent_converter.parameters())) + unused_params.append(self.unconditioned_embedding) - # Last block also has autocast disabled for high-precision outputs. - h = h.float() - out = self.out(h) + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + x = self.inp_block(x) + x = torch.cat([x, code_emb], dim=1) + x = self.integrating_conv(x) + for i, lyr in enumerate(self.layers): + # Do layer drop where applicable. Do not drop first and last layers. + if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop: + unused_params.extend(list(lyr.parameters())) + else: + # First and last blocks will have autocast disabled for improved precision. + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + x = x.float() + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + if return_code_pred: + return out, mel_pred return out if __name__ == '__main__': - clip = torch.randn(2, 1, 32868) - aligned_latent = torch.randn(2,388,1024) - aligned_sequence = torch.randint(0,8192,(2,388)) - cond = torch.randn(2, 1, 44000) + clip = torch.randn(2, 100, 400) + aligned_latent = torch.randn(2,388,512) + aligned_sequence = torch.randint(0,8192,(2,100)) + cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) - model = DiffusionTts(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], - attention_resolutions=[], - num_heads=8, - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, - super_sampling=False, - efficient_convs=False) + model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5) # Test with latent aligned conditioning - o = model(clip, ts, aligned_latent, cond) + #o = model(clip, ts, aligned_latent, cond) # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence, cond) +