forked from mrq/DL-Art-School
m2v stuff
This commit is contained in:
parent
c9c16e3b01
commit
e9fb2ead9a
50
codes/models/audio/music/mel2vec_codes_gpt.py
Normal file
50
codes/models/audio/music/mel2vec_codes_gpt.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import GPT2Config, GPT2Model
|
||||||
|
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class Mel2VecCodesGpt(nn.Module):
|
||||||
|
def __init__(self, dim, layers, num_groups=8, num_vectors=8):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_groups = num_groups
|
||||||
|
|
||||||
|
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||||
|
n_inner=dim*2)
|
||||||
|
self.gpt = GPT2Model(self.config)
|
||||||
|
del self.gpt.wte # Unused, we'll do our own embeddings.
|
||||||
|
self.embeddings = nn.ModuleList([nn.Embedding(num_vectors, dim//num_groups) for _ in range(num_groups)])
|
||||||
|
self.heads = nn.ModuleList([nn.Linear(dim, num_vectors) for _ in range(num_groups)])
|
||||||
|
|
||||||
|
def forward(self, codes):
|
||||||
|
assert codes.shape[-1] == self.num_groups
|
||||||
|
|
||||||
|
inputs = codes[:, :-1]
|
||||||
|
targets = codes[:, 1:]
|
||||||
|
|
||||||
|
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||||
|
h = torch.cat(h, dim=-1)
|
||||||
|
h = self.gpt(inputs_embeds=h, return_dict=True).last_hidden_state
|
||||||
|
|
||||||
|
losses = 0
|
||||||
|
for i, head in enumerate(self.heads):
|
||||||
|
logits = head(h).permute(0,2,1)
|
||||||
|
loss = F.cross_entropy(logits, targets[:,:,i])
|
||||||
|
losses = losses + loss
|
||||||
|
|
||||||
|
return losses / self.num_groups
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_music_gpt(opt_net, opt):
|
||||||
|
return Mel2VecCodesGpt(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = Mel2VecCodesGpt(512, 8)
|
||||||
|
codes = torch.randint(0,8, (2,300,8))
|
||||||
|
model(codes)
|
39
codes/scripts/audio/gen/use_mel2vec_codes.py
Normal file
39
codes/scripts/audio/gen/use_mel2vec_codes.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||||
|
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||||
|
from utils.util import load_audio
|
||||||
|
|
||||||
|
def collapse_codegroups(codes):
|
||||||
|
codes = codes.clone()
|
||||||
|
groups = codes.shape[-1]
|
||||||
|
for k in range(groups):
|
||||||
|
codes[:,:,k] = codes[:,:,k] * groups ** k
|
||||||
|
codes = codes.sum(-1)
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
def recover_codegroups(codes, groups):
|
||||||
|
codes = codes.clone()
|
||||||
|
output = torch.LongTensor(codes.shape[0], codes.shape[1], groups, device=codes.device)
|
||||||
|
for k in range(groups):
|
||||||
|
output[:,:,k] = codes % groups
|
||||||
|
codes = codes // groups
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
||||||
|
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
||||||
|
model.load_state_dict(torch.load("X:\\dlas\\experiments\\train_music_mel2vec\\models\\29000_generator_ema.pth"))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
||||||
|
mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out']
|
||||||
|
|
||||||
|
codes = model.get_codes(mel)
|
||||||
|
|
||||||
|
collapsed = collapse_codegroups(codes)
|
||||||
|
recovered = recover_codegroups(collapsed, 8)
|
||||||
|
|
||||||
|
print(codes)
|
|
@ -13,7 +13,7 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from data.util import find_audio_files
|
from data.util import find_audio_files, find_files_of_type
|
||||||
from utils.util import load_audio
|
from utils.util import load_audio
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,27 +36,16 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
|
||||||
for i, spl in enumerate(splits):
|
for i, spl in enumerate(splits):
|
||||||
if spl.shape[-1] != duration_per_clip*sampling_rate:
|
if spl.shape[-1] != duration_per_clip*sampling_rate:
|
||||||
continue # In general, this just means "skip the last item".
|
continue # In general, this just means "skip the last item".
|
||||||
# Perform some checks on subclips within this clip.
|
|
||||||
passed_checks = True
|
|
||||||
for s in range(duration_per_clip // 2):
|
|
||||||
subclip = spl[s*2*sampling_rate:(s+1)*2*sampling_rate]
|
|
||||||
# Are significant parts of any of this clip just silence?
|
|
||||||
if subclip.var() < .001:
|
|
||||||
passed_checks=False
|
|
||||||
if not passed_checks:
|
|
||||||
break
|
|
||||||
if not passed_checks:
|
|
||||||
continue
|
|
||||||
torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S")
|
torchaudio.save(f'{outdir}/{i:05d}.wav', spl.unsqueeze(0), sampling_rate, encoding="PCM_S")
|
||||||
report_progress(progress_file, file)
|
report_progress(progress_file, file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux')
|
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4')
|
||||||
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt')
|
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt')
|
||||||
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100')
|
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4')
|
||||||
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
|
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6)
|
||||||
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
|
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -66,7 +55,8 @@ if __name__ == '__main__':
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
processed_files.add(line.strip())
|
processed_files.add(line.strip())
|
||||||
|
|
||||||
files = set(find_audio_files(args.path, include_nonwav=True))
|
files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0])
|
||||||
|
#files = set(find_audio_files(args.path, include_nonwav=True))
|
||||||
orig_len = len(files)
|
orig_len = len(files)
|
||||||
files = files - processed_files
|
files = files - processed_files
|
||||||
print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")
|
print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")
|
||||||
|
|
|
@ -327,7 +327,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_mel2vec.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_gpt.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -320,3 +320,26 @@ class AudioUnshuffleInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
inp = state[self.input]
|
inp = state[self.input]
|
||||||
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
|
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
|
||||||
|
|
||||||
|
|
||||||
|
class Mel2vecCodesInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
for_what = opt_get(opt, ['for'], 'music')
|
||||||
|
|
||||||
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||||
|
self.m2v = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0,
|
||||||
|
mask_time_prob=0,
|
||||||
|
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8,
|
||||||
|
disable_custom_linear_init=True)
|
||||||
|
self.m2v.load_state_dict(torch.load(f"../experiments/m2v_{for_what}.pth", map_location=torch.device('cpu')))
|
||||||
|
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
|
||||||
|
self.needs_move = True
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
mels = state[self.input]
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.needs_move:
|
||||||
|
self.m2v = self.m2v.to(mels.device)
|
||||||
|
codes = self.m2v.get_codes(mels)
|
||||||
|
return {self.output: codes}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user