some stuff

This commit is contained in:
James Betker 2022-05-27 11:40:31 -06:00
parent 5efeee6b97
commit 490d39b967
3 changed files with 62 additions and 6 deletions

View File

@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock, AttentionBlock
from models.audio.music.flat_diffusion import MultiGroupEmbedding
from trainer.networks import register_model
from utils.util import checkpoint
class Code2Mel(nn.Module):
def __init__(self, out_dim=256, base_dim=1024, num_tokens=16, num_groups=4, dropout=.1):
super().__init__()
self.emb = MultiGroupEmbedding(num_tokens, num_groups, base_dim)
self.base_blocks = nn.Sequential(ResBlock(base_dim, dropout, dims=1),
AttentionBlock(base_dim, num_heads=base_dim//64),
ResBlock(base_dim, dropout, dims=1))
l2dim = base_dim-256
self.l2_up_block = nn.Conv1d(base_dim, l2dim, kernel_size=5, padding=2)
self.l2_blocks = nn.Sequential(ResBlock(l2dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l2dim, num_heads=base_dim//64),
ResBlock(l2dim, dropout, dims=1),
ResBlock(l2dim, dropout, dims=1))
l3dim = l2dim-256
self.l3_up_block = nn.Conv1d(l2dim, l3dim, kernel_size=5, padding=2)
self.l3_blocks = nn.Sequential(ResBlock(l3dim, dropout, kernel_size=5, dims=1),
AttentionBlock(l3dim, num_heads=base_dim//64),
ResBlock(l3dim, dropout, kernel_size=5, dims=1),
ResBlock(l3dim, dropout, dims=1))
self.final_block = nn.Conv1d(l3dim, out_dim, kernel_size=3, padding=1)
def forward(self, codes, target):
with torch.autocast(codes.device.type):
h = self.emb(codes).permute(0,2,1)
h = checkpoint(self.base_blocks, h)
h = F.interpolate(h, scale_factor=2, mode='linear')
h = self.l2_up_block(h)
h = checkpoint(self.l2_blocks, h)
h = F.interpolate(h, size=target.shape[-1], mode='linear')
h = self.l3_up_block(h)
h = checkpoint(self.l3_blocks, h.float())
pred = self.final_block(h)
return F.mse_loss(pred, target), pred
@register_model
def register_code2mel(opt_net, opt):
return Code2Mel(**opt_net['kwargs'])
if __name__ == '__main__':
model = Code2Mel()
codes = torch.randint(0,16, (2,200,4))
target = torch.randn(2,256,804)
model(codes, target)

View File

@ -5,4 +5,4 @@ https://github.com/neonbjb/demucs
conda activate demucs conda activate demucs
python setup.py install python setup.py install
CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda --two-stems=vocals CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda --two-stems=vocals
``` ``

View File

@ -42,9 +42,9 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4') parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\silk')
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt') parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\silk\\already_processed.txt')
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4') parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\silk')
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6) parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6)
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30) parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
args = parser.parse_args() args = parser.parse_args()
@ -55,8 +55,7 @@ if __name__ == '__main__':
for line in f.readlines(): for line in f.readlines():
processed_files.add(line.strip()) processed_files.add(line.strip())
files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0]) files = set(find_audio_files(args.path, include_nonwav=True))
#files = set(find_audio_files(args.path, include_nonwav=True))
orig_len = len(files) orig_len = len(files)
files = files - processed_files files = files - processed_files
print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.") print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")