diff --git a/codes/models/audio/music/m2v_code_to_mel.py b/codes/models/audio/music/m2v_code_to_mel.py new file mode 100644 index 00000000..91e25363 --- /dev/null +++ b/codes/models/audio/music/m2v_code_to_mel.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.arch_util import ResBlock, AttentionBlock +from models.audio.music.flat_diffusion import MultiGroupEmbedding +from trainer.networks import register_model +from utils.util import checkpoint + + +class Code2Mel(nn.Module): + def __init__(self, out_dim=256, base_dim=1024, num_tokens=16, num_groups=4, dropout=.1): + super().__init__() + self.emb = MultiGroupEmbedding(num_tokens, num_groups, base_dim) + self.base_blocks = nn.Sequential(ResBlock(base_dim, dropout, dims=1), + AttentionBlock(base_dim, num_heads=base_dim//64), + ResBlock(base_dim, dropout, dims=1)) + l2dim = base_dim-256 + self.l2_up_block = nn.Conv1d(base_dim, l2dim, kernel_size=5, padding=2) + self.l2_blocks = nn.Sequential(ResBlock(l2dim, dropout, kernel_size=5, dims=1), + AttentionBlock(l2dim, num_heads=base_dim//64), + ResBlock(l2dim, dropout, kernel_size=5, dims=1), + AttentionBlock(l2dim, num_heads=base_dim//64), + ResBlock(l2dim, dropout, dims=1), + ResBlock(l2dim, dropout, dims=1)) + l3dim = l2dim-256 + self.l3_up_block = nn.Conv1d(l2dim, l3dim, kernel_size=5, padding=2) + self.l3_blocks = nn.Sequential(ResBlock(l3dim, dropout, kernel_size=5, dims=1), + AttentionBlock(l3dim, num_heads=base_dim//64), + ResBlock(l3dim, dropout, kernel_size=5, dims=1), + ResBlock(l3dim, dropout, dims=1)) + self.final_block = nn.Conv1d(l3dim, out_dim, kernel_size=3, padding=1) + + def forward(self, codes, target): + with torch.autocast(codes.device.type): + h = self.emb(codes).permute(0,2,1) + h = checkpoint(self.base_blocks, h) + h = F.interpolate(h, scale_factor=2, mode='linear') + h = self.l2_up_block(h) + h = checkpoint(self.l2_blocks, h) + h = F.interpolate(h, size=target.shape[-1], mode='linear') + h = self.l3_up_block(h) + h = checkpoint(self.l3_blocks, h.float()) + pred = self.final_block(h) + return F.mse_loss(pred, target), pred + + +@register_model +def register_code2mel(opt_net, opt): + return Code2Mel(**opt_net['kwargs']) + + +if __name__ == '__main__': + model = Code2Mel() + codes = torch.randint(0,16, (2,200,4)) + target = torch.randn(2,256,804) + model(codes, target) \ No newline at end of file diff --git a/codes/scripts/audio/prep_music/demucs_notes.txt b/codes/scripts/audio/prep_music/demucs_notes.txt index 064e791f..f54f03be 100644 --- a/codes/scripts/audio/prep_music/demucs_notes.txt +++ b/codes/scripts/audio/prep_music/demucs_notes.txt @@ -5,4 +5,4 @@ https://github.com/neonbjb/demucs conda activate demucs python setup.py install CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda --two-stems=vocals -``` \ No newline at end of file +`` \ No newline at end of file diff --git a/codes/scripts/audio/prep_music/phase_1_split_files.py b/codes/scripts/audio/prep_music/phase_1_split_files.py index 42e9853f..21dc6fed 100644 --- a/codes/scripts/audio/prep_music/phase_1_split_files.py +++ b/codes/scripts/audio/prep_music/phase_1_split_files.py @@ -42,9 +42,9 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip, if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music4') - parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\music\\bt-music4\\already_processed.txt') - parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\flacced\\bt-music-4') + parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\silk') + parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\silk\\already_processed.txt') + parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\silk') parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=6) parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30) args = parser.parse_args() @@ -55,8 +55,7 @@ if __name__ == '__main__': for line in f.readlines(): processed_files.add(line.strip()) - files = set(find_files_of_type(None, args.path, qualifier=lambda p: p.endswith('.flac'))[0]) - #files = set(find_audio_files(args.path, include_nonwav=True)) + files = set(find_audio_files(args.path, include_nonwav=True)) orig_len = len(files) files = files - processed_files print(f"Found {len(files)} files to process. Total processing is {100*(orig_len-len(files))/orig_len}% complete.")