diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py new file mode 100644 index 00000000..f7385b31 --- /dev/null +++ b/codes/models/audio/music/gpt_music.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2Model + +from models.audio.music.music_quantizer import MusicQuantizer +from models.audio.music.music_quantizer2 import MusicQuantizer2 +from trainer.networks import register_model +from utils.util import opt_get + + +class GptMusic(nn.Module): + def __init__(self, dim, layers, num_target_vectors=512, num_target_groups=2, cv_dim=1024, num_upper_vectors=64, num_upper_groups=4): + super().__init__() + self.num_groups = num_target_groups + self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, + n_inner=dim*2) + self.target_quantizer = MusicQuantizer(inp_channels=256, inner_dim=[1024,1024,512], codevector_dim=cv_dim, codebook_size=num_target_vectors, codebook_groups=num_target_groups) + del self.target_quantizer.decoder + del self.target_quantizer.up + self.upper_quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024,896,768,640,512,384], codevector_dim=cv_dim, codebook_size=num_upper_vectors, codebook_groups=num_upper_groups) + del self.upper_quantizer.up + self.gpt = GPT2Model(self.config) + del self.gpt.wte # Unused, we'll do our own embeddings. + self.embeddings = nn.ModuleList([nn.Embedding(num_target_vectors, dim // num_target_groups) for _ in range(num_target_groups)]) + self.upper_proj = nn.Conv1d(cv_dim, dim, kernel_size=1) + self.heads = nn.ModuleList([nn.Linear(dim, num_target_vectors) for _ in range(num_target_groups)]) + + + def forward(self, mel): + with torch.no_grad(): + self.target_quantizer.eval() + codes = self.target_quantizer.get_codes(mel) + upper_vector, upper_diversity = self.upper_quantizer(mel, return_decoder_latent=True) + upper_vector = self.upper_proj(upper_vector) + upper_vector = F.interpolate(upper_vector, size=codes.shape[1], mode='linear') + upper_vector = upper_vector.permute(0,2,1) + + inputs = codes[:, :-1] + upper_vector = upper_vector[:, :-1] + targets = codes[:, 1:] + + h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = torch.cat(h, dim=-1) + upper_vector + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + + losses = 0 + for i, head in enumerate(self.heads): + logits = head(h).permute(0,2,1) + loss = F.cross_entropy(logits, targets[:,:,i]) + losses = losses + loss + + return losses / self.num_groups + + +@register_model +def register_music_gpt(opt_net, opt): + return GptMusic(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = GptMusic(512, 12) + mel = torch.randn(2,256,400) + model(mel) \ No newline at end of file diff --git a/codes/models/audio/music/music_gen_fill_gaps.py b/codes/models/audio/music/music_gen_fill_gaps.py deleted file mode 100644 index 14eaf563..00000000 --- a/codes/models/audio/music/music_gen_fill_gaps.py +++ /dev/null @@ -1,262 +0,0 @@ -import random - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast -from torchaudio.transforms import TimeMasking, FrequencyMasking - -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock -from trainer.networks import register_model -from utils.util import checkpoint - -def is_sequence(t): - return t.dtype == torch.long - - -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - efficient_config=True, - use_scale_shift_norm=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_scale_shift_norm = use_scale_shift_norm - padding = {1: 0, 3: 1, 5: 2}[kernel_size] - eff_kernel = 1 if efficient_config else 3 - eff_padding = 0 if efficient_config else 1 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionLayer(TimestepBlock): - def __init__(self, model_channels, dropout, num_heads): - super().__init__() - self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) - - def forward(self, x, time_emb): - y = self.resblk(x, time_emb) - return self.attn(y) - - -class MusicGenerator(nn.Module): - def __init__( - self, - model_channels=512, - num_layers=8, - in_channels=100, - out_channels=200, # mean and variance - dropout=0, - use_fp16=False, - num_heads=16, - # Parameters for regularization. - layer_drop=.1, - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - # Masking parameters. - frequency_mask_percent_max=0, - time_mask_percent_max=0, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.dropout = dropout - self.num_heads = num_heads - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - self.layer_drop = layer_drop - self.time_mask_percent_max = time_mask_percent_max - self.frequency_mask_percent_mask = frequency_mask_percent_max - - self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - self.time_embed = nn.Sequential( - linear(model_channels, model_channels), - nn.SiLU(), - linear(model_channels, model_channels), - ) - - self.conditioner = nn.Sequential( - nn.Conv1d(in_channels, model_channels, 3, padding=1), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - ) - self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) - self.conditioning_timestep_integrator = TimestepEmbedSequential( - DiffusionLayer(model_channels, dropout, num_heads), - DiffusionLayer(model_channels, dropout, num_heads), - ) - self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) - self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + - [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()), - 'conditioner': list(self.conditioner.parameters()) + list(self.conditioner.parameters()), - 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def do_masking(self, truth): - b, c, s = truth.shape - mask = torch.ones_like(truth) - if self.random() > .5: - # Frequency mask - cs = random.randint(0, c-10) - ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c))) - mask[:, cs:ce] = 0 - else: - # Time mask - cs = random.randint(0, s-5) - ce = min(s-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*s))) - mask[:, :, cs:ce] = 0 - return truth * mask - - - def timestep_independent(self, truth): - truth_emb = self.conditioner(truth) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((truth_emb.shape[0], 1, 1), - device=truth_emb.device) < self.unconditioned_percentage - truth_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(truth.shape[0], 1, 1), - truth_emb) - return truth_emb - - - def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param truth: Input value is either pre-masked (in inference), or unmasked (during training) - :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() - :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. - :return: an [N x C x ...] Tensor of outputs. - """ - assert precomputed_aligned_embeddings is not None or truth is not None - - unused_params = [] - if conditioning_free: - truth_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.conditioner.parameters())) - else: - if precomputed_aligned_embeddings is not None: - truth_emb = precomputed_aligned_embeddings - else: - if self.training: - truth = self.do_masking(truth) - truth_emb = self.timestep_independent(truth) - unused_params.append(self.unconditioned_embedding) - - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - truth_emb = self.conditioning_timestep_integrator(truth_emb, time_emb) - x = self.inp_block(x) - x = torch.cat([x, truth_emb], dim=1) - x = self.integrating_conv(x) - for i, lyr in enumerate(self.layers): - # Do layer drop where applicable. Do not drop first and last layers. - if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop: - unused_params.extend(list(lyr.parameters())) - else: - # First and last blocks will have autocast disabled for improved precision. - with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): - x = lyr(x, time_emb) - - x = x.float() - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -@register_model -def register_music_gap_gen(opt_net, opt): - return MusicGenerator(**opt_net['kwargs']) - - -if __name__ == '__main__': - clip = torch.randn(2, 100, 400) - aligned_latent = torch.randn(2,100,388) - ts = torch.LongTensor([600, 600]) - model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5) - o = model(clip, ts, aligned_latent) - diff --git a/codes/models/audio/music/music_gen_fill_gaps_v2.py b/codes/models/audio/music/music_gen_fill_gaps_v2.py deleted file mode 100644 index e483b1bc..00000000 --- a/codes/models/audio/music/music_gen_fill_gaps_v2.py +++ /dev/null @@ -1,266 +0,0 @@ -import random - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import autocast -from torchaudio.transforms import TimeMasking, FrequencyMasking - -from models.audio.tts.unified_voice2 import ConditioningEncoder -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock -from models.lucidrains.x_transformers import Encoder -from trainer.networks import register_model -from utils.util import checkpoint - -def is_sequence(t): - return t.dtype == torch.long - - -class ResBlock(TimestepBlock): - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - dims=2, - kernel_size=3, - efficient_config=True, - use_scale_shift_norm=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_scale_shift_norm = use_scale_shift_norm - padding = {1: 0, 3: 1, 5: 2}[kernel_size] - eff_kernel = 1 if efficient_config else 3 - eff_padding = 0 if efficient_config else 1 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding), - ) - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, eff_kernel, padding=eff_padding) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, x, emb - ) - - def _forward(self, x, emb): - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class DiffusionLayer(TimestepBlock): - def __init__(self, model_channels, dropout, num_heads): - super().__init__() - self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) - - def forward(self, x, time_emb): - y = self.resblk(x, time_emb) - return self.attn(y) - - -class ConditioningEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - attn_blocks=6): - super().__init__() - attn = [] - self.init = nn.Sequential(nn.Conv1d(spec_dim, embedding_dim//2, kernel_size=3, padding=1, stride=2), - nn.Conv1d(embedding_dim//2, embedding_dim, kernel_size=3, padding=1, stride=2)) - self.attn = Encoder(dim=embedding_dim, depth=attn_blocks, use_scalenorm=True, rotary_pos_emb=True, - heads=embedding_dim//64, ff_mult=1) - self.dim = embedding_dim - - def forward(self, x): - h = self.init(x) - h = self.attn(h.permute(0,2,1)) - return h.mean(dim=1) - - -class MusicGenerator(nn.Module): - def __init__( - self, - model_channels=512, - num_layers=8, - in_channels=100, - out_channels=200, # mean and variance - dropout=0, - use_fp16=False, - num_heads=16, - # Parameters for regularization. - layer_drop=.1, - unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - # Masking parameters. - frequency_mask_percent_max=0.2, - time_mask_percent_max=0.2, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.dropout = dropout - self.num_heads = num_heads - self.unconditioned_percentage = unconditioned_percentage - self.enable_fp16 = use_fp16 - self.layer_drop = layer_drop - self.time_mask_percent_max = time_mask_percent_max - self.frequency_mask_percent_mask = frequency_mask_percent_max - - self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) - self.time_embed = nn.Sequential( - linear(model_channels, model_channels), - nn.SiLU(), - linear(model_channels, model_channels), - ) - - self.conditioner = ConditioningEncoder(in_channels, model_channels) - self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels)) - self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + - [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)]) - - self.out = nn.Sequential( - normalization(model_channels), - nn.SiLU(), - zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)), - ) - - def get_grad_norm_parameter_groups(self): - groups = { - 'layers': list(self.layers.parameters()), - 'conditioner': list(self.conditioner.parameters()), - 'time_embed': list(self.time_embed.parameters()), - } - return groups - - def do_masking(self, truth): - b, c, s = truth.shape - - # Frequency mask - mask_freq = torch.ones_like(truth) - cs = random.randint(0, c-10) - ce = min(c-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*c))) - mask_freq[:, cs:ce] = 0 - - # Time mask - mask_time = torch.ones_like(truth) - cs = random.randint(0, s-5) - ce = min(s-1, cs+random.randint(1, int(self.frequency_mask_percent_mask*s))) - mask_time[:, :, cs:ce] = 0 - - return truth * mask_time * mask_freq - - - def timestep_independent(self, truth): - if self.training: - truth = self.do_masking(truth) - truth_emb = self.conditioner(truth) - return truth_emb - - - def forward(self, x, timesteps, truth=None, precomputed_aligned_embeddings=None, conditioning_free=False): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param truth: Input value is either pre-masked (in inference), or unmasked (during training) - :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() - :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. - :return: an [N x C x ...] Tensor of outputs. - """ - assert precomputed_aligned_embeddings is not None or truth is not None - - unused_params = [] - if conditioning_free: - truth_emb = self.unconditioned_embedding - unused_params.extend(list(self.conditioner.parameters())) - else: - if precomputed_aligned_embeddings is not None: - truth_emb = precomputed_aligned_embeddings - else: - truth_emb = self.timestep_independent(truth) - unused_params.append(self.unconditioned_embedding) - time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + truth_emb - - x = self.inp_block(x) - for i, lyr in enumerate(self.layers): - # Do layer drop where applicable. Do not drop first and last layers. - if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop: - unused_params.extend(list(lyr.parameters())) - else: - # First and last blocks will have autocast disabled for improved precision. - with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): - x = lyr(x, time_emb) - - x = x.float() - out = self.out(x) - - # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. - extraneous_addition = 0 - for p in unused_params: - extraneous_addition = extraneous_addition + p.mean() - out = out + extraneous_addition * 0 - - return out - - -@register_model -def register_music_gap_gen2(opt_net, opt): - return MusicGenerator(**opt_net['kwargs']) - - -if __name__ == '__main__': - clip = torch.randn(2, 100, 400) - aligned_latent = torch.randn(2,100,388) - ts = torch.LongTensor([600, 600]) - model = MusicGenerator(512, layer_drop=.3, unconditioned_percentage=.5) - o = model(clip, ts, aligned_latent) - diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py index 19429e41..dd508048 100644 --- a/codes/models/audio/music/music_quantizer.py +++ b/codes/models/audio/music/music_quantizer.py @@ -197,14 +197,11 @@ class MusicQuantizer(nn.Module): self.code_ind = 0 self.total_codes = 0 - def get_codes(self, mel, project=False): - proj = self.m2v.input_blocks(mel).permute(0,2,1) - _, proj = self.m2v.projector(proj) - if project: - proj, _ = self.quantizer(proj) - return proj - else: - return self.quantizer.get_codes(proj) + def get_codes(self, mel): + h = self.down(mel) + h = self.encoder(h) + h = self.enc_norm(h.permute(0,2,1)) + return self.quantizer.get_codes(h) def forward(self, mel, return_decoder_latent=False): orig_mel = mel diff --git a/codes/models/audio/music/music_quantizer2.py b/codes/models/audio/music/music_quantizer2.py new file mode 100644 index 00000000..5b2a7138 --- /dev/null +++ b/codes/models/audio/music/music_quantizer2.py @@ -0,0 +1,262 @@ +import functools + +import torch +from torch import nn +import torch.nn.functional as F + +from models.arch_util import zero_module +from models.vqvae.vqvae import Quantize +from trainer.networks import register_model +from utils.util import checkpoint, ceil_multiple, print_network + + +class Downsample(nn.Module): + def __init__(self, chan_in, chan_out): + super().__init__() + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=.5, mode='linear') + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, chan_in, chan_out): + super().__init__() + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='linear') + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, 3, padding = 1), + nn.GroupNorm(8, chan), + nn.SiLU(), + nn.Conv1d(chan, chan, 3, padding = 1), + nn.GroupNorm(8, chan), + nn.SiLU(), + zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), + ) + + def forward(self, x): + return checkpoint(self._forward, x) + x + + def _forward(self, x): + return self.net(x) + + +class Wav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320): + super().__init__() + self.codevector_dim = codevector_dim + self.num_groups = num_codevector_groups + self.num_vars = num_codevectors_per_group + self.num_codevectors = num_codevector_groups * num_codevectors_per_group + + if codevector_dim % self.num_groups != 0: + raise ValueError( + f"`codevector_dim {codevector_dim} must be divisible " + f"by `num_codevector_groups` {num_codevector_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + # Parameters init. + self.weight_proj.weight.data.normal_(mean=0.0, std=1) + self.weight_proj.bias.data.zero_() + nn.init.uniform_(self.codevectors) + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def get_codes(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + codevector_idx = hidden_states.argmax(dim=-1) + idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) + return idxs + + def forward(self, hidden_states, mask_time_indices=None, return_probs=False): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # compute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = ( + codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + .sum(-2) + .view(batch_size, sequence_length, -1) + ) + + if return_probs: + return codevectors, perplexity, codevector_probs.view(batch_size, sequence_length, self.num_groups, self.num_vars) + return codevectors, perplexity + + +class MusicQuantizer2(nn.Module): + def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2, + max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, + codebook_size=16, codebook_groups=4): + super().__init__() + if not isinstance(inner_dim, list): + inner_dim = [inner_dim // 2 ** x for x in range(down_steps+1)] + self.max_gumbel_temperature = max_gumbel_temperature + self.min_gumbel_temperature = min_gumbel_temperature + self.gumbel_temperature_decay = gumbel_temperature_decay + self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim[0], codevector_dim=codevector_dim, + num_codevector_groups=codebook_groups, + num_codevectors_per_group=codebook_size) + self.codebook_size = codebook_size + self.codebook_groups = codebook_groups + self.num_losses_record = [] + + if down_steps == 0: + self.down = nn.Conv1d(inp_channels, inner_dim[0], kernel_size=3, padding=1) + self.up = nn.Conv1d(inner_dim[0], inp_channels, kernel_size=3, padding=1) + elif down_steps == 2: + self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim[-1], kernel_size=3, padding=1), + *[Downsample(inner_dim[-i], inner_dim[-i-1]) for i in range(1,len(inner_dim))]) + self.up = nn.Sequential(*[Upsample(inner_dim[i], inner_dim[i+1]) for i in range(len(inner_dim)-1)] + + [nn.Conv1d(inner_dim[-1], inp_channels, kernel_size=3, padding=1)]) + + self.encoder = nn.Sequential(ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) + self.enc_norm = nn.LayerNorm(inner_dim[0], eps=1e-5) + self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim[0], kernel_size=3, padding=1), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0]), + ResBlock(inner_dim[0])) + + self.codes = torch.zeros((3000000,), dtype=torch.long) + self.internal_step = 0 + self.code_ind = 0 + self.total_codes = 0 + + def get_codes(self, mel, project=False): + proj = self.m2v.input_blocks(mel).permute(0,2,1) + _, proj = self.m2v.projector(proj) + if project: + proj, _ = self.quantizer(proj) + return proj + else: + return self.quantizer.get_codes(proj) + + def forward(self, mel, return_decoder_latent=False): + orig_mel = mel + cm = ceil_multiple(mel.shape[-1], 2 ** (len(self.down)-1)) + if cm != 0: + mel = F.pad(mel, (0,cm-mel.shape[-1])) + + h = self.down(mel) + h = self.encoder(h) + h = self.enc_norm(h.permute(0,2,1)) + codevectors, perplexity, codes = self.quantizer(h, return_probs=True) + diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors + self.log_codes(codes) + h = self.decoder(codevectors.permute(0,2,1)) + if return_decoder_latent: + return h, diversity + + reconstructed = self.up(h.float()) + reconstructed = reconstructed[:, :, :orig_mel.shape[-1]] + + mse = F.mse_loss(reconstructed, orig_mel) + return mse, diversity + + def log_codes(self, codes): + if self.internal_step % 5 == 0: + codes = torch.argmax(codes, dim=-1) + ccodes = codes[:,:,0] + for j in range(1,codes.shape[-1]): + ccodes += codes[:,:,j] * self.codebook_size ** j + codes = ccodes + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + + def get_debug_values(self, step, __): + if self.total_codes > 0: + return {'histogram_codes': self.codes[:self.total_codes]} + else: + return {} + + def update_for_step(self, step, *args): + self.quantizer.temperature = max( + self.max_gumbel_temperature * self.gumbel_temperature_decay**step, + self.min_gumbel_temperature, + ) + + +@register_model +def register_music_quantizer2(opt_net, opt): + return MusicQuantizer2(**opt_net['kwargs']) + + +if __name__ == '__main__': + model = MusicQuantizer2(inner_dim=[1024], codevector_dim=1024, codebook_size=256, codebook_groups=2) + print_network(model) + mel = torch.randn((2,256,782)) + model(mel) \ No newline at end of file diff --git a/codes/models/audio/music/transformer_diffusion6.py b/codes/models/audio/music/transformer_diffusion8.py similarity index 68% rename from codes/models/audio/music/transformer_diffusion6.py rename to codes/models/audio/music/transformer_diffusion8.py index a86af9ba..e6189b60 100644 --- a/codes/models/audio/music/transformer_diffusion6.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from models.audio.music.music_quantizer2 import MusicQuantizer2 from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock from models.lucidrains.x_transformers import Encoder, Attention, FeedForward, RMSScaleShiftNorm, RotaryEmbedding @@ -39,15 +40,16 @@ class TimestepRotaryEmbedSequential(nn.Sequential, TimestepBlock): class DietAttentionBlock(TimestepBlock): def __init__(self, in_dim, dim, heads, dropout): super().__init__() + self.rms_scale_norm = RMSScaleShiftNorm(in_dim) self.proj = nn.Linear(in_dim, dim) - self.rms_scale_norm = RMSScaleShiftNorm(dim) self.attn = Attention(dim, heads=heads, causal=False, dropout=dropout) self.ff = FeedForward(dim, in_dim, mult=1, dropout=dropout, zero_init_output=True) def forward(self, x, timestep_emb, rotary_emb): - h = self.proj(x) - h = self.rms_scale_norm(h, norm_scale_shift_inp=timestep_emb) - h, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + h = self.rms_scale_norm(x, norm_scale_shift_inp=timestep_emb) + h = self.proj(h) + k, _, _, _ = checkpoint(self.attn, h, None, None, None, None, None, rotary_emb) + h = k + h h = checkpoint(self.ff, h) return h + x @@ -59,6 +61,7 @@ class TransformerDiffusion(nn.Module): def __init__( self, prenet_channels=256, + prenet_layers=3, model_channels=512, block_channels=256, num_layers=8, @@ -107,7 +110,7 @@ class TransformerDiffusion(nn.Module): self.input_converter = nn.Linear(input_vec_dim, prenet_channels) self.code_converter = Encoder( dim=prenet_channels, - depth=3, + depth=prenet_layers, heads=prenet_heads, ff_dropout=dropout, attn_dropout=dropout, @@ -120,7 +123,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,prenet_channels)) self.rotary_embeddings = RotaryEmbedding(rotary_emb_dim) - self.cond_intg = nn.Linear(prenet_channels*2, block_channels) + self.cond_intg = nn.Linear(prenet_channels*2, model_channels) self.intg = nn.Linear(prenet_channels*2, model_channels) self.layers = TimestepRotaryEmbedSequential(*[DietAttentionBlock(model_channels, block_channels, block_channels // 64, dropout) for _ in range(num_layers)]) @@ -164,8 +167,10 @@ class TransformerDiffusion(nn.Module): unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], x.shape[-1], 1) + cond_emb = self.conditioning_embedder(conditioning_input).permute(0,2,1) + cond_emb = self.conditioning_encoder(cond_emb)[:, 0] + unused_params.extend(list(self.code_converter.parameters())) else: if precomputed_code_embeddings is not None: code_emb = precomputed_code_embeddings @@ -195,18 +200,87 @@ class TransformerDiffusion(nn.Module): return out +class TransformerDiffusionWithQuantizer(nn.Module): + def __init__(self, freeze_quantizer_until=20000, **kwargs): + super().__init__() + + self.internal_step = 0 + self.freeze_quantizer_until = freeze_quantizer_until + self.diff = TransformerDiffusion(**kwargs) + self.quantizer = MusicQuantizer2(inp_channels=256, inner_dim=[1024], codevector_dim=1024, codebook_size=256, + codebook_groups=2, max_gumbel_temperature=4, min_gumbel_temperature=.5) + self.quantizer.quantizer.temperature = self.quantizer.min_gumbel_temperature + del self.quantizer.up + + def update_for_step(self, step, *args): + self.internal_step = step + qstep = max(0, self.internal_step - self.freeze_quantizer_until) + self.quantizer.quantizer.temperature = max( + self.quantizer.max_gumbel_temperature * self.quantizer.gumbel_temperature_decay ** qstep, + self.quantizer.min_gumbel_temperature, + ) + + def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): + quant_grad_enabled = self.internal_step > self.freeze_quantizer_until + with torch.set_grad_enabled(quant_grad_enabled): + proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True) + proj = proj.permute(0,2,1) + + # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. + if not quant_grad_enabled: + unused = 0 + for p in self.quantizer.parameters(): + unused = unused + p.mean() * 0 + proj = proj + unused + diversity_loss = diversity_loss * 0 + + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + if disable_diversity: + return diff + return diff, diversity_loss + + def get_debug_values(self, step, __): + if self.quantizer.total_codes > 0: + return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes]} + else: + return {} + + @register_model -def register_transformer_diffusion6(opt_net, opt): +def register_transformer_diffusion8(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs']) +@register_model +def register_transformer_diffusion8_with_quantizer(opt_net, opt): + return TransformerDiffusionWithQuantizer(**opt_net['kwargs']) + + +""" +# For TFD5 if __name__ == '__main__': clip = torch.randn(2, 256, 400) aligned_sequence = torch.randn(2,100,512) cond = torch.randn(2, 256, 400) ts = torch.LongTensor([600, 600]) - model = TransformerDiffusion(model_channels=4096, block_channels=2048, prenet_channels=1024, num_layers=16) + model = TransformerDiffusion(model_channels=3072, block_channels=1536, prenet_channels=1536) torch.save(model, 'sample.pth') print_network(model) o = model(clip, ts, aligned_sequence, cond) +""" + +if __name__ == '__main__': + clip = torch.randn(2, 256, 400) + cond = torch.randn(2, 256, 400) + ts = torch.LongTensor([600, 600]) + model = TransformerDiffusionWithQuantizer(model_channels=2048, block_channels=1024, prenet_channels=1024, input_vec_dim=1024, num_layers=16, prenet_layers=6) + + #quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') + #diff_weights = torch.load('X:\\dlas\\experiments\\train_music_diffusion_tfd5\\models\\48000_generator_ema.pth') + #model.quantizer.load_state_dict(quant_weights, strict=False) + #model.diff.load_state_dict(diff_weights) + + torch.save(model.state_dict(), 'sample.pth') + print_network(model) + o = model(clip, ts, clip, cond) diff --git a/codes/models/audio/music/unet_diffusion_music_codes.py b/codes/models/audio/music/unet_diffusion_music_codes.py index 1533b324..e75f826e 100644 --- a/codes/models/audio/music/unet_diffusion_music_codes.py +++ b/codes/models/audio/music/unet_diffusion_music_codes.py @@ -530,12 +530,12 @@ class UNetMusicModel(nn.Module): ch, time_embed_dim, dropout, - out_channels=mult * model_channels, + out_channels=int(mult * model_channels), dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] - ch = mult * model_channels + ch = int(mult * model_channels) if ds in attention_resolutions: layers.append( AttentionBlock( @@ -605,12 +605,12 @@ class UNetMusicModel(nn.Module): ch + ich, time_embed_dim, dropout, - out_channels=model_channels * mult, + out_channels=int(model_channels * mult), dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] - ch = model_channels * mult + ch = int(model_channels * mult) if ds in attention_resolutions: layers.append( AttentionBlock( @@ -749,9 +749,9 @@ if __name__ == '__main__': clip = torch.randn(2, 256, 782) cond = torch.randn(2, 256, 782) ts = torch.LongTensor([600, 600]) - model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=640, num_res_blocks=3, input_vec_dim=1024, - attention_resolutions=(2,4), channel_mult=(1,2,3), dims=1, - use_scale_shift_norm=True, dropout=.1, num_heads=8, unconditioned_percentage=.4) + model = UNetMusicModelWithQuantizer(in_channels=256, out_channels=512, model_channels=1024, num_res_blocks=3, input_vec_dim=1024, + attention_resolutions=(2,4), channel_mult=(1,1.5,2), dims=1, + use_scale_shift_norm=True, dropout=.1, num_heads=16, unconditioned_percentage=.4) print_network(model) quant_weights = torch.load('D:\\dlas\\experiments\\train_music_quant\\models\\18000_generator_ema.pth') diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index ea1daed4..b6d8529c 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -328,7 +328,7 @@ class Mel2vecCodesInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) self.m2v = get_music_codegen() - del self.m2v.m2v.encoder # This is a big memory sink which will not get used. + del self.m2v.quantizer.encoder # This is a big memory sink which will not get used. self.needs_move = True self.inj_vector = opt_get(opt, ['vector'], False)