integrate tfd12 with cheater network

This commit is contained in:
James Betker 2022-06-17 09:08:56 -06:00
parent 9d7ce42630
commit 87a86ae6a8
2 changed files with 110 additions and 7 deletions

View File

@ -131,7 +131,6 @@ class GptMusicLower(nn.Module):
return groups
@register_model
def register_music_gpt_lower2(opt_net, opt):
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))

View File

@ -6,6 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock
from models.audio.music.gpt_music2 import UpperEncoder, GptMusicLower
from models.audio.music.music_quantizer2 import MusicQuantizer2
from models.audio.tts.lucidrains_dvae import DiscreteVAE
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
@ -168,11 +169,16 @@ class TransformerDiffusion(nn.Module):
for p in self.parameters():
p.DO_NOT_TRAIN = True
p.requires_grad = False
for m in [self.ar_input and self.ar_prior_intg]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
if hasattr(self, 'ar_input'):
for m in [self.ar_input and self.ar_prior_intg]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
if hasattr(self, 'code_converter'):
for m in [self.code_converter and self.input_converter]:
for p in m.parameters():
del p.DO_NOT_TRAIN
p.requires_grad = True
self.debug_codes = {}
@ -502,6 +508,61 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
p.grad *= .2
class TransformerDiffusionWithCheaterLatent(nn.Module):
def __init__(self, freeze_encoder_until=50000, **kwargs):
super().__init__()
self.internal_step = 0
self.freeze_encoder_until = freeze_encoder_until
self.diff = TransformerDiffusion(**kwargs)
self.encoder = UpperEncoder(256, 1024, 256)
self.encoder = self.encoder.eval()
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
encoder_grad_enabled = self.internal_step > self.freeze_encoder_until
with torch.set_grad_enabled(encoder_grad_enabled):
proj = self.encoder(truth_mel).permute(0,2,1)
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
return diff
def get_debug_values(self, step, __):
self.internal_step = step
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.diff.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.diff.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.diff.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
'blk1_ff_layers': ff1,
'blk2_ff_layers': ff2,
'ff_layers': ff1 + ff2,
'block_out_layers': blkout_layers,
'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),
'out': list(self.diff.out.parameters()),
'x_proj': list(self.diff.inp_block.parameters()),
'layers': list(self.diff.layers.parameters()),
'code_converters': list(self.diff.input_converter.parameters()) + list(self.diff.code_converter.parameters()),
'time_embed': list(self.diff.time_embed.parameters()),
'encoder': list(self.encoder.parameters()),
}
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.diff.layers]))
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
@register_model
def register_transformer_diffusion12(opt_net, opt):
return TransformerDiffusion(**opt_net['kwargs'])
@ -524,6 +585,10 @@ def register_transformer_diffusion_12_with_pretrained_vqvae(opt_net, opt):
def register_transformer_diffusion_12_with_multi_vqvae(opt_net, opt):
return TransformerDiffusionWithMultiPretrainedVqvae(**opt_net['kwargs'])
@register_model
def register_transformer_diffusion_12_with_cheater_latent(opt_net, opt):
return TransformerDiffusionWithCheaterLatent(**opt_net['kwargs'])
def test_quant_model():
clip = torch.randn(2, 256, 400)
@ -646,6 +711,45 @@ def test_ar_model():
model(clip, ts, cond, conditioning_input=cond)
def test_cheater_model():
clip = torch.randn(2, 256, 400)
ts = torch.LongTensor([600, 600])
# For music:
model = TransformerDiffusionWithCheaterLatent(in_channels=256, out_channels=512,
model_channels=1024, contraction_dim=512,
prenet_channels=1024, num_heads=8,
input_vec_dim=256, num_layers=12, prenet_layers=6,
dropout=.1,
)
diff_weights = torch.load('extracted_diff.pth')
model.diff.load_state_dict(diff_weights, strict=False)
cheater_ar_weights = torch.load('X:\\dlas\\experiments\\train_music_gpt_cheater\\models\\19500_generator_ema.pth')
cheater_ar = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4,
vqargs= {'positional_dims': 1, 'channels': 64,
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
})
cheater_ar.load_state_dict(cheater_ar_weights)
model.encoder.load_state_dict(cheater_ar.upper_encoder.state_dict(), strict=True)
torch.save(model.state_dict(), 'sample.pth')
print_network(model)
o = model(clip, ts, clip)
pg = model.get_grad_norm_parameter_groups()
def extract_diff(in_f, out_f, remove_head=False):
p = torch.load(in_f)
out = {}
for k, v in p.items():
if k.startswith('diff.'):
if remove_head and (k.startswith('diff.input_converter') or k.startswith('diff.code_converter')):
continue
out[k.replace('diff.', '')] = v
torch.save(out, out_f)
if __name__ == '__main__':
test_vqvae_model()
#extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True)
test_cheater_model()