m2v stuff

This commit is contained in:
James Betker 2022-05-20 11:01:17 -06:00
parent c9c16e3b01
commit e9fb2ead9a
5 changed files with 120 additions and 18 deletions

View File

@ -0,0 +1,50 @@
import torch
from torch import nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model
from trainer.networks import register_model
from utils.util import opt_get
class Mel2VecCodesGpt(nn.Module):
def __init__(self, dim, layers, num_groups=8, num_vectors=8):
super().__init__()
self.num_groups = num_groups
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
n_inner=dim*2)
self.gpt = GPT2Model(self.config)
del self.gpt.wte # Unused, we'll do our own embeddings.
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
def forward(self, codes):
assert codes.shape[-1] == self.num_groups
inputs = codes[:, :-1]
targets = codes[:, 1:]
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
h = torch.cat(h, dim=-1)
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
losses = 0
for i, head in enumerate(self.heads):
logits = head(h).permute(0,2,1)
loss = F.cross_entropy(logits, targets[:,:,i])
losses = losses + loss
return losses / self.num_groups
@register_model
def register_music_gpt(opt_net, opt):
return Mel2VecCodesGpt(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
model = Mel2VecCodesGpt(512, 8)
codes = torch.randint(0,8, (2,300,8))
model(codes)

View File

@ -0,0 +1,39 @@
import torch
from models.audio.mel2vec import ContrastiveTrainingWrapper
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
from utils.util import load_audio
def collapse_codegroups(codes):
codes = codes.clone()
groups = codes.shape[-1]
for k in range(groups):
codes[:,:,k] = codes[:,:,k] * groups ** k
codes = codes.sum(-1)
return codes
def recover_codegroups(codes, groups):
codes = codes.clone()
output = torch.LongTensor(codes.shape[0], codes.shape[1], groups, device=codes.device)
for k in range(groups):
output[:,:,k] = codes % groups
codes = codes // groups
return output
if __name__ == '__main__':
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
model.eval()
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out']
codes = model.get_codes(mel)
collapsed = collapse_codegroups(codes)
recovered = recover_codegroups(collapsed, 8)
print(codes)

View File

@ -13,7 +13,7 @@ import torch
import torchaudio
from tqdm import tqdm
from data.util import find_audio_files
from data.util import find_audio_files, find_files_of_type
from utils.util import load_audio
@ -36,27 +36,16 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
for i, spl in enumerate(splits):
if spl.shape[-1] != duration_per_clip*sampling_rate:
continue # In general, this just means "skip the last item".
# Perform some checks on subclips within this clip.
passed_checks = True
for s in range(duration_per_clip // 2):
subclip = spl[s*2*sampling_rate:(s+1)*2*sampling_rate]
# Are significant parts of any of this clip just silence?
if subclip.var() < .001:
passed_checks=False
if not passed_checks:
break
if not passed_checks:
continue
torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S")
report_progress(progress_file, file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux')
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt')
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100')
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4')
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt')
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4')
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6)
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
args = parser.parse_args()
@ -66,7 +55,8 @@ if __name__ == '__main__':
for line in f.readlines():
processed_files.add(line.strip())
files = set(find_audio_files(args.path, include_nonwav=True))
files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0])
#files = set(find_audio_files(args.path, include_nonwav=True))
orig_len = len(files)
files = files - processed_files
print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")

View File

@ -327,7 +327,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_mel2vec.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -320,3 +320,26 @@ class AudioUnshuffleInjector(Injector):
def forward(self, state):
inp = state[self.input]
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
class Mel2vecCodesInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
for_what = opt_get(opt, ['for'], 'music')
from models.audio.mel2vec import ContrastiveTrainingWrapper
self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
mask_time_prob=0,
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8,
disable_custom_linear_init=True)
self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu')))
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
self.needs_move = True
def forward(self, state):
mels = state[self.input]
with torch.no_grad():
if self.needs_move:
self.m2v = self.m2v.to(mels.device)
codes = self.m2v.get_codes(mels)
return {self.output: codes}