From a5d2123daafd41daa9d5716af51ca59067616837 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Jun 2022 21:04:51 -0600 Subject: [PATCH] more cleanup --- .../music/unet_diffusion_waveform_gen3.py | 43 ++++--------------- codes/trainer/eval/music_diffusion_fid.py | 5 ++- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/codes/models/audio/music/unet_diffusion_waveform_gen3.py b/codes/models/audio/music/unet_diffusion_waveform_gen3.py index 28fee999..165c8cd2 100644 --- a/codes/models/audio/music/unet_diffusion_waveform_gen3.py +++ b/codes/models/audio/music/unet_diffusion_waveform_gen3.py @@ -146,8 +146,6 @@ class DiffusionWaveformGen(nn.Module): :param num_res_blocks: number of residual blocks per downsample. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. @@ -165,39 +163,24 @@ class DiffusionWaveformGen(nn.Module): num_res_blocks=(1,1,0), token_conditioning_resolutions=(1,4), mid_resnet_depth=10, - conv_resample=True, dims=1, use_fp16=False, - kernel_size=3, - scale_factor=2, time_embed_dim_multiplier=1, - freeze_main_net=False, - use_scale_shift_norm=True, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - # Parameters for super-sampling. - super_sampling=False, - super_sampling_max_noising_factor=.1, ): super().__init__() - if super_sampling: - in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input. self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.dropout = dropout self.channel_mult = channel_mult - self.conv_resample = conv_resample self.dims = dims - self.super_sampling_enabled = super_sampling - self.super_sampling_max_noising_factor = super_sampling_max_noising_factor self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 self.alignment_size = 2 ** (len(channel_mult)+1) - self.freeze_main_net = freeze_main_net self.in_mel_channels = in_mel_channels - padding = 1 if kernel_size == 3 else 2 time_embed_dim = model_channels * time_embed_dim_multiplier self.time_embed = nn.Sequential( @@ -217,7 +200,7 @@ class DiffusionWaveformGen(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) + conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) @@ -242,8 +225,8 @@ class DiffusionWaveformGen(nn.Module): dropout, out_channels=int(mult * model_channels), dims=dims, - kernel_size=kernel_size, - use_scale_shift_norm=use_scale_shift_norm, + kernel_size=3, + use_scale_shift_norm=True, ) ] ch = int(mult * model_channels) @@ -255,7 +238,7 @@ class DiffusionWaveformGen(nn.Module): self.input_blocks.append( TimestepEmbedSequential( Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1 + ch, True, dims=dims, out_channels=out_ch, factor=2, ksize=3, pad=1 ) ) ) @@ -279,15 +262,15 @@ class DiffusionWaveformGen(nn.Module): dropout, out_channels=int(model_channels * mult), dims=dims, - kernel_size=kernel_size, - use_scale_shift_norm=use_scale_shift_norm, + kernel_size=3, + use_scale_shift_norm=True, ) ] ch = int(model_channels * mult) if level and i == num_blocks: out_ch = ch layers.append( - Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) + Upsample(ch, True, dims=dims, out_channels=out_ch, factor=2) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) @@ -296,20 +279,10 @@ class DiffusionWaveformGen(nn.Module): self.out = nn.Sequential( normalization(ch), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) - if self.freeze_main_net: - mains = [self.time_embed, self.contextual_embedder, self.unconditioned_embedding, self.conditioning_timestep_integrator, - self.input_blocks, self.middle_block, self.output_blocks, self.out] - for m in mains: - for p in m.parameters(): - p.requires_grad = False - p.DO_NOT_TRAIN = True - def get_grad_norm_parameter_groups(self): - if self.freeze_main_net: - return {} groups = { 'input_blocks': list(self.input_blocks.parameters()), 'output_blocks': list(self.output_blocks.parameters()), diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 066b2e28..4c3ca21c 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -61,6 +61,7 @@ class MusicDiffusionFid(evaluator.Evaluator): if mode == 'spec_decode': self.diffusion_fn = self.perform_diffusion_spec_decode + self.squeeze_ratio = opt_eval['squeeze_ratio'] elif 'from_codes' == mode: self.diffusion_fn = self.perform_diffusion_from_codes self.local_modules['codegen'] = get_music_codegen() @@ -81,11 +82,11 @@ class MusicDiffusionFid(evaluator.Evaluator): def perform_diffusion_spec_decode(self, audio, sample_rate=22050): real_resampled = audio audio = audio.unsqueeze(0) - output_shape = (1, 256, audio.shape[-1] // 256) + output_shape = (1, self.squeeze_ratio, audio.shape[-1] // self.squeeze_ratio) mel = self.spec_fn({'in': audio})['out'] gen = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': mel}) - gen = pixel_shuffle_1d(gen, 256) + gen = pixel_shuffle_1d(gen, self.squeeze_ratio) return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate