From ff8b0533acb6db2afe00f044607344f71c414858 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 19:23:48 -0600 Subject: [PATCH] gen3 waveform --- ...rdb.py => unet_diffusion_waveform_gen3.py} | 61 ++++++------------- 1 file changed, 20 insertions(+), 41 deletions(-) rename codes/models/audio/music/{unet_diffusion_waveform_gen_rrdb.py => unet_diffusion_waveform_gen3.py} (86%) diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen_rrdb.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py similarity index 86% rename from codes/models/audio/music/unet_diffusion_waveform_gen_rrdb.py rename to codes/models/audio/music/unet_diffusion_waveform_gen3.py index 2214f297..857a849e 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen_rrdb.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py @@ -7,7 +7,7 @@ from models.diffusion.unet_diffusion import TimestepEmbedSequential, \ Downsample, Upsample, TimestepBlock from scripts.audio.gen.use_diffuse_tts import ceil_multiple from trainer.networks import register_model -from utils.util import checkpoint +from utils.util import checkpoint, print_network def is_sequence(t): @@ -23,7 +23,7 @@ class ResBlock(TimestepBlock): out_channels=None, dims=2, kernel_size=3, - efficient_config=True, + efficient_config=False, use_scale_shift_norm=False, ): super().__init__() @@ -93,6 +93,8 @@ class ResBlock(TimestepBlock): class StackedResidualBlock(TimestepBlock): def __init__(self, channels, emb_channels, dropout): + super().__init__() + self.emb_layers = nn.Sequential( nn.SiLU(), linear( @@ -102,29 +104,30 @@ class StackedResidualBlock(TimestepBlock): ) gc = channels // 4 - super().__init__() self.initial_norm = nn.GroupNorm(num_groups=8, num_channels=channels) for i in range(5): out_channels = channels if i == 4 else gc self.add_module( f'conv{i + 1}', - nn.Conv2d(channels + i * gc, out_channels, 3, 1, 1)) - self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=channels)) + nn.Conv1d(channels + i * gc, out_channels, 3, 1, 1)) + self.add_module(f'gn{i+1}', nn.GroupNorm(num_groups=8, num_channels=out_channels)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) zero_module(self.conv5) + self.drop = nn.Dropout(p=dropout) - def forawrd(self, x, emb): + def forward(self, x, emb): return checkpoint(self.forward_, x, emb) def forward_(self, x, emb): - emb_out = self.emb_layers(emb).type(h.dtype) + emb_out = self.emb_layers(emb) scale, shift = torch.chunk(emb_out, 2, dim=1) - x0 = self.initial_norm(x) * (1 + scale) + shift + x0 = self.initial_norm(x) * (1 + scale.unsqueeze(-1)) + shift.unsqueeze(-1) x1 = self.lrelu(self.gn1(self.conv1(x0))) x2 = self.lrelu(self.gn2(self.conv2(torch.cat((x, x1), 1)))) x3 = self.lrelu(self.gn3(self.conv3(torch.cat((x, x1, x2), 1)))) x4 = self.lrelu(self.gn4(self.conv4(torch.cat((x, x1, x2, x3), 1)))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + x5 = self.drop(x5) return x5 + x @@ -152,15 +155,14 @@ class DiffusionWaveformGen(nn.Module): def __init__( self, - model_channels, + model_channels=512, in_channels=64, in_mel_channels=256, - conditioning_dim_factor=8, - conditioning_expansion=4, + conditioning_dim_factor=4, out_channels=128, # mean and variance dropout=0, channel_mult= (1,1.5,2), - num_res_blocks=(1,1,1), + num_res_blocks=(1,1,0), token_conditioning_resolutions=(1,4), mid_resnet_depth=10, conv_resample=True, @@ -168,9 +170,8 @@ class DiffusionWaveformGen(nn.Module): use_fp16=False, kernel_size=3, scale_factor=2, - time_embed_dim_multiplier=4, + time_embed_dim_multiplier=1, freeze_main_net=False, - efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3. use_scale_shift_norm=True, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. @@ -197,7 +198,6 @@ class DiffusionWaveformGen(nn.Module): self.freeze_main_net = freeze_main_net self.in_mel_channels = in_mel_channels padding = 1 if kernel_size == 3 else 2 - down_kernel = 1 if efficient_convs else 3 time_embed_dim = model_channels * time_embed_dim_multiplier self.time_embed = nn.Sequential( @@ -213,12 +213,6 @@ class DiffusionWaveformGen(nn.Module): # transformer network. self.mel_converter = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1) self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) - self.conditioning_timestep_integrator = TimestepEmbedSequential( - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), - ) - self.conditioning_expansion = conditioning_expansion self.input_blocks = nn.ModuleList( [ @@ -249,7 +243,6 @@ class DiffusionWaveformGen(nn.Module): out_channels=int(mult * model_channels), dims=dims, kernel_size=kernel_size, - efficient_config=efficient_convs, use_scale_shift_norm=use_scale_shift_norm, ) ] @@ -262,7 +255,7 @@ class DiffusionWaveformGen(nn.Module): self.input_blocks.append( TimestepEmbedSequential( Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1 + ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1 ) ) ) @@ -286,7 +279,6 @@ class DiffusionWaveformGen(nn.Module): out_channels=int(model_channels * mult), dims=dims, kernel_size=kernel_size, - efficient_config=efficient_convs, use_scale_shift_norm=use_scale_shift_norm, ) ] @@ -360,11 +352,6 @@ class DiffusionWaveformGen(nn.Module): else: code_emb = self.mel_converter(aligned_conditioning) - # Everything after this comment is timestep dependent. - code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) - code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) - - first = True time_emb = time_emb.float() h = x for k, module in enumerate(self.input_blocks): @@ -374,7 +361,6 @@ class DiffusionWaveformGen(nn.Module): else: h = module(h, time_emb) hs.append(h) - first = False h = self.middle_block(h, time_emb) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) @@ -398,18 +384,11 @@ def register_unet_diffusion_waveform_gen3(opt_net, opt): if __name__ == '__main__': - clip = torch.randn(2, 1, 32868) - aligned_sequence = torch.randn(2,120,220) + clip = torch.randn(2, 64, 880) + aligned_sequence = torch.randn(2,256,220) ts = torch.LongTensor([600, 600]) - model = DiffusionWaveformGen(128, - channel_mult=[1,1.5,2, 3, 4, 6, 8], - num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - token_conditioning_resolutions=[1,4,16,64], - kernel_size=3, - scale_factor=2, - time_embed_dim_multiplier=4, - super_sampling=False, - efficient_convs=False) + model = DiffusionWaveformGen() # Test with sequence aligned conditioning o = model(clip, ts, aligned_sequence) + print_network(model)