diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index 3f7c59af..80cbbb2a 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -1,10 +1,12 @@ +import functools + import torch from torch import nn import torch.nn.functional as F from models.arch_util import zero_module from trainer.networks import register_model -from utils.util import checkpoint, ceil_multiple +from utils.util import checkpoint, ceil_multiple, print_network class Downsample(nn.Module): @@ -152,33 +154,37 @@ class MusicQuantizer(nn.Module): max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, codebook_size=16, codebook_groups=4): super().__init__() + if not isinstance(inner_dim, list): + inner_dim = [inner_dim // 2 ** x for x in range(down_steps+1)] self.max_gumbel_temperature = max_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature self.gumbel_temperature_decay = gumbel_temperature_decay - self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, codevector_dim=codevector_dim, + self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim, num_codevector_groups=codebook_groups, num_codevectors_per_group=codebook_size) + self.codebook_size = codebook_size + self.codebook_groups = codebook_groups self.num_losses_record = [] if down_steps == 0: - self.down = nn.Conv1d(inp_channels, inner_dim, kernel_size=3, padding=1) - self.up = nn.Conv1d(inner_dim, inp_channels, kernel_size=3, padding=1) + self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1) + self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1) elif down_steps == 2: - self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim//4, kernel_size=3, padding=1), - Downsample(inner_dim//4, inner_dim//2), - Downsample(inner_dim//2, inner_dim)) - self.up = nn.Sequential(Upsample(inner_dim, inner_dim//2), - Upsample(inner_dim//2, inner_dim//4), - nn.Conv1d(inner_dim//4, inp_channels, kernel_size=3, padding=1)) + self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1), + Downsample(inner_dim[-1], inner_dim[-2]), + Downsample(inner_dim[-2], inner_dim[-3])) + self.up = nn.Sequential(Upsample(inner_dim[-3], inner_dim[-2]), + Upsample(inner_dim[-2], inner_dim[-1]), + nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)) - self.encoder = nn.Sequential(ResBlock(inner_dim), - ResBlock(inner_dim), - ResBlock(inner_dim)) - self.enc_norm = nn.LayerNorm(inner_dim, eps=1e-5) - self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim, kernel_size=3, padding=1), - ResBlock(inner_dim), - ResBlock(inner_dim), - ResBlock(inner_dim)) + self.encoder = nn.Sequential(ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) + self.enc_norm = nn.LayerNorm(inner_dim[0], eps=1e-5) + self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim[0], kernel_size=3, padding=1), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) self.codes = torch.zeros((3000000,), dtype=torch.long) self.internal_step = 0 @@ -210,7 +216,7 @@ class MusicQuantizer(nn.Module): if return_decoder_latent: return h, diversity - reconstructed = self.up(h) + reconstructed = self.up(h.float()) reconstructed = reconstructed[:, :, :orig_mel.shape[-1]] mse = F.mse_loss(reconstructed, orig_mel) @@ -219,7 +225,10 @@ class MusicQuantizer(nn.Module): def log_codes(self, codes): if self.internal_step % 5 == 0: codes = torch.argmax(codes, dim=-1) - codes = codes[:,:,0] + codes[:,:,1] * 16 + codes[:,:,2] * 16 ** 2 + codes[:,:,3] * 16 ** 3 + ccodes = codes[:,:,0] + for j in range(1,codes.shape[-1]): + ccodes += codes[:,:,j] * self.codebook_size ** j + codes = ccodes codes = codes.flatten() l = codes.shape[0] i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l @@ -242,6 +251,7 @@ def register_music_quantizer(opt_net, opt): if __name__ == '__main__': - model = MusicQuantizer() + model = MusicQuantizer(inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2) + print_network(model) mel = torch.randn((2,256,782)) model(mel) \ No newline at end of file diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index 865ea8c2..531a0868 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -60,6 +60,7 @@ class TransformerDiffusion(nn.Module): def __init__( self, prenet_channels=256, + prenet_layers=3, model_channels=512, block_channels=256, num_layers=8, @@ -108,7 +109,7 @@ class TransformerDiffusion(nn.Module): self.input_converter = nn.Linear(input_vec_dim, prenet_channels) self.code_converter = Encoder( dim=prenet_channels, - depth=3, + depth=prenet_layers, heads=prenet_heads, ff_dropout=dropout, attn_dropout=dropout, @@ -205,7 +206,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.internal_step = 0 self.freeze_quantizer_until = freeze_quantizer_until self.diff = TransformerDiffusion(**kwargs) - self.m2v = MusicQuantizer(inp_channels=256, inner_dim=2048, codevector_dim=1024) + self.m2v = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=1024, codebook_size=512, codebook_groups=2) self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature del self.m2v.up @@ -270,14 +271,14 @@ if __name__ == '__main__': clip = torch.randn(2, 256, 400) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=2048, num_layers=16) + model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6) - #quant_weights = torch.load('X:\\dlas\\experiments\\train_music_quant\\models\\1000_generator.pth') + quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') - #model.m2v.load_state_dict(quant_weights, strict=False) + model.m2v.load_state_dict(quant_weights, strict=False) #model.diff.load_state_dict(diff_weights) - #torch.save(model.state_dict(), 'sample.pth') + torch.save(model.state_dict(), 'sample.pth') print_network(model) o = model(clip, ts, clip, cond)