diff --git a/codes/models/audio/music/mel2vec_codes_gpt.py b/codes/models/audio/music/mel2vec_codes_gpt.py new file mode 100644 index 00000000..9cbf9104 --- /dev/null +++ b/codes/models/audio/music/mel2vec_codes_gpt.py @@ -0,0 +1,50 @@ +import torch +from torch import nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2Model + +from trainer.networks import register_model +from utils.util import opt_get + + +class Mel2VecCodesGpt(nn.Module): + def __init__(self, dim, layers, num_groups=8, num_vectors=8): + super().__init__() + + self.num_groups = num_groups + + self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, + n_inner=dim*2) + self.gpt = GPT2Model(self.config) + del self.gpt.wte # Unused, we'll do our own embeddings. + self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)]) + self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)]) + + def forward(self, codes): + assert codes.shape[-1] == self.num_groups + + inputs = codes[:, :-1] + targets = codes[:, 1:] + + h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] + h = torch.cat(h, dim=-1) + h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state + + losses = 0 + for i, head in enumerate(self.heads): + logits = head(h).permute(0,2,1) + loss = F.cross_entropy(logits, targets[:,:,i]) + losses = losses + loss + + return losses / self.num_groups + + +@register_model +def register_music_gpt(opt_net, opt): + return Mel2VecCodesGpt(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = Mel2VecCodesGpt(512, 8) + codes = torch.randint(0,8, (2,300,8)) + model(codes) \ No newline at end of file diff --git a/codes/scripts/audio/gen/use_mel2vec_codes.py b/codes/scripts/audio/gen/use_mel2vec_codes.py new file mode 100644 index 00000000..848561a2 --- /dev/null +++ b/codes/scripts/audio/gen/use_mel2vec_codes.py @@ -0,0 +1,39 @@ +import torch + +from models.audio.mel2vec import ContrastiveTrainingWrapper +from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector +from utils.util import load_audio + +def collapse_codegroups(codes): + codes = codes.clone() + groups = codes.shape[-1] + for k in range(groups): + codes[:,:,k] = codes[:,:,k] * groups ** k + codes = codes.sum(-1) + return codes + + +def recover_codegroups(codes, groups): + codes = codes.clone() + output = torch.LongTensor(codes.shape[0], codes.shape[1], groups, device=codes.device) + for k in range(groups): + output[:,:,k] = codes % groups + codes = codes // groups + return output + + +if __name__ == '__main__': + model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0, + mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True) + model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth")) + model.eval() + + wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050) + mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out'] + + codes = model.get_codes(mel) + + collapsed = collapse_codegroups(codes) + recovered = recover_codegroups(collapsed, 8) + + print(codes) \ No newline at end of file diff --git a/codes/scripts/audio/prep_music/phase_1_split_files.py b/codes/scripts/audio/prep_music/phase_1_split_files.py index 29bc96f3..42e9853f 100644 --- a/codes/scripts/audio/prep_music/phase_1_split_files.py +++ b/codes/scripts/audio/prep_music/phase_1_split_files.py @@ -13,7 +13,7 @@ import torch import torchaudio from tqdm import tqdm -from data.util import find_audio_files +from data.util import find_audio_files, find_files_of_type from utils.util import load_audio @@ -36,27 +36,16 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip, for i, spl in enumerate(splits): if spl.shape[-1] != duration_per_clip*sampling_rate: continue # In general, this just means "skip the last item". - # Perform some checks on subclips within this clip. - passed_checks = True - for s in range(duration_per_clip // 2): - subclip = spl[s*2*sampling_rate:(s+1)*2*sampling_rate] - # Are significant parts of any of this clip just silence? - if subclip.var() < .001: - passed_checks=False - if not passed_checks: - break - if not passed_checks: - continue torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S") report_progress(progress_file, file) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux') - parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt') - parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100') - parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8) + parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4') + parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt') + parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4') + parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6) parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30) args = parser.parse_args() @@ -66,7 +55,8 @@ if __name__ == '__main__': for line in f.readlines(): processed_files.add(line.strip()) - files = set(find_audio_files(args.path, include_nonwav=True)) + files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0]) + #files = set(find_audio_files(args.path, include_nonwav=True)) orig_len = len(files) files = files - processed_files print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.") diff --git a/codes/train.py b/codes/train.py index 721f6a09..bb2067ca 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_mel2vec.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index a76910da..4f178d18 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -320,3 +320,26 @@ class AudioUnshuffleInjector(Injector): def forward(self, state): inp = state[self.input] return {self.output: pixel_unshuffle_1d(inp, self.compression)} + + +class Mel2vecCodesInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + for_what = opt_get(opt, ['for'], 'music') + + from models.audio.mel2vec import ContrastiveTrainingWrapper + self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, + mask_time_prob=0, + mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, + disable_custom_linear_init=True) + self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu'))) + del self.m2v.m2v.encoder # This is a big memory sink which will not get used. + self.needs_move = True + + def forward(self, state): + mels = state[self.input] + with torch.no_grad(): + if self.needs_move: + self.m2v = self.m2v.to(mels.device) + codes = self.m2v.get_codes(mels) + return {self.output: codes}