diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 5ee1fc13..fafeaef3 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -770,9 +770,11 @@ class GaussianDiffusion: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) - model_output = model_outputs[0] - if len(model_outputs) > 1: + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] terms['extra_outputs'] = model_outputs[1:] + else: + model_output = model_outputs if self.model_var_type in [ ModelVarType.LEARNED, diff --git a/codes/models/gpt_voice/mini_encoder.py b/codes/models/gpt_voice/mini_encoder.py index 8b2d54b5..86e61d20 100644 --- a/codes/models/gpt_voice/mini_encoder.py +++ b/codes/models/gpt_voice/mini_encoder.py @@ -3,8 +3,7 @@ import torch.nn as nn from models.diffusion.nn import normalization, conv_nd, zero_module -from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy - +from models.diffusion.unet_diffusion import Downsample, AttentionBlock, QKVAttention, QKVAttentionLegacy, Upsample # Combined resnet & full-attention encoder for converting an audio clip into an embedding. from utils.util import checkpoint @@ -90,17 +89,17 @@ class ResBlock(nn.Module): class AudioMiniEncoder(nn.Module): - def __init__(self, spec_dim, embedding_dim, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0): + def __init__(self, spec_dim, embedding_dim, base_channels=128, depth=2, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, dropout=0, downsample_factor=2, kernel_size=3): super().__init__() self.init = nn.Sequential( - conv_nd(1, spec_dim, 128, 3, padding=1) + conv_nd(1, spec_dim, base_channels, 3, padding=1) ) - ch = 128 + ch = base_channels res = [] - for l in range(2): + for l in range(depth): for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False)) - res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=2)) + res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) self.final = nn.Sequential( diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py new file mode 100644 index 00000000..40eb5645 --- /dev/null +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -0,0 +1,335 @@ +from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 +from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear +from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ + Downsample, Upsample +import torch +import torch.nn as nn + +from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner +from trainer.networks import register_model +from utils.util import get_mask_from_lengths + + +class DiscreteSpectrogramConditioningBlock(nn.Module): + def __init__(self, discrete_codes, channels): + super().__init__() + self.emb = nn.Embedding(discrete_codes, channels) + + """ + Embeds the given codes and concatenates them onto x. Return shape: bx2cxS + + :param x: bxcxS waveform latent + :param codes: bxN discrete codes, N <= S + """ + def forward(self, x, codes): + _, c, S = x.shape + b, N = codes.shape + assert N <= S + emb = self.emb(codes).permute(0,2,1) + emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') + return torch.cat([x, emb], dim=1) + + +class DiffusionVocoderWithRef(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + Customized to be conditioned on a spectrogram prior. + + :param in_channels: channels in the input Tensor. + :param spectrogram_channels: channels in the conditioning spectrogram. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + model_channels, + num_res_blocks, + in_channels=1, + out_channels=2, # mean and variance + discrete_codes=8192, + dropout=0, + # 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz + channel_mult=(1, 1, 2, 2, 4, 8, 16, 32, 64), + spectrogram_conditioning_resolutions=(4,8,16,32), + attention_resolutions=(64,128,256), + conv_resample=True, + dims=1, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + kernel_size=3, + scale_factor=2, + conditioning_inputs_provided=True, + conditioning_input_dim=80, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.dims = dims + + padding = 1 if kernel_size == 3 else 2 + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.conditioning_enabled = conditioning_inputs_provided + if conditioning_inputs_provided: + self.contextual_embedder = AudioMiniEncoder(conditioning_input_dim, time_embed_dim) + self.query_gen = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, + attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) + self.embedding_combiner = EmbeddingCombiner(time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + for level, mult in enumerate(channel_mult): + if ds in spectrogram_conditioning_resolutions: + self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch)) + ch *= 2 + + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + kernel_size=kernel_size, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + kernel_size=kernel_size, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + kernel_size=kernel_size, + ), + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + kernel_size=kernel_size, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + kernel_size=kernel_size, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + kernel_size=kernel_size, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, discrete_spectrogram, conditioning_inputs=None, num_conditioning_signals=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. + if self.conditioning_enabled: + assert conditioning_inputs is not None + assert num_conditioning_signals is not None + + hs = [] + emb1 = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.conditioning_enabled: + emb2 = torch.stack([self.contextual_embedder(ci.squeeze(1)) for ci in list(torch.chunk(conditioning_inputs, conditioning_inputs.shape[1], dim=1))], dim=1) + emb = torch.cat([emb1.unsqueeze(1), emb2], dim=1) + emb = self.embedding_combiner(emb, None, self.query_gen(x)) + else: + emb = emb1 + + h = x.type(self.dtype) + for k, module in enumerate(self.input_blocks): + if isinstance(module, DiscreteSpectrogramConditioningBlock): + h = module(h, discrete_spectrogram) + else: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +@register_model +def register_unet_diffusion_vocoder_with_ref(opt_net, opt): + return DiffusionVocoderWithRef(**opt_net['kwargs']) + + +# Test for ~4 second audio clip at 22050Hz +if __name__ == '__main__': + clip = torch.randn(2, 1, 81920) + spec = torch.randint(8192, (2, 500,)) + cond = torch.randn(2, 4, 80, 600) + ts = torch.LongTensor([555, 556]) + model = DiffusionVocoderWithRef(32, 2) + print(model(clip, ts, spec, cond, 4).shape) diff --git a/codes/train.py b/codes/train.py index 193cc2fc..1cc43b7a 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_dvae_clips_with_discretization_loss.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_vocoder_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()