diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen2.py b/codes/models/audio/music/unet_diffusion_waveform_gen2.py index 1e334b46..cd7eb5d2 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen2.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen2.py @@ -144,25 +144,41 @@ class ResBlockSimple(nn.Module): return self.skip_connection(x) + h -class StructuralProcessor(nn.Module): +class AudioVAE(nn.Module): def __init__(self, channels, dropout): super().__init__() - # 256,128,64,32,16,8,4,2,1 - level_resblocks = [3, 3, 2, 2, 2,1,1,1] - level_ch_div = [1, 1, 2, 4, 4,8,8,16] + # 1, 4, 16, 64, 256 + level_resblocks = [1, 1, 2, 2, 2] + level_ch_mult = [1, 2, 4, 6, 8] levels = [] - lastdiv = 1 - for resblks, chdiv in zip(level_resblocks, level_ch_div): - levels.append(nn.Sequential(*([nn.Conv1d(channels//lastdiv, channels//chdiv, 1)] + - [ResBlockSimple(channels//chdiv, dropout) for _ in range(resblks)]))) + for i, (resblks, chdiv) in enumerate(zip(level_resblocks, level_ch_mult)): + blocks = [ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)] + if i != len(level_ch_mult)-1: + blocks.append(nn.Conv1d(channels*chdiv, channels*level_ch_mult[i+1], kernel_size=5, padding=2, stride=4)) + levels.append(nn.Sequential(*blocks)) + self.down_levels = nn.ModuleList(levels) + + levels = [] + lastdiv = None + for resblks, chdiv in reversed(list(zip(level_resblocks, level_ch_mult))): + if lastdiv is not None: + blocks = [nn.Conv1d(channels*lastdiv, channels*chdiv, kernel_size=5, padding=2)] + else: + blocks = [] + blocks.extend([ResBlockSimple(channels*chdiv, dropout=dropout, kernel_size=5) for _ in range(resblks)]) + levels.append(nn.Sequential(*blocks)) lastdiv = chdiv - self.levels = nn.ModuleList(levels) + self.up_levels = nn.ModuleList(levels) def forward(self, x): h = x - for level in self.levels: + for level in self.down_levels: h = level(h) - h = F.interpolate(h, scale_factor=2, mode='linear') + + for k, level in enumerate(self.up_levels): + h = level(h) + if k != len(self.up_levels)-1: + h = F.interpolate(h, scale_factor=4, mode='linear') return h @@ -178,20 +194,10 @@ class DiffusionTts(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially @@ -202,8 +208,6 @@ class DiffusionTts(nn.Module): self, model_channels, in_channels=1, - in_mel_channels=120, - conditioning_dim_factor=8, out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K @@ -211,13 +215,9 @@ class DiffusionTts(nn.Module): num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2), # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0) # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1 - attention_resolutions=(512,1024,2048), conv_resample=True, dims=1, use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4, @@ -229,24 +229,16 @@ class DiffusionTts(nn.Module): ): super().__init__() - if num_heads_upsample == -1: - num_heads_upsample = num_heads - self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample self.dims = dims self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 - self.alignment_size = 2 ** (len(channel_mult)+1) - self.in_mel_channels = in_mel_channels + self.alignment_size = max(2 ** (len(channel_mult)+1), 256) padding = 1 if kernel_size == 3 else 2 down_kernel = 1 if efficient_convs else 3 @@ -257,18 +249,17 @@ class DiffusionTts(nn.Module): linear(time_embed_dim, time_embed_dim), ) - conditioning_dim = model_channels * conditioning_dim_factor - self.structural_cond_input = nn.Conv1d(in_mel_channels, conditioning_dim, 3, padding=1) - self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_mel_channels,1)) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1)) - self.structural_processor = StructuralProcessor(conditioning_dim, dropout) - self.surrogate_head = nn.Conv1d(conditioning_dim//16, in_channels, 1) + self.structural_cond_input = nn.Conv1d(in_channels, model_channels, kernel_size=5, padding=2) + self.aligned_latent_padding_embedding = nn.Parameter(torch.zeros(1,in_channels,1)) + self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) + self.structural_processor = AudioVAE(model_channels, dropout) + self.surrogate_head = nn.Conv1d(model_channels, in_channels, 1) - self.input_block = conv_nd(dims, in_channels, model_channels//2, kernel_size, padding=padding) + self.input_block = conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, model_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, model_channels*2, model_channels, 1) ) ] ) @@ -292,14 +283,6 @@ class DiffusionTts(nn.Module): ) ] ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ) - ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -327,20 +310,6 @@ class DiffusionTts(nn.Module): efficient_config=efficient_convs, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - num_heads=num_heads, - num_head_channels=num_head_channels, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - kernel_size=kernel_size, - efficient_config=efficient_convs, - use_scale_shift_norm=use_scale_shift_norm, - ), ) self._feature_size += ch @@ -361,14 +330,6 @@ class DiffusionTts(nn.Module): ) ] ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - ) - ) if level and i == num_blocks: out_ch = ch layers.append( @@ -403,9 +364,6 @@ class DiffusionTts(nn.Module): } return groups - def is_latent(self, t): - return t.shape[1] != self.in_mel_channels - def fix_alignment(self, x, aligned_conditioning): """ The UNet requires that the input is a certain multiple of 2, defined by the UNet depth. Enforce this by @@ -415,36 +373,26 @@ class DiffusionTts(nn.Module): if cm != 0: pc = (cm-x.shape[-1])/x.shape[-1] x = F.pad(x, (0,cm-x.shape[-1])) - # Also fix aligned_latent, which is aligned to x. - if self.is_latent(aligned_conditioning): - aligned_conditioning = torch.cat([aligned_conditioning, - self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) - else: - aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) return x, aligned_conditioning - def forward(self, x, timesteps, aligned_conditioning, conditioning_free=False): + def forward(self, x, timesteps, conditioning, conditioning_free=False): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. - :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning: should just be the truth value. produces a latent through an autoencoder, then uses diffusion to decode that latent. + at inference, only the latent is passed in. :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ - # Shuffle aligned_latent to BxCxS format - if self.is_latent(aligned_conditioning): - aligned_conditioning = aligned_conditioning.permute(0, 2, 1) - # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net. orig_x_shape = x.shape[-1] - x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning) + x, aligned_conditioning = self.fix_alignment(x, conditioning) with autocast(x.device.type, enabled=self.enable_fp16): - hs = [] - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) # Note: this block does not need to repeated on inference, since it is not timestep-dependent. if conditioning_free: @@ -456,10 +404,12 @@ class DiffusionTts(nn.Module): code_emb = F.interpolate(code_emb, size=(x.shape[-1],), mode='linear') surrogate = self.surrogate_head(code_emb) - # Everything after this comment is timestep dependent. x = self.input_block(x) x = torch.cat([x, code_emb], dim=1) + # Everything after this comment is timestep dependent. + hs = [] + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) time_emb = time_emb.float() h = x for k, module in enumerate(self.input_blocks): @@ -493,13 +443,11 @@ def register_unet_diffusion_waveform_gen2(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 1, 32868) - aligned_sequence = torch.randn(2,120,128) + aligned_sequence = torch.randn(2,1,32868) ts = torch.LongTensor([600, 600]) model = DiffusionTts(128, channel_mult=[1,1.5,2, 3, 4, 6, 8], num_res_blocks=[2, 2, 2, 2, 2, 2, 1], - attention_resolutions=[], - num_heads=8, kernel_size=3, scale_factor=2, time_embed_dim_multiplier=4,