From 28d95e31415347e525c3cded99b60898af74322e Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Jun 2022 15:09:47 -0600 Subject: [PATCH] gptmusic work --- codes/models/arch_util.py | 13 +- codes/models/audio/music/gpt_music.py | 173 +++++++++++++++--------- codes/models/audio/music/gpt_music2.py | 176 +++++++++++++++++++++++++ codes/train.py | 2 +- 4 files changed, 292 insertions(+), 72 deletions(-) create mode 100644 codes/models/audio/music/gpt_music2.py diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 072953ce..19791ecd 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -319,7 +319,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=None): + def __init__(self, channels, use_conv, dims=2, out_channels=None, factor=2): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -327,16 +327,7 @@ class Downsample(nn.Module): self.dims = dims ksize = 3 pad = 1 - if dims == 1: - stride = 4 - ksize = 5 - pad = 2 - elif dims == 2: - stride = 2 - else: - stride = (1,2,2) - if factor is not None: - stride = factor + stride = factor if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, ksize, stride=stride, padding=pad diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index 9ed43e8f..0e3bfd02 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -6,9 +6,11 @@ from transformers import GPT2Config, GPT2Model from models.arch_util import AttentionBlock, ResBlock from models.audio.music.music_quantizer import MusicQuantizer from models.audio.music.music_quantizer2 import MusicQuantizer2 +from models.audio.tts.lucidrains_dvae import DiscreteVAE from models.lucidrains.x_transformers import Encoder +from models.vqvae.vqvae import Quantize from trainer.networks import register_model -from utils.util import opt_get, checkpoint +from utils.util import opt_get, checkpoint, ceil_multiple, print_network class ConditioningEncoder(nn.Module): @@ -57,66 +59,106 @@ class UpperConditioningEncoder(nn.Module): return h.mean(dim=2) -class GptMusicLower(nn.Module): - def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, - num_upper_groups=4, fp16=True, freeze_upper_until=0): +class UpperQuantizer(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + num_tokens): super().__init__() + attn = [] + def edim(m): + dd = max(embedding_dim//m, 128, spec_dim) + return ceil_multiple(dd, 8) + self.encoder = nn.Sequential( + ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True), + ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True), + ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True), + ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True), + ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1), + ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True), + ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1), + ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True), + ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), + ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), + ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1), + nn.GroupNorm(8, embedding_dim) + ) + self.quantizer = Quantize(embedding_dim, num_tokens) + + self.codes = torch.zeros((num_tokens*100,), dtype=torch.long) + self.code_ind = 0 + self.total_codes = 0 self.internal_step = 0 + + def forward(self, x): + h = x + for lyr in self.encoder: + h = lyr(h) + h = h.permute(0,2,1) + h_quant, commitment_loss, codes = self.quantizer(h) + self.log_codes(codes) + return h_quant, commitment_loss + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.internal_step % 10 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + self.internal_step += 1 + + +class GptMusicLower(nn.Module): + def __init__(self, dim, layers, dropout=0, num_target_vectors=8192, num_upper_vectors=32768, + fp16=True, freeze_upper_until=0, num_vaes=4, vqargs={}): + super().__init__() + self.num_vaes = num_vaes self.freeze_upper_until = freeze_upper_until - self.num_groups = num_target_groups self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) - self.target_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, - codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) - self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim, - max(512,dim-128), - max(512,dim-256), - max(512,dim-384), - max(512,dim-512), - max(512,dim-512)], codevector_dim=dim, - codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True) + self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) + self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors) self.fp16 = fp16 - # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) - del self.target_quantizer.decoder - del self.target_quantizer.up - del self.upper_quantizer.up + self.internal_step = 0 + # Freeze the target quantizer. - for p in self.target_quantizer.parameters(): + for p in self.target_quantizers.parameters(): p.DO_NOT_TRAIN = True p.requires_grad = False - self.upper_mixer = Encoder( - dim=dim, - depth=4, - heads=dim//64, - ff_dropout=dropout, - attn_dropout=dropout, - use_rmsnorm=True, - ff_glu=True, - rotary_emb_dim=True, - ) self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64) self.gpt = GPT2Model(self.config) del self.gpt.wte # Unused, we'll do our own embeddings. - self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_target_groups) for _ in range(num_target_groups)]) - self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_target_groups)]) - + self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) def forward(self, mel, conditioning, return_latent=False): unused_params = [] + with torch.no_grad(): - self.target_quantizer.eval() - codes = self.target_quantizer.get_codes(mel) + codes = [] + partition_size = mel.shape[1] // len(self.target_quantizers) + for i, q in enumerate(self.target_quantizers): + mel_partition = mel[:, i*partition_size:(i+1)*partition_size] + codes.append(q.get_codebook_indices(mel_partition)) + codes = torch.stack(codes, dim=-1) + if self.freeze_upper_until > self.internal_step: with torch.no_grad(): - upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) + self.upper_quantizer = self.upper_quantizer.eval() + upper_vector, upper_diversity = self.upper_quantizer(mel) unused_params.extend(list(self.upper_quantizer.parameters())) else: + self.upper_quantizer = self.upper_quantizer.train() upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) - upper_vector = self.upper_mixer(upper_vector.permute(0,2,1)).permute(0,2,1) # Allow the upper vector to fully attend to itself (the whole thing is a prior.) - upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') + upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear') upper_vector = upper_vector.permute(0,2,1) inputs = codes[:, :-1] @@ -148,33 +190,24 @@ class GptMusicLower(nn.Module): unused_adder = unused_adder + p.mean() * 0 losses = losses + unused_adder - return losses / self.num_groups, upper_diversity + return losses / self.num_vaes, upper_diversity def get_grad_norm_parameter_groups(self): groups = { 'gpt': list(self.gpt.parameters()), 'conditioning': list(self.conditioning_encoder.parameters()), - 'upper_mixer': list(self.upper_mixer.parameters()), - 'upper_quant_down': list(self.upper_quantizer.down.parameters()), - 'upper_quant_encoder': list(self.upper_quantizer.encoder.parameters()), - 'upper_quant_codebook': [self.upper_quantizer.quantizer.codevectors], + 'upper_quantizer': list(self.upper_quantizer.parameters()), + 'target_vqs': list(self.target_quantizers.parameters()), } return groups def get_debug_values(self, step, __): + self.internal_step = 0 if self.upper_quantizer.total_codes > 0: - return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes], - 'gumbel_temperature': self.upper_quantizer.quantizer.temperature} + return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]} else: return {} - def update_for_step(self, step, *args): - self.internal_step = step - self.upper_quantizer.quantizer.temperature = max( - self.upper_quantizer.max_gumbel_temperature * self.upper_quantizer.gumbel_temperature_decay**self.internal_step, - self.upper_quantizer.min_gumbel_temperature, - ) - class GptMusicUpper(nn.Module): def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True): @@ -263,18 +296,38 @@ def register_music_gpt_upper(opt_net, opt): def test_lower(): - from models.audio.music.transformer_diffusion8 import TransformerDiffusionWithQuantizer - base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024, - prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024, - dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) - base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu'))) + model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000, + num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4, + vqargs= { + 'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }) + quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth'] + for i, qfile in enumerate(quants): + quant_weights = torch.load(qfile) + model.target_quantizers[i].load_state_dict(quant_weights, strict=True) + torch.save(model.state_dict(), 'sample.pth') + print_network(model) - model = GptMusicLower(512, 8, fp16=False, freeze_upper_until=100) - model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) - torch.save(model.state_dict(), "sample.pth") mel = torch.randn(2,256,400) model(mel, mel) - model.get_grad_norm_parameter_groups() + pg = model.get_grad_norm_parameter_groups() + + t = 0 + for k, vs in pg.items(): + s = 0 + for v in vs: + m = 1 + for d in v.shape: + m *= d + s += m + t += s + print(k, s/1000000) + print(t/1000000) def test_upper(): diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py new file mode 100644 index 00000000..494aa476 --- /dev/null +++ b/codes/models/audio/music/gpt_music2.py @@ -0,0 +1,176 @@ +import torch +import torch.nn.functional as F +from torch import nn +from transformers import GPT2Config, GPT2Model + +from models.arch_util import AttentionBlock, ResBlock +from models.audio.tts.lucidrains_dvae import DiscreteVAE +from trainer.networks import register_model +from utils.util import opt_get, ceil_multiple, print_network + + +class UpperEncoder(nn.Module): + def __init__(self, + spec_dim, + hidden_dim, + embedding_dim, + ): + super().__init__() + attn = [] + def edim(m): + dd = max(hidden_dim // m, 128, spec_dim) + return ceil_multiple(dd, 8) + self.downsampler = nn.Sequential( + ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True), + ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True), + ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True), + ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True), + ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1), + ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True), + ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1), + ResBlock(edim(2), out_channels=hidden_dim, use_conv=True, dims=1, down=True)) + self.encoder = nn.Sequential( + AttentionBlock(hidden_dim, 4, do_activation=True), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + AttentionBlock(hidden_dim, 4, do_activation=True), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + AttentionBlock(hidden_dim, 4, do_activation=True), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + nn.GroupNorm(8, hidden_dim), + nn.SiLU(), + nn.Conv1d(hidden_dim, embedding_dim, 1) + ) + + def forward(self, x): + h = self.downsampler(x) + h = self.encoder(h) + return h + + +class GptMusicLower(nn.Module): + def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}): + super().__init__() + self.num_vaes = num_vaes + self.start_token = nn.Parameter(torch.randn(1, 1, dim)) + self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, + n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, + use_cache=False) + + self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)]) + self.upper_encoder = UpperEncoder(256, dim, encoder_out_dim) + self.encoder_projector = nn.Conv1d(encoder_out_dim, dim, 1) + self.fp16 = fp16 + + # Freeze the target quantizer. + for p in self.target_quantizers.parameters(): + p.DO_NOT_TRAIN = True + p.requires_grad = False + # And delete the decoder, which is unused. + for tq in self.target_quantizers: + del tq.decoder + + self.gpt = GPT2Model(self.config) + del self.gpt.wte # Unused, we'll do our own embeddings. + + self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)]) + self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)]) + + def forward(self, mel, return_latent=False): + unused_params = [] + + with torch.no_grad(): + codes = [] + partition_size = mel.shape[1] // len(self.target_quantizers) + for i, q in enumerate(self.target_quantizers): + mel_partition = mel[:, i*partition_size:(i+1)*partition_size] + codes.append(q.get_codebook_indices(mel_partition)) + codes = torch.stack(codes, dim=-1) + + upper_vector = self.upper_encoder(mel) + upper_vector = self.encoder_projector(upper_vector) + # WTB slerp + upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') + upper_vector = upper_vector.permute(0,2,1) + + inputs = codes[:, :-1] + targets = codes + upper_vector = upper_vector[:, :-1] + h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = torch.cat(h, dim=-1) + upper_vector + + with torch.autocast(mel.device.type, enabled=self.fp16): + # Stick the conditioning embedding on the front of the input sequence. + # The transformer will learn how to integrate it. + # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. + h = torch.cat([self.start_token.repeat(h.shape[0], 1, 1), h], dim=1) + + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + + if return_latent: + return h.float() + + losses = 0 + for i, head in enumerate(self.heads): + logits = head(h).permute(0,2,1) + loss = F.cross_entropy(logits, targets[:,:,i]) + losses = losses + loss + + unused_adder = 0 + for p in unused_params: + unused_adder = unused_adder + p.mean() * 0 + losses = losses + unused_adder + + return losses / self.num_vaes + + def get_grad_norm_parameter_groups(self): + groups = { + 'gpt': list(self.gpt.parameters()), + 'heads': list(self.heads.parameters()), + 'embeddings': list(self.embeddings.parameters()), + 'upper_latent_encoder': list(self.upper_encoder.encoder.parameters()), + 'upper_latent_downsampler': list(self.upper_encoder.downsampler.parameters()), + } + return groups + + + +@register_model +def register_music_gpt_lower2(opt_net, opt): + return GptMusicLower(**opt_get(opt_net, ['kwargs'], {})) + + +def test_lower(): + model = GptMusicLower(dim=1024, encoder_out_dim=256, layers=16, fp16=False, num_target_vectors=8192, num_vaes=4, + vqargs= {'positional_dims': 1, 'channels': 64, + 'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192, + 'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False, + }) + quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth', + 'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth'] + for i, qfile in enumerate(quants): + quant_weights = torch.load(qfile) + model.target_quantizers[i].load_state_dict(quant_weights, strict=False) + torch.save(model.state_dict(), 'sample.pth') + print_network(model) + + mel = torch.randn(2,256,400) + model(mel) + pg = model.get_grad_norm_parameter_groups() + + t = 0 + for k, vs in pg.items(): + s = 0 + for v in vs: + m = 1 + for d in v.shape: + m *= d + s += m + t += s + print(k, s/1000000) + print(t/1000000) + + +if __name__ == '__main__': + test_lower() \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 3c242f36..033bb894 100644 --- a/codes/train.py +++ b/codes/train.py @@ -339,7 +339,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_unified_alignment.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)