Adjust diffusion vocoder to allow training individual levels

This commit is contained in:
James Betker 2022-01-19 13:37:59 -07:00
parent ac13bfefe8
commit 4af8525dc3

View File

@ -11,12 +11,13 @@ from utils.util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock(nn.Module): class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels): def __init__(self, dvae_channels, channels, level):
super().__init__() super().__init__()
self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1), self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
normalization(channels), normalization(channels),
nn.SiLU(), nn.SiLU(),
nn.Conv1d(channels, channels, kernel_size=3)) nn.Conv1d(channels, channels, kernel_size=3))
self.level = level
""" """
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape. Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
@ -91,7 +92,7 @@ class DiffusionVocoderWithRef(nn.Module):
conditioning_inputs_provided=True, conditioning_inputs_provided=True,
conditioning_input_dim=80, conditioning_input_dim=80,
time_embed_dim_multiplier=4, time_embed_dim_multiplier=4,
only_train_dvae_connection_layers=False, freeze_layers_below=None, # powers of 2; ex: 1,2,4,8,16,32,etc..
): ):
super().__init__() super().__init__()
@ -125,13 +126,11 @@ class DiffusionVocoderWithRef(nn.Module):
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1, self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5) attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.input_blocks = nn.ModuleList( seqlyr = TimestepEmbedSequential(
[ conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
)
]
) )
seqlyr.level = 0
self.input_blocks = nn.ModuleList([seqlyr])
spectrogram_blocks = [] spectrogram_blocks = []
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
@ -140,7 +139,7 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions: if ds in spectrogram_conditioning_resolutions:
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch) spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level)
self.input_blocks.append(spec_cond_block) self.input_blocks.append(spec_cond_block)
spectrogram_blocks.append(spec_cond_block) spectrogram_blocks.append(spec_cond_block)
ch *= 2 ch *= 2
@ -167,13 +166,14 @@ class DiffusionVocoderWithRef(nn.Module):
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.input_blocks.append(layer)
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
out_ch = ch out_ch = ch
self.input_blocks.append( upblk = TimestepEmbedSequential(
TimestepEmbedSequential(
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
@ -189,7 +189,8 @@ class DiffusionVocoderWithRef(nn.Module):
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
) )
) )
) upblk.level = 2 ** level
self.input_blocks.append(upblk)
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
ds *= 2 ds *= 2
@ -263,7 +264,9 @@ class DiffusionVocoderWithRef(nn.Module):
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.output_blocks.append(layer)
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
@ -272,14 +275,31 @@ class DiffusionVocoderWithRef(nn.Module):
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
) )
if only_train_dvae_connection_layers: if freeze_layers_below is not None:
# Freeze all parameters first.
for p in self.parameters(): for p in self.parameters():
p.DO_NOT_TRAIN = True p.DO_NOT_TRAIN = True
p.requires_grad = False p.requires_grad = False
for sb in spectrogram_blocks: # Now un-freeze the modules we actually want to train.
for p in sb.parameters(): unfrozen_modules = [self.out]
for blk in self.input_blocks:
if blk.level <= freeze_layers_below:
unfrozen_modules.append(blk)
last_frozen_output_block = None
for blk in self.output_blocks:
if blk.level <= freeze_layers_below:
unfrozen_modules.append(blk)
else:
last_frozen_output_block = blk
# And finally, the last upsample block in output blocks.
unfrozen_modules.append(last_frozen_output_block[1])
unfrozen_params = 0
for m in unfrozen_modules:
for p in m.parameters():
del p.DO_NOT_TRAIN del p.DO_NOT_TRAIN
p.requires_grad = True p.requires_grad = True
unfrozen_params += 1
print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
def forward(self, x, timesteps, spectrogram, conditioning_input=None): def forward(self, x, timesteps, spectrogram, conditioning_input=None):
""" """
@ -317,6 +337,23 @@ class DiffusionVocoderWithRef(nn.Module):
return self.out(h) return self.out(h)
def move_all_layers_down(pretrained_path, output_path, layers_to_be_added=3):
# layers_to_be_added should be=num_res_blocks+1+[1if spectrogram_conditioning_resolutions;else0]
sd = torch.load(pretrained_path)
out = sd.copy()
replaced = []
for n, p in sd.items():
if n.startswith('input_blocks.') and not n.startswith('input_blocks.0.'):
if n not in replaced:
del out[n]
components = n.split('.')
components[1] = str(int(components[1]) + layers_to_be_added)
new_name = '.'.join(components)
out[new_name] = p
replaced.append(new_name)
torch.save(out, output_path)
@register_model @register_model
def register_unet_diffusion_vocoder_with_ref(opt_net, opt): def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
return DiffusionVocoderWithRef(**opt_net['kwargs']) return DiffusionVocoderWithRef(**opt_net['kwargs'])
@ -324,10 +361,30 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz # Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__': if __name__ == '__main__':
path = 'X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth'
move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2)
clip = torch.randn(2, 1, 40960) clip = torch.randn(2, 1, 40960)
#spec = torch.randint(8192, (2, 40,)) spec = torch.randn(2,80,160)
spec = torch.randn(2,512,160)
cond = torch.randn(2, 1, 40960) cond = torch.randn(2, 1, 40960)
ts = torch.LongTensor([555, 556]) ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8) model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ],
num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
discrete_codes=80, freeze_layers_below=1)
loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False)
new_params = loading_errors.missing_keys
new_params_trained = []
existing_params_trained = []
for n,p in model.named_parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
if n in new_params:
new_params_trained.append(n)
else:
existing_params_trained.append(n)
for n in new_params:
if n not in new_params_trained:
print(f"{n} is a new parameter, but it is not marked as trainable.")
print(model(clip, ts, spec, cond).shape) print(model(clip, ts, spec, cond).shape)