From 1d0b44ebc293b5c355ad3f7e43ffac504dc5524a Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 15 Oct 2021 11:51:17 -0600 Subject: [PATCH] More tweaks to diffusion-vocoder --- codes/models/diffusion/nn.py | 11 ++++-- .../unet_diffusion_vocoder_with_ref.py | 39 ++++++++++++------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/codes/models/diffusion/nn.py b/codes/models/diffusion/nn.py index bf451585..169bc0d9 100644 --- a/codes/models/diffusion/nn.py +++ b/codes/models/diffusion/nn.py @@ -97,12 +97,15 @@ def normalization(channels): :param channels: number of input channels. :return: an nn.Module for normalization. """ + groups = 32 if channels <= 16: - return GroupNorm32(8, channels) + groups = 8 elif channels <= 64: - return GroupNorm32(16, channels) - else: - return GroupNorm32(32, channels) + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) def timestep_embedding(timesteps, dim, max_period=10000): diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index a6ffecfc..9be0c264 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -14,9 +14,15 @@ class DiscreteSpectrogramConditioningBlock(nn.Module): def __init__(self, discrete_codes, channels): super().__init__() self.emb = nn.Embedding(discrete_codes, channels) + self.norm = normalization(channels) + self.act = nn.SiLU() + self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1), + normalization(channels*2), + nn.SiLU(), + nn.Conv1d(channels*2, channels, kernel_size=3, padding=1)) """ - Embeds the given codes and concatenates them onto x. Return shape: bx2cxS + Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape. :param x: bxcxS waveform latent :param codes: bxN discrete codes, N <= S @@ -27,7 +33,9 @@ class DiscreteSpectrogramConditioningBlock(nn.Module): assert N <= S emb = self.emb(codes).permute(0,2,1) emb = nn.functional.interpolate(emb, size=(S,), mode='nearest') - return torch.cat([x, emb], dim=1) + together = torch.cat([self.act(self.norm(x)), emb], dim=1) + together = self.intg(together) + return together + x class DiffusionVocoderWithRef(nn.Module): @@ -68,11 +76,13 @@ class DiffusionVocoderWithRef(nn.Module): out_channels=2, # mean and variance discrete_codes=8192, dropout=0, - # 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz - channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64), - num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), - spectrogram_conditioning_resolutions=(512,), - attention_resolutions=(512,1024,2048,4096), + # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K + channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48), + num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2), + # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) + # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 + spectrogram_conditioning_resolutions=(1,8,64,512), + attention_resolutions=(512,1024,2048), conv_resample=True, dims=1, use_fp16=False, @@ -136,7 +146,6 @@ class DiffusionVocoderWithRef(nn.Module): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): if ds in spectrogram_conditioning_resolutions: self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch)) - ch *= 2 for _ in range(num_blocks): layers = [ @@ -144,13 +153,13 @@ class DiffusionVocoderWithRef(nn.Module): ch, time_embed_dim, dropout, - out_channels=mult * model_channels, + out_channels=int(mult * model_channels), dims=dims, use_scale_shift_norm=use_scale_shift_norm, kernel_size=kernel_size, ) ] - ch = mult * model_channels + ch = int(mult * model_channels) if ds in attention_resolutions: layers.append( AttentionBlock( @@ -223,13 +232,13 @@ class DiffusionVocoderWithRef(nn.Module): ch + ich, time_embed_dim, dropout, - out_channels=model_channels * mult, + out_channels=int(model_channels * mult), dims=dims, use_scale_shift_norm=use_scale_shift_norm, kernel_size=kernel_size, ) ] - ch = model_channels * mult + ch = int(model_channels * mult) if ds in attention_resolutions: layers.append( AttentionBlock( @@ -326,9 +335,9 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': - clip = torch.randn(2, 1, 81920) - spec = torch.randint(8192, (2, 160,)) - cond = torch.randn(2, 4, 80, 600) + clip = torch.randn(2, 1, 40960) + spec = torch.randint(8192, (2, 40,)) + cond = torch.randn(2, 3, 80, 173) ts = torch.LongTensor([555, 556]) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False) print(model(clip, ts, spec, cond, 4).shape)