From 4af8525dc3872afcb336ef0b3b4819c5939a0b91 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 19 Jan 2022 13:37:59 -0700 Subject: [PATCH] Adjust diffusion vocoder to allow training individual levels --- .../unet_diffusion_vocoder_with_ref.py | 97 +++++++++++++++---- 1 file changed, 77 insertions(+), 20 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py index 4a5dae31..7cfc3ce5 100644 --- a/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py +++ b/codes/models/gpt_voice/unet_diffusion_vocoder_with_ref.py @@ -11,12 +11,13 @@ from utils.util import get_mask_from_lengths class DiscreteSpectrogramConditioningBlock(nn.Module): - def __init__(self, dvae_channels, channels): + def __init__(self, dvae_channels, channels, level): super().__init__() self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1), normalization(channels), nn.SiLU(), nn.Conv1d(channels, channels, kernel_size=3)) + self.level = level """ Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape. @@ -91,7 +92,7 @@ class DiffusionVocoderWithRef(nn.Module): conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4, - only_train_dvae_connection_layers=False, + freeze_layers_below=None, # powers of 2; ex: 1,2,4,8,16,32,etc.. ): super().__init__() @@ -125,13 +126,11 @@ class DiffusionVocoderWithRef(nn.Module): self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) - ) - ] + seqlyr = TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding) ) + seqlyr.level = 0 + self.input_blocks = nn.ModuleList([seqlyr]) spectrogram_blocks = [] self._feature_size = model_channels input_block_chans = [model_channels] @@ -140,7 +139,7 @@ class DiffusionVocoderWithRef(nn.Module): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): if ds in spectrogram_conditioning_resolutions: - spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch) + spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level) self.input_blocks.append(spec_cond_block) spectrogram_blocks.append(spec_cond_block) ch *= 2 @@ -167,13 +166,14 @@ class DiffusionVocoderWithRef(nn.Module): use_new_attention_order=use_new_attention_order, ) ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) + layer = TimestepEmbedSequential(*layers) + layer.level = 2 ** level + self.input_blocks.append(layer) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( + upblk = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, @@ -189,7 +189,8 @@ class DiffusionVocoderWithRef(nn.Module): ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor ) ) - ) + upblk.level = 2 ** level + self.input_blocks.append(upblk) ch = out_ch input_block_chans.append(ch) ds *= 2 @@ -263,7 +264,9 @@ class DiffusionVocoderWithRef(nn.Module): else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) ) ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) + layer = TimestepEmbedSequential(*layers) + layer.level = 2 ** level + self.output_blocks.append(layer) self._feature_size += ch self.out = nn.Sequential( @@ -272,14 +275,31 @@ class DiffusionVocoderWithRef(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) - if only_train_dvae_connection_layers: + if freeze_layers_below is not None: + # Freeze all parameters first. for p in self.parameters(): p.DO_NOT_TRAIN = True p.requires_grad = False - for sb in spectrogram_blocks: - for p in sb.parameters(): + # Now un-freeze the modules we actually want to train. + unfrozen_modules = [self.out] + for blk in self.input_blocks: + if blk.level <= freeze_layers_below: + unfrozen_modules.append(blk) + last_frozen_output_block = None + for blk in self.output_blocks: + if blk.level <= freeze_layers_below: + unfrozen_modules.append(blk) + else: + last_frozen_output_block = blk + # And finally, the last upsample block in output blocks. + unfrozen_modules.append(last_frozen_output_block[1]) + unfrozen_params = 0 + for m in unfrozen_modules: + for p in m.parameters(): del p.DO_NOT_TRAIN p.requires_grad = True + unfrozen_params += 1 + print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.") def forward(self, x, timesteps, spectrogram, conditioning_input=None): """ @@ -317,6 +337,23 @@ class DiffusionVocoderWithRef(nn.Module): return self.out(h) +def move_all_layers_down(pretrained_path, output_path, layers_to_be_added=3): + # layers_to_be_added should be=num_res_blocks+1+[1if spectrogram_conditioning_resolutions;else0] + sd = torch.load(pretrained_path) + out = sd.copy() + replaced = [] + for n, p in sd.items(): + if n.startswith('input_blocks.') and not n.startswith('input_blocks.0.'): + if n not in replaced: + del out[n] + components = n.split('.') + components[1] = str(int(components[1]) + layers_to_be_added) + new_name = '.'.join(components) + out[new_name] = p + replaced.append(new_name) + torch.save(out, output_path) + + @register_model def register_unet_diffusion_vocoder_with_ref(opt_net, opt): return DiffusionVocoderWithRef(**opt_net['kwargs']) @@ -324,10 +361,30 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt): # Test for ~4 second audio clip at 22050Hz if __name__ == '__main__': + path = 'X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth' + move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2) + clip = torch.randn(2, 1, 40960) - #spec = torch.randint(8192, (2, 40,)) - spec = torch.randn(2,512,160) + spec = torch.randn(2,80,160) cond = torch.randn(2, 1, 40960) ts = torch.LongTensor([555, 556]) - model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8) + model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ], + num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512], + dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2, + conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4, + discrete_codes=80, freeze_layers_below=1) + loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False) + new_params = loading_errors.missing_keys + new_params_trained = [] + existing_params_trained = [] + for n,p in model.named_parameters(): + if not hasattr(p, 'DO_NOT_TRAIN'): + if n in new_params: + new_params_trained.append(n) + else: + existing_params_trained.append(n) + for n in new_params: + if n not in new_params_trained: + print(f"{n} is a new parameter, but it is not marked as trainable.") + print(model(clip, ts, spec, cond).shape)