From dabd87246d8f9829a4de28863276c39ac0d3d7a6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 31 Aug 2021 14:38:33 -0600 Subject: [PATCH] Add unet_diffusion_vocoder --- codes/data/audio/wavfile_dataset.py | 5 + codes/models/diffusion/nn.py | 7 +- codes/models/diffusion/unet_diffusion.py | 10 +- .../diffusion/unet_diffusion_vocoder.py | 302 ++++++++++++++++++ codes/models/gpt_voice/lucidrains_dvae.py | 6 + codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 12 + 7 files changed, 338 insertions(+), 6 deletions(-) create mode 100644 codes/models/diffusion/unet_diffusion_vocoder.py diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py index 707ef136..17abe266 100644 --- a/codes/data/audio/wavfile_dataset.py +++ b/codes/data/audio/wavfile_dataset.py @@ -33,6 +33,7 @@ class WavfileDataset(torch.utils.data.Dataset): self.pad_to = opt_get(opt, ['pad_to_seconds'], None) if self.pad_to is not None: self.pad_to *= self.sampling_rate + self.min_sz = opt_get(opt, ['minimum_samples'], 0) self.augment = opt_get(opt, ['do_augmentation'], False) if self.augment: @@ -90,6 +91,10 @@ class WavfileDataset(torch.utils.data.Dataset): #print(f"Warning! Truncating clip {filename} from {audio_norm.shape[-1]} to {self.pad_to}") audio_norm = audio_norm[:, :self.pad_to] + # Bail and try the next clip if there is not enough data. + if audio_norm.shape[-1] < self.min_sz: + return self[(index + 1) % len(self)] + output = { 'clip': audio_norm, 'path': filename, diff --git a/codes/models/diffusion/nn.py b/codes/models/diffusion/nn.py index a4cd59c2..bf451585 100644 --- a/codes/models/diffusion/nn.py +++ b/codes/models/diffusion/nn.py @@ -97,7 +97,12 @@ def normalization(channels): :param channels: number of input channels. :return: an nn.Module for normalization. """ - return GroupNorm32(32, channels) + if channels <= 16: + return GroupNorm32(8, channels) + elif channels <= 64: + return GroupNorm32(16, channels) + else: + return GroupNorm32(32, channels) def timestep_embedding(timesteps, dim, max_period=10000): diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index 83f6c5f5..2446f913 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -186,6 +186,8 @@ class ResBlock(TimestepBlock): dims=2, up=False, down=False, + kernel_size=3, + padding=1, ): super().__init__() self.channels = channels @@ -198,7 +200,7 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), ) self.updown = up or down @@ -224,7 +226,7 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) ), ) @@ -232,7 +234,7 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 + dims, channels, self.out_channels, kernel_size, padding=padding ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -922,4 +924,4 @@ if __name__ == '__main__': l = torch.randn(1,3,32,32) ts = torch.LongTensor([555]) y = srm(x, ts, low_res=l) - print(y.shape, y.mean(), y.std(), y.min(), y.max()) + print(y.shape, y.mean(), y.std(), y.min(), y.max()) \ No newline at end of file diff --git a/codes/models/diffusion/unet_diffusion_vocoder.py b/codes/models/diffusion/unet_diffusion_vocoder.py new file mode 100644 index 00000000..27324f03 --- /dev/null +++ b/codes/models/diffusion/unet_diffusion_vocoder.py @@ -0,0 +1,302 @@ +from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16 +from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear +from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \ + Downsample, Upsample +import torch +import torch.nn as nn + +from trainer.networks import register_model + + +class DiffusionVocoder(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + Customized to be conditioned on a spectrogram prior. + + :param in_channels: channels in the input Tensor. + :param spectrogram_channels: channels in the conditioning spectrogram. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + model_channels, + num_res_blocks, + in_channels=1, + out_channels=2, # mean and variance + spectrogram_channels=80, + spectrogram_conditioning_level=3, # Level at which spectrogram conditioning is applied to the waveform. + dropout=0, + # 106496 -> 26624 -> 6656 -> 16664 -> 416 -> 104 -> 26 for ~5secs@22050Hz + channel_mult=(1, 2, 4, 8, 16, 32, 64), + attention_resolutions=(16,32,64), + conv_resample=True, + dims=1, + num_classes=None, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.spectrogram_channels = spectrogram_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.dims = dims + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + spec_chs = channel_mult[spectrogram_conditioning_level] * model_channels + self.spectrogram_conditioner = nn.Sequential( + conv_nd(dims, self.spectrogram_channels, spec_chs, 1), + normalization(spec_chs), + nn.SiLU(), + conv_nd(dims, spec_chs, spec_chs, 1) + ) + self.convergence_conv = nn.Sequential( + normalization(spec_chs*2), + nn.SiLU(), + conv_nd(dims, spec_chs*2, spec_chs*2, 1) + ) + + for level, mult in enumerate(channel_mult): + if level == spectrogram_conditioning_level+1: + ch *= 2 # At this level, the spectrogram is concatenated onto the input. + + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + if level == spectrogram_conditioning_level: + self.input_block_injection_point = len(self.input_blocks)-1 + input_block_chans[-1] *= 2 + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps, spectrogram): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert x.shape[-1] % 4096 == 0 # This model operates at base//4096 at it's bottom levels, thus this requirement. + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + conditioning = self.spectrogram_conditioner(spectrogram) + + h = x.type(self.dtype) + for k, module in enumerate(self.input_blocks): + h = module(h, emb) + if k == self.input_block_injection_point: + cond = nn.functional.interpolate(conditioning, size=h.shape[-self.dims:], mode='nearest') + h = torch.cat([h, cond], dim=1) + h = self.convergence_conv(h) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(x.dtype) + return self.out(h) + + +@register_model +def register_unet_diffusion_vocoder(opt_net, opt): + return DiffusionVocoder(**opt_net['kwargs']) + + +# Test for ~4 second audio clip at 22050Hz +if __name__ == '__main__': + clip = torch.randn(1, 1, 81920) + spec = torch.randn(1, 80, 416) + ts = torch.LongTensor([555]) + model = DiffusionVocoder(16, 2) + print(model(clip, ts, spec).shape) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 6d81dab2..31583d7e 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -190,6 +190,12 @@ class DiscreteVAE(nn.Module): images = self.decoder(image_embeds) return images + def infer(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0,2,3,1) if len(img.shape) == 4 else (0,2,1)) + sampled, commitment_loss, codes = self.codebook(logits) + return self.decode(codes) + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially # more lossy (but useful for determining network performance). diff --git a/codes/train.py b/codes/train.py index e618b709..83f990ae 100644 --- a/codes/train.py +++ b/codes/train.py @@ -284,7 +284,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_lrdvae_audio_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_from_dvae_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index 9d9ae00a..c77fd963 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -535,6 +535,18 @@ class MelSpectrogramInjector(Injector): return {self.output: self.stft.mel_spectrogram(inp)} +class RandomAudioCropInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.crop_sz = opt['crop_size'] + + def forward(self, state): + inp = state[self.input] + len = inp.shape[-1] + margin = len - self.crop_sz + start = random.randint(0, margin) + return {self.output: inp[:, :, start:start+self.crop_sz]} + if __name__ == '__main__': inj = DecomposeDimensionInjector({'dim':2, 'in': 'x', 'out': 'y'}, None)