345 lines
15 KiB
Python
345 lines
15 KiB
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from transformers import GPT2Config, GPT2Model
|
|
|
|
from models.arch_util import AttentionBlock, ResBlock
|
|
from models.audio.music.music_quantizer import MusicQuantizer
|
|
from models.audio.music.music_quantizer2 import MusicQuantizer2
|
|
from models.audio.tts.lucidrains_dvae import DiscreteVAE
|
|
from models.lucidrains.x_transformers import Encoder
|
|
from models.vqvae.vqvae import Quantize
|
|
from trainer.networks import register_model
|
|
from utils.util import opt_get, checkpoint, ceil_multiple, print_network
|
|
|
|
|
|
class ConditioningEncoder(nn.Module):
|
|
def __init__(self,
|
|
spec_dim,
|
|
embedding_dim,
|
|
attn_blocks=6,
|
|
num_attn_heads=4):
|
|
super().__init__()
|
|
attn = []
|
|
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=3, stride=2, padding=1)
|
|
for a in range(attn_blocks):
|
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
|
|
self.attn = nn.Sequential(*attn)
|
|
self.dim = embedding_dim
|
|
|
|
def forward(self, x):
|
|
h = self.init(x)
|
|
h = self.attn(h)
|
|
return h.mean(dim=2)
|
|
|
|
|
|
class UpperConditioningEncoder(nn.Module):
|
|
def __init__(self,
|
|
spec_dim,
|
|
embedding_dim,
|
|
attn_blocks=6,
|
|
num_attn_heads=4):
|
|
super().__init__()
|
|
attn = []
|
|
self.init = nn.Sequential(nn.Conv1d(spec_dim, min(spec_dim+128, embedding_dim), kernel_size=3, stride=2, padding=1),
|
|
nn.Conv1d(min(spec_dim+128, embedding_dim), min(spec_dim+256, embedding_dim), kernel_size=3, stride=2, padding=1),
|
|
nn.Conv1d(min(spec_dim+256, embedding_dim), min(spec_dim+384, embedding_dim), kernel_size=3, stride=2, padding=1),
|
|
nn.Conv1d(min(spec_dim+384, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
|
ResBlock(min(spec_dim+512, embedding_dim), dims=1),
|
|
nn.Conv1d(min(spec_dim+512, embedding_dim), min(spec_dim+512, embedding_dim), kernel_size=3, stride=2, padding=1),
|
|
ResBlock(min(spec_dim+512, embedding_dim), dims=1))
|
|
for a in range(attn_blocks):
|
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_activation=True))
|
|
self.attn = nn.Sequential(*attn)
|
|
self.dim = embedding_dim
|
|
|
|
def forward(self, x):
|
|
h = self.init(x)
|
|
h = self.attn(h)
|
|
return h.mean(dim=2)
|
|
|
|
|
|
class UpperQuantizer(nn.Module):
|
|
def __init__(self,
|
|
spec_dim,
|
|
embedding_dim,
|
|
num_tokens):
|
|
super().__init__()
|
|
attn = []
|
|
def edim(m):
|
|
dd = max(embedding_dim//m, 128, spec_dim)
|
|
return ceil_multiple(dd, 8)
|
|
self.encoder = nn.Sequential(
|
|
ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
|
|
ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
|
|
ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
|
|
ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
|
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
|
|
ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
|
ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
|
|
ResBlock(edim(2), out_channels=embedding_dim, use_conv=True, dims=1, down=True),
|
|
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
|
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
|
ResBlock(embedding_dim, out_channels=embedding_dim, use_conv=True, dims=1),
|
|
nn.GroupNorm(8, embedding_dim)
|
|
)
|
|
self.quantizer = Quantize(embedding_dim, num_tokens)
|
|
|
|
self.codes = torch.zeros((num_tokens*100,), dtype=torch.long)
|
|
self.code_ind = 0
|
|
self.total_codes = 0
|
|
self.internal_step = 0
|
|
|
|
def forward(self, x):
|
|
h = x
|
|
for lyr in self.encoder:
|
|
h = lyr(h)
|
|
h = h.permute(0,2,1)
|
|
h_quant, commitment_loss, codes = self.quantizer(h)
|
|
self.log_codes(codes)
|
|
return h_quant, commitment_loss
|
|
|
|
def log_codes(self, codes):
|
|
# This is so we can debug the distribution of codes being learned.
|
|
if self.internal_step % 10 == 0:
|
|
codes = codes.flatten()
|
|
l = codes.shape[0]
|
|
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
|
self.codes[i:i+l] = codes.cpu()
|
|
self.code_ind = self.code_ind + l
|
|
if self.code_ind >= self.codes.shape[0]:
|
|
self.code_ind = 0
|
|
self.total_codes += 1
|
|
self.internal_step += 1
|
|
|
|
|
|
class GptMusicLower(nn.Module):
|
|
def __init__(self, dim, layers, dropout=0, num_target_vectors=8192, num_upper_vectors=32768,
|
|
fp16=True, freeze_upper_until=0, num_vaes=4, vqargs={}):
|
|
super().__init__()
|
|
self.num_vaes = num_vaes
|
|
self.freeze_upper_until = freeze_upper_until
|
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False)
|
|
self.target_quantizers = nn.ModuleList([DiscreteVAE(**vqargs).eval() for _ in range(num_vaes)])
|
|
self.upper_quantizer = UpperQuantizer(256, dim, num_upper_vectors)
|
|
self.fp16 = fp16
|
|
self.internal_step = 0
|
|
|
|
# Freeze the target quantizer.
|
|
for p in self.target_quantizers.parameters():
|
|
p.DO_NOT_TRAIN = True
|
|
p.requires_grad = False
|
|
|
|
self.conditioning_encoder = ConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
|
|
|
self.gpt = GPT2Model(self.config)
|
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
|
|
|
self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_vaes) for _ in range(num_vaes)])
|
|
self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_vaes)])
|
|
|
|
def forward(self, mel, conditioning, return_latent=False):
|
|
unused_params = []
|
|
|
|
with torch.no_grad():
|
|
codes = []
|
|
partition_size = mel.shape[1] // len(self.target_quantizers)
|
|
for i, q in enumerate(self.target_quantizers):
|
|
mel_partition = mel[:, i*partition_size:(i+1)*partition_size]
|
|
codes.append(q.get_codebook_indices(mel_partition))
|
|
codes = torch.stack(codes, dim=-1)
|
|
|
|
if self.freeze_upper_until > self.internal_step:
|
|
with torch.no_grad():
|
|
self.upper_quantizer = self.upper_quantizer.eval()
|
|
upper_vector, upper_diversity = self.upper_quantizer(mel)
|
|
unused_params.extend(list(self.upper_quantizer.parameters()))
|
|
else:
|
|
self.upper_quantizer = self.upper_quantizer.train()
|
|
upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True)
|
|
upper_vector = F.interpolate(upper_vector.permute(0,2,1), size=codes.shape[1], mode='linear')
|
|
upper_vector = upper_vector.permute(0,2,1)
|
|
|
|
inputs = codes[:, :-1]
|
|
targets = codes
|
|
upper_vector = upper_vector[:, :-1]
|
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
|
h = torch.cat(h, dim=-1) + upper_vector
|
|
|
|
with torch.autocast(mel.device.type, enabled=self.fp16):
|
|
# Stick the conditioning embedding on the front of the input sequence.
|
|
# The transformer will learn how to integrate it.
|
|
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
|
cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1)
|
|
h = torch.cat([cond_emb, h], dim=1)
|
|
|
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
|
|
|
if return_latent:
|
|
return h.float()
|
|
|
|
losses = 0
|
|
for i, head in enumerate(self.heads):
|
|
logits = head(h).permute(0,2,1)
|
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
|
losses = losses + loss
|
|
|
|
unused_adder = 0
|
|
for p in unused_params:
|
|
unused_adder = unused_adder + p.mean() * 0
|
|
losses = losses + unused_adder
|
|
|
|
return losses / self.num_vaes, upper_diversity
|
|
|
|
def get_grad_norm_parameter_groups(self):
|
|
groups = {
|
|
'gpt': list(self.gpt.parameters()),
|
|
'conditioning': list(self.conditioning_encoder.parameters()),
|
|
'upper_quantizer': list(self.upper_quantizer.parameters()),
|
|
'target_vqs': list(self.target_quantizers.parameters()),
|
|
}
|
|
return groups
|
|
|
|
def get_debug_values(self, step, __):
|
|
self.internal_step = 0
|
|
if self.upper_quantizer.total_codes > 0:
|
|
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
|
|
else:
|
|
return {}
|
|
|
|
|
|
class GptMusicUpper(nn.Module):
|
|
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
|
super().__init__()
|
|
self.internal_step = 0
|
|
self.num_groups = num_upper_groups
|
|
self.fp16 = fp16
|
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
|
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
|
use_cache=False)
|
|
self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[dim,
|
|
max(512,dim-128),
|
|
max(512,dim-256),
|
|
max(512,dim-384),
|
|
max(512,dim-512),
|
|
max(512,dim-512)], codevector_dim=dim,
|
|
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups,
|
|
expressive_downsamples=True)
|
|
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
|
del self.upper_quantizer.up
|
|
# Freeze the quantizer.
|
|
for p in self.upper_quantizer.parameters():
|
|
p.DO_NOT_TRAIN = True
|
|
p.requires_grad = False
|
|
|
|
self.conditioning_encoder = UpperConditioningEncoder(256, dim, attn_blocks=4, num_attn_heads=dim//64)
|
|
|
|
self.gpt = GPT2Model(self.config)
|
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
|
|
|
self.embeddings = nn.ModuleList([nn.Embedding(num_upper_vectors, dim // num_upper_groups) for _ in range(num_upper_groups)])
|
|
self.heads = nn.ModuleList([nn.Linear(dim, num_upper_vectors) for _ in range(num_upper_groups)])
|
|
|
|
|
|
def forward(self, mel, conditioning, return_latent=False):
|
|
with torch.no_grad():
|
|
self.upper_quantizer.eval()
|
|
codes = self.upper_quantizer.get_codes(mel)
|
|
|
|
inputs = codes[:, :-1]
|
|
targets = codes
|
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
|
h = torch.cat(h, dim=-1)
|
|
|
|
with torch.autocast(mel.device.type, enabled=self.fp16):
|
|
# Stick the conditioning embedding on the front of the input sequence.
|
|
# The transformer will learn how to integrate it.
|
|
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
|
cond_emb = self.conditioning_encoder(conditioning).unsqueeze(1)
|
|
h = torch.cat([cond_emb, h], dim=1)
|
|
|
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
|
|
|
if return_latent:
|
|
return h.float()
|
|
|
|
losses = 0
|
|
for i, head in enumerate(self.heads):
|
|
logits = head(h).permute(0,2,1)
|
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
|
losses = losses + loss
|
|
|
|
return losses / self.num_groups
|
|
|
|
def get_grad_norm_parameter_groups(self):
|
|
groups = {
|
|
'gpt': list(self.gpt.parameters()),
|
|
'conditioning': list(self.conditioning_encoder.parameters()),
|
|
}
|
|
return groups
|
|
|
|
def get_debug_values(self, step, __):
|
|
if self.upper_quantizer.total_codes > 0:
|
|
return {'histogram_upper_codes': self.upper_quantizer.codes[:self.upper_quantizer.total_codes]}
|
|
else:
|
|
return {}
|
|
|
|
|
|
@register_model
|
|
def register_music_gpt_lower(opt_net, opt):
|
|
return GptMusicLower(**opt_get(opt_net, ['kwargs'], {}))
|
|
|
|
@register_model
|
|
def register_music_gpt_upper(opt_net, opt):
|
|
return GptMusicUpper(**opt_get(opt_net, ['kwargs'], {}))
|
|
|
|
|
|
def test_lower():
|
|
model = GptMusicLower(dim=512, layers=12, fp16=False, freeze_upper_until=1000,
|
|
num_target_vectors=8192, num_upper_vectors=8192, num_vaes=4,
|
|
vqargs= {
|
|
'positional_dims': 1, 'channels': 64,
|
|
'hidden_dim': 512, 'num_resnet_blocks': 3, 'codebook_dim': 512, 'num_tokens': 8192,
|
|
'num_layers': 0, 'record_codes': True, 'kernel_size': 3, 'use_transposed_convs': False,
|
|
})
|
|
quants = ['X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_low\\models\\7500_generator.pth',
|
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_low\\models\\11000_generator.pth',
|
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_mid_high\\models\\11500_generator.pth',
|
|
'X:\\dlas\\experiments\\music_vqvaes\\train_lrdvae_music_high\\models\\11500_generator.pth']
|
|
for i, qfile in enumerate(quants):
|
|
quant_weights = torch.load(qfile)
|
|
model.target_quantizers[i].load_state_dict(quant_weights, strict=True)
|
|
torch.save(model.state_dict(), 'sample.pth')
|
|
print_network(model)
|
|
|
|
mel = torch.randn(2,256,400)
|
|
model(mel, mel)
|
|
pg = model.get_grad_norm_parameter_groups()
|
|
|
|
t = 0
|
|
for k, vs in pg.items():
|
|
s = 0
|
|
for v in vs:
|
|
m = 1
|
|
for d in v.shape:
|
|
m *= d
|
|
s += m
|
|
t += s
|
|
print(k, s/1000000)
|
|
print(t/1000000)
|
|
|
|
|
|
def test_upper():
|
|
lower = GptMusicLower(512, 12)
|
|
lower.load_state_dict(torch.load('D:\\dlas\\experiments\\train_music_gpt\\models\\44500_generator_ema.pth'))
|
|
model = GptMusicUpper(512, 12)
|
|
model.upper_quantizer.load_state_dict(lower.upper_quantizer.state_dict())
|
|
torch.save(model.state_dict(), 'sample.pth')
|
|
mel = torch.randn(2,256,2500)
|
|
model(mel, mel)
|
|
model.get_grad_norm_parameter_groups()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_lower() |