Adjust diffusion vocoder to allow training individual levels

This commit is contained in:
James Betker 2022-01-19 13:37:59 -07:00
parent ac13bfefe8
commit 4af8525dc3

View File

@ -11,12 +11,13 @@ from utils.util import get_mask_from_lengths
class DiscreteSpectrogramConditioningBlock(nn.Module):
def __init__(self, dvae_channels, channels):
def __init__(self, dvae_channels, channels, level):
super().__init__()
self.intg = nn.Sequential(nn.Conv1d(dvae_channels, channels, kernel_size=1),
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, channels, kernel_size=3))
self.level = level
"""
Embeds the given codes and concatenates them onto x. Return shape is the same as x.shape.
@ -91,7 +92,7 @@ class DiffusionVocoderWithRef(nn.Module):
conditioning_inputs_provided=True,
conditioning_input_dim=80,
time_embed_dim_multiplier=4,
only_train_dvae_connection_layers=False,
freeze_layers_below=None, # powers of 2; ex: 1,2,4,8,16,32,etc..
):
super().__init__()
@ -125,13 +126,11 @@ class DiffusionVocoderWithRef(nn.Module):
self.contextual_embedder = AudioMiniEncoder(in_channels, time_embed_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=2, num_attn_heads=2, dropout=dropout, downsample_factor=4, kernel_size=5)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
)
]
seqlyr = TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
)
seqlyr.level = 0
self.input_blocks = nn.ModuleList([seqlyr])
spectrogram_blocks = []
self._feature_size = model_channels
input_block_chans = [model_channels]
@ -140,7 +139,7 @@ class DiffusionVocoderWithRef(nn.Module):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch)
spec_cond_block = DiscreteSpectrogramConditioningBlock(discrete_codes, ch, 2 ** level)
self.input_blocks.append(spec_cond_block)
spectrogram_blocks.append(spec_cond_block)
ch *= 2
@ -167,13 +166,14 @@ class DiffusionVocoderWithRef(nn.Module):
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.input_blocks.append(layer)
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
upblk = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
@ -189,7 +189,8 @@ class DiffusionVocoderWithRef(nn.Module):
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor
)
)
)
upblk.level = 2 ** level
self.input_blocks.append(upblk)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
@ -263,7 +264,9 @@ class DiffusionVocoderWithRef(nn.Module):
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
layer = TimestepEmbedSequential(*layers)
layer.level = 2 ** level
self.output_blocks.append(layer)
self._feature_size += ch
self.out = nn.Sequential(
@ -272,14 +275,31 @@ class DiffusionVocoderWithRef(nn.Module):
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
)
if only_train_dvae_connection_layers:
if freeze_layers_below is not None:
# Freeze all parameters first.
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for sb in spectrogram_blocks:
for p in sb.parameters():
# Now un-freeze the modules we actually want to train.
unfrozen_modules = [self.out]
for blk in self.input_blocks:
if blk.level <= freeze_layers_below:
unfrozen_modules.append(blk)
last_frozen_output_block = None
for blk in self.output_blocks:
if blk.level <= freeze_layers_below:
unfrozen_modules.append(blk)
else:
last_frozen_output_block = blk
# And finally, the last upsample block in output blocks.
unfrozen_modules.append(last_frozen_output_block[1])
unfrozen_params = 0
for m in unfrozen_modules:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
unfrozen_params += 1
print(f"freeze_layers_below specified. Training a total of {unfrozen_params} parameters.")
def forward(self, x, timesteps, spectrogram, conditioning_input=None):
"""
@ -317,6 +337,23 @@ class DiffusionVocoderWithRef(nn.Module):
return self.out(h)
def move_all_layers_down(pretrained_path, output_path, layers_to_be_added=3):
# layers_to_be_added should be=num_res_blocks+1+[1if spectrogram_conditioning_resolutions;else0]
sd = torch.load(pretrained_path)
out = sd.copy()
replaced = []
for n, p in sd.items():
if n.startswith('input_blocks.') and not n.startswith('input_blocks.0.'):
if n not in replaced:
del out[n]
components = n.split('.')
components[1] = str(int(components[1]) + layers_to_be_added)
new_name = '.'.join(components)
out[new_name] = p
replaced.append(new_name)
torch.save(out, output_path)
@register_model
def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
return DiffusionVocoderWithRef(**opt_net['kwargs'])
@ -324,10 +361,30 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
path = 'X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth'
move_all_layers_down(path, 'diffuse_new_lyr.pth', layers_to_be_added=2)
clip = torch.randn(2, 1, 40960)
#spec = torch.randint(8192, (2, 40,))
spec = torch.randn(2,512,160)
spec = torch.randn(2,80,160)
cond = torch.randn(2, 1, 40960)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=True, time_embed_dim_multiplier=8)
model = DiffusionVocoderWithRef(model_channels=128, channel_mult=[1,1,1.5,2, 3, 4, 6, 8, 8, 8, 8 ],
num_res_blocks=[1,2, 2, 2, 2, 2, 2, 2, 2, 1, 1 ], spectrogram_conditioning_resolutions=[2,512],
dropout=.05, attention_resolutions=[512,1024], num_heads=4, kernel_size=3, scale_factor=2,
conditioning_inputs_provided=True, conditioning_input_dim=80, time_embed_dim_multiplier=4,
discrete_codes=80, freeze_layers_below=1)
loading_errors = model.load_state_dict(torch.load('diffuse_new_lyr.pth'), strict=False)
new_params = loading_errors.missing_keys
new_params_trained = []
existing_params_trained = []
for n,p in model.named_parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
if n in new_params:
new_params_trained.append(n)
else:
existing_params_trained.append(n)
for n in new_params:
if n not in new_params_trained:
print(f"{n} is a new parameter, but it is not marked as trainable.")
print(model(clip, ts, spec, cond).shape)