From c0db85bf4f745c5cb8e6586acfb3abc436f335f5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 31 May 2022 21:06:54 -0600 Subject: [PATCH] music quantizer --- .../data/audio/unsupervised_audio_dataset.py | 11 +- codes/data/util.py | 2 +- codes/models/audio/music/music_quantizer.py | 236 ++++++++++++++++++ codes/train.py | 2 +- 4 files changed, 245 insertions(+), 6 deletions(-) create mode 100644 codes/models/audio/music/music_quantizer.py diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 3277e1ef..f4d2fedb 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -182,9 +182,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): if __name__ == '__main__': params = { 'mode': 'unsupervised_audio', - 'path': ['Y:\\split\\yt-music'], + 'path': ['Y:\\separated\\yt-music-0', 'Y:\\separated\\yt-music-1', + 'Y:\\separated\\bt-music-1', 'Y:\\separated\\bt-music-2', + 'Y:\\separated\\bt-music-3', 'Y:\\separated\\bt-music-4', + 'Y:\\separated\\bt-music-5'], 'cache_path': 'Y:\\separated\\no-vocals-cache-win.pth', - 'endswith': 'no_vocals.wav', + 'endswith': ['no_vocals.wav'], 'sampling_rate': 22050, 'pad_to_samples': 200000, 'resample_clip': False, @@ -202,6 +205,6 @@ if __name__ == '__main__': for b in tqdm(dl): for b_ in range(b['clip'].shape[0]): #pass - torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate) - torchaudio.save(f'{i}_alt_clip_{b_}.wav', b['alt_clips'][b_], ds.sampling_rate) + #torchaudio.save(f'{i}_clip_{b_}.wav', b['clip'][b_], ds.sampling_rate) + #torchaudio.save(f'{i}_alt_clip_{b_}.wav', b['alt_clips'][b_], ds.sampling_rate) i += 1 diff --git a/codes/data/util.py b/codes/data/util.py index 5ea94575..529cba29 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -599,7 +599,7 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=[], not before = len(output) def filter_fn(p): for e in endswith: - if not p.endswith(endswith): + if not p.endswith(e): return False for e in not_endswith: if p.endswith(e): diff --git a/codes/models/audio/music/music_quantizer.py b/codes/models/audio/music/music_quantizer.py new file mode 100644 index 00000000..d304ae35 --- /dev/null +++ b/codes/models/audio/music/music_quantizer.py @@ -0,0 +1,236 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from models.arch_util import zero_module +from trainer.networks import register_model + + +class Downsample(nn.Module): + def __init__(self, chan_in, chan_out): + super().__init__() + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=.5, mode='linear') + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, chan_in, chan_out): + super().__init__() + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=3, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='linear') + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, 3, padding = 1), + nn.GroupNorm(8, chan), + nn.SiLU(), + nn.Conv1d(chan, chan, 3, padding = 1), + nn.GroupNorm(8, chan), + nn.SiLU(), + zero_module(nn.Conv1d(chan, chan, 3, padding = 1)), + ) + + def forward(self, x): + return self.net(x) + x + + +class Wav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, proj_dim=1024, codevector_dim=512, num_codevector_groups=2, num_codevectors_per_group=320): + super().__init__() + self.codevector_dim = codevector_dim + self.num_groups = num_codevector_groups + self.num_vars = num_codevectors_per_group + self.num_codevectors = num_codevector_groups * num_codevectors_per_group + + if codevector_dim % self.num_groups != 0: + raise ValueError( + f"`codevector_dim {codevector_dim} must be divisible " + f"by `num_codevector_groups` {num_codevector_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(proj_dim, self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + # Parameters init. + self.weight_proj.weight.data.normal_(mean=0.0, std=1) + self.weight_proj.bias.data.zero_() + nn.init.uniform_(self.codevectors) + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def get_codes(self, hidden_states): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + codevector_idx = hidden_states.argmax(dim=-1) + idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) + return idxs + + def forward(self, hidden_states, mask_time_indices=None, return_probs=False): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # compute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = ( + codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + .sum(-2) + .view(batch_size, sequence_length, -1) + ) + + if return_probs: + return codevectors, perplexity, codevector_probs.view(batch_size, sequence_length, self.num_groups, self.num_vars) + return codevectors, perplexity + + +class MusicQuantizer(nn.Module): + def __init__(self, inp_channels=256, inner_dim=1024, codevector_dim=1024, down_steps=2, + max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, + codebook_size=16, codebook_groups=4): + super().__init__() + self.max_gumbel_temperature = max_gumbel_temperature + self.min_gumbel_temperature = min_gumbel_temperature + self.gumbel_temperature_decay = gumbel_temperature_decay + self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim, codevector_dim=codevector_dim, + num_codevector_groups=codebook_groups, + num_codevectors_per_group=codebook_size) + self.num_losses_record = [] + + if down_steps == 0: + self.down = nn.Conv1d(inp_channels, inner_dim, kernel_size=3, padding=1) + self.up = nn.Conv1d(inner_dim, inp_channels, kernel_size=3, padding=1) + elif down_steps == 2: + self.down = nn.Sequential(nn.Conv1d(inp_channels, inner_dim//4, kernel_size=3, padding=1), + Downsample(inner_dim//4, inner_dim//2), + Downsample(inner_dim//2, inner_dim)) + self.up = nn.Sequential(Upsample(inner_dim, inner_dim//2), + Upsample(inner_dim//2, inner_dim//4), + nn.Conv1d(inner_dim//4, inp_channels, kernel_size=3, padding=1)) + + self.encoder = nn.Sequential(ResBlock(inner_dim), + ResBlock(inner_dim), + ResBlock(inner_dim)) + self.enc_norm = nn.LayerNorm(inner_dim, eps=1e-5) + self.decoder = nn.Sequential(nn.Conv1d(codevector_dim, inner_dim, kernel_size=3, padding=1), + ResBlock(inner_dim), + ResBlock(inner_dim), + ResBlock(inner_dim)) + + self.codes = torch.zeros((3000000,), dtype=torch.long) + self.internal_step = 0 + self.code_ind = 0 + self.total_codes = 0 + + def get_codes(self, mel, project=False): + proj = self.m2v.input_blocks(mel).permute(0,2,1) + _, proj = self.m2v.projector(proj) + if project: + proj, _ = self.quantizer(proj) + return proj + else: + return self.quantizer.get_codes(proj) + + def forward(self, mel): + h = self.down(mel) + h = self.encoder(h) + h = self.enc_norm(h.permute(0,2,1)) + codevectors, perplexity, codes = self.quantizer(h, return_probs=True) + h = self.decoder(codevectors.permute(0,2,1)) + reconstructed = self.up(h) + + mse = F.mse_loss(reconstructed, mel) + diversity = (self.quantizer.num_codevectors - perplexity) / self.quantizer.num_codevectors + + self.log_codes(codes) + + return mse, diversity + + def log_codes(self, codes): + if self.internal_step % 5 == 0: + codes = torch.argmax(codes, dim=-1) + codes = codes[:,:,0] + codes[:,:,1] * 16 + codes[:,:,2] * 16 ** 2 + codes[:,:,3] * 16 ** 3 + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i:i+l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + + def get_debug_values(self, step, __): + if self.total_codes > 0: + return {'histogram_codes': self.codes[:self.total_codes]} + else: + return {} + + +@register_model +def register_music_quantizer(opt_net, opt): + return MusicQuantizer(**opt_net['kwargs']) + + +if __name__ == '__main__': + model = MusicQuantizer() + mel = torch.randn((2,256,200)) + model(mel) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 4717dc68..6e1ea265 100644 --- a/codes/train.py +++ b/codes/train.py @@ -338,7 +338,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_tfd.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_quant.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)