m2v stuff
This commit is contained in:
parent
c9c16e3b01
commit
e9fb2ead9a
50
codes/models/audio/music/mel2vec_codes_gpt.py
Normal file
50
codes/models/audio/music/mel2vec_codes_gpt.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
class Mel2VecCodesGpt(nn.Module):
|
||||
def __init__(self, dim, layers, num_groups=8, num_vectors=8):
|
||||
super().__init__()
|
||||
|
||||
self.num_groups = num_groups
|
||||
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2)
|
||||
self.gpt = GPT2Model(self.config)
|
||||
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||
|
||||
def forward(self, codes):
|
||||
assert codes.shape[-1] == self.num_groups
|
||||
|
||||
inputs = codes[:, :-1]
|
||||
targets = codes[:, 1:]
|
||||
|
||||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1)
|
||||
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||
|
||||
losses = 0
|
||||
for i, head in enumerate(self.heads):
|
||||
logits = head(h).permute(0,2,1)
|
||||
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||
losses = losses + loss
|
||||
|
||||
return losses / self.num_groups
|
||||
|
||||
|
||||
@register_model
|
||||
def register_music_gpt(opt_net, opt):
|
||||
return Mel2VecCodesGpt(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Mel2VecCodesGpt(512, 8)
|
||||
codes = torch.randint(0,8, (2,300,8))
|
||||
model(codes)
|
39
codes/scripts/audio/gen/use_mel2vec_codes.py
Normal file
39
codes/scripts/audio/gen/use_mel2vec_codes.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
from utils.util import load_audio
|
||||
|
||||
def collapse_codegroups(codes):
|
||||
codes = codes.clone()
|
||||
groups = codes.shape[-1]
|
||||
for k in range(groups):
|
||||
codes[:,:,k] = codes[:,:,k] * groups ** k
|
||||
codes = codes.sum(-1)
|
||||
return codes
|
||||
|
||||
|
||||
def recover_codegroups(codes, groups):
|
||||
codes = codes.clone()
|
||||
output = torch.LongTensor(codes.shape[0], codes.shape[1], groups, device=codes.device)
|
||||
for k in range(groups):
|
||||
output[:,:,k] = codes % groups
|
||||
codes = codes // groups
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
||||
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
||||
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
|
||||
model.eval()
|
||||
|
||||
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
||||
mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out']
|
||||
|
||||
codes = model.get_codes(mel)
|
||||
|
||||
collapsed = collapse_codegroups(codes)
|
||||
recovered = recover_codegroups(collapsed, 8)
|
||||
|
||||
print(codes)
|
|
@ -13,7 +13,7 @@ import torch
|
|||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.util import find_audio_files
|
||||
from data.util import find_audio_files, find_files_of_type
|
||||
from utils.util import load_audio
|
||||
|
||||
|
||||
|
@ -36,27 +36,16 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
|
|||
for i, spl in enumerate(splits):
|
||||
if spl.shape[-1] != duration_per_clip*sampling_rate:
|
||||
continue # In general, this just means "skip the last item".
|
||||
# Perform some checks on subclips within this clip.
|
||||
passed_checks = True
|
||||
for s in range(duration_per_clip // 2):
|
||||
subclip = spl[s*2*sampling_rate:(s+1)*2*sampling_rate]
|
||||
# Are significant parts of any of this clip just silence?
|
||||
if subclip.var() < .001:
|
||||
passed_checks=False
|
||||
if not passed_checks:
|
||||
break
|
||||
if not passed_checks:
|
||||
continue
|
||||
torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S")
|
||||
report_progress(progress_file, file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux')
|
||||
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt')
|
||||
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100')
|
||||
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
|
||||
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4')
|
||||
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt')
|
||||
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4')
|
||||
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6)
|
||||
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -66,7 +55,8 @@ if __name__ == '__main__':
|
|||
for line in f.readlines():
|
||||
processed_files.add(line.strip())
|
||||
|
||||
files = set(find_audio_files(args.path, include_nonwav=True))
|
||||
files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0])
|
||||
#files = set(find_audio_files(args.path, include_nonwav=True))
|
||||
orig_len = len(files)
|
||||
files = files - processed_files
|
||||
print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")
|
||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_mel2vec.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -320,3 +320,26 @@ class AudioUnshuffleInjector(Injector):
|
|||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
|
||||
|
||||
|
||||
class Mel2vecCodesInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
for_what = opt_get(opt, ['for'], 'music')
|
||||
|
||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||
self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
|
||||
mask_time_prob=0,
|
||||
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8,
|
||||
disable_custom_linear_init=True)
|
||||
self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu')))
|
||||
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
|
||||
self.needs_move = True
|
||||
|
||||
def forward(self, state):
|
||||
mels = state[self.input]
|
||||
with torch.no_grad():
|
||||
if self.needs_move:
|
||||
self.m2v = self.m2v.to(mels.device)
|
||||
codes = self.m2v.get_codes(mels)
|
||||
return {self.output: codes}
|
||||
|
|
Loading…
Reference in New Issue
Block a user