From c9c16e3b010574f9ac51ad77519ae66cea349d37 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 19 May 2022 13:39:32 -0600 Subject: [PATCH] misc updates --- codes/models/audio/mel2vec.py | 2 +- codes/models/classifiers/twin_cifar_resnet.py | 77 ++++++++++++++++++- .../scripts/audio/prep_music/demucs_notes.txt | 8 ++ .../audio/prep_music/mt3_transcribe.py | 0 .../audio/prep_music/phase_1_split_files.py | 7 +- codes/sweep.py | 15 ++-- codes/trainer/eval/music_diffusion_fid.py | 4 +- codes/utils/util.py | 4 + 8 files changed, 100 insertions(+), 17 deletions(-) create mode 100644 codes/scripts/audio/prep_music/demucs_notes.txt delete mode 100644 codes/scripts/audio/prep_music/mt3_transcribe.py diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index e02768af..6eea654a 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -540,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): - def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100, + def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100, max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995, codebook_size=320, codebook_groups=2, **kwargs): diff --git a/codes/models/classifiers/twin_cifar_resnet.py b/codes/models/classifiers/twin_cifar_resnet.py index ceb78064..6aa1f938 100644 --- a/codes/models/classifiers/twin_cifar_resnet.py +++ b/codes/models/classifiers/twin_cifar_resnet.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from trainer.networks import register_model @@ -51,6 +52,7 @@ class BasicBlock(nn.Module): def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + class BottleNeck(nn.Module): """Residual block for resnet over 50 layers @@ -80,6 +82,7 @@ class BottleNeck(nn.Module): def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) + class ResNet(nn.Module): def __init__(self, block, num_block, num_classes=100): @@ -137,11 +140,81 @@ class ResNet(nn.Module): return output + +class SymbolicLoss: + def __init__(self, category_depths=[3,5,5,3], convergence_weighting=[1,.6,.3,.1], divergence_weighting=[.1,.3,.6,1]): + self.depths = category_depths + self.total_classes = 1 + for c in category_depths: + self.total_classes *= c + self.elements_per_level = [] + m = 1 + for c in category_depths[1:]: + m *= c + self.elements_per_level.append(self.total_classes // m) + self.elements_per_level = self.elements_per_level + [1] + self.convergence_weighting = convergence_weighting + self.divergence_weighting = divergence_weighting + # TODO: improve the above logic, I'm sure it can be done better. + + def __call__(self, logits, collaboratorLabels): + """ + Computes the symbolic loss. + :param logits: Nested level scores for the network under training. + :param collaboratorLabels: level labels from the collaborator network. + :return: Convergence loss & divergence loss. + """ + b, l = logits.shape + assert l == self.total_classes, f"Expected {self.total_classes} predictions, got {l}" + + convergence_loss = 0 + divergence_loss = 0 + for epc, cw, dw in zip(self.elements_per_level, self.convergence_weighting, self.divergence_weighting): + level_logits = logits.view(b, l//epc, epc) + level_logits = level_logits.sum(dim=-1) + level_labels = collaboratorLabels.div(epc, rounding_mode='trunc') + # Convergence + convergence_loss = convergence_loss + F.cross_entropy(level_logits, level_labels) * cw + # Divergence + div_label_indices = level_logits.argmax(dim=-1) + # TODO: find the torch-y way of doing this. + dp = [] + for bi, i in enumerate(div_label_indices): + dp.append(level_logits[:, i]) + div_preds = torch.stack(dp, dim=0) + div_labels = torch.arange(0, b, device=logits.device) + divergence_loss = divergence_loss + F.cross_entropy(div_preds, div_labels) + return convergence_loss, divergence_loss + + +if __name__ == '__main__': + sl = SymbolicLoss() + logits = torch.randn(5, sl.total_classes) + labels = torch.randint(0, sl.total_classes, (5,)) + sl(logits, labels) + + +class TwinnedCifar(nn.Module): + def __init__(self): + super().__init__() + self.loss = SymbolicLoss() + self.netA = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) + self.netB = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes) + + def forward(self, x): + y1 = self.netA(x) + y2 = self.netB(x) + b = x.shape[0] + convergenceA, divergenceA = self.loss(y1[:b//2], y2.argmax(dim=-1)[:b//2]) + convergenceB, divergenceB = self.loss(y2[b//2:], y1.argmax(dim=-1)[b//2:]) + return convergenceA + convergenceB, divergenceA + divergenceB + + @register_model -def register_cifar_resnet18(opt_net, opt): +def register_twin_cifar(opt_net, opt): """ return a ResNet 18 object """ - return ResNet(BasicBlock, [2, 2, 2, 2]) + return TwinnedCifar() def resnet34(): """ return a ResNet 34 object diff --git a/codes/scripts/audio/prep_music/demucs_notes.txt b/codes/scripts/audio/prep_music/demucs_notes.txt new file mode 100644 index 00000000..1c37b7df --- /dev/null +++ b/codes/scripts/audio/prep_music/demucs_notes.txt @@ -0,0 +1,8 @@ +My custom demucs library is used for batch source separation: +https://github.com/neonbjb/demucs + +``` +conda activate demucs +python setup.py install +CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda +``` \ No newline at end of file diff --git a/codes/scripts/audio/prep_music/mt3_transcribe.py b/codes/scripts/audio/prep_music/mt3_transcribe.py deleted file mode 100644 index e69de29b..00000000 diff --git a/codes/scripts/audio/prep_music/phase_1_split_files.py b/codes/scripts/audio/prep_music/phase_1_split_files.py index 5e85cbaf..29bc96f3 100644 --- a/codes/scripts/audio/prep_music/phase_1_split_files.py +++ b/codes/scripts/audio/prep_music/phase_1_split_files.py @@ -26,6 +26,7 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip, try: audio = load_audio(file, sampling_rate) except: + print(f"Error loading file {file}") report_progress(progress_file, file) return @@ -52,9 +53,9 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip, if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music2') - parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\yt-music-1\\already_processed.txt') - parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\music\\bigdump') + parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux') + parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt') + parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100') parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8) parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30) args = parser.parse_args() diff --git a/codes/sweep.py b/codes/sweep.py index dbe80fe5..03a877b1 100644 --- a/codes/sweep.py +++ b/codes/sweep.py @@ -34,17 +34,14 @@ if __name__ == '__main__': Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options file, with a hard-coded set of modifications. """ - base_opt = '../experiments/train_diffusion_tts9_sweep.yml' + base_opt = '../experiments/sweep_music_mel2vec.yml' modifications = { 'baseline': {}, - 'more_filters': {'networks': {'generator': {'kwargs': {'model_channels': 96}}}}, - 'more_kern': {'networks': {'generator': {'kwargs': {'kernel_size': 5}}}}, - 'less_heads': {'networks': {'generator': {'kwargs': {'num_heads': 2}}}}, - 'eff_off': {'networks': {'generator': {'kwargs': {'efficient_convs': False}}}}, - 'more_time': {'networks': {'generator': {'kwargs': {'time_embed_dim_multiplier': 8}}}}, - 'scale_shift_off': {'networks': {'generator': {'kwargs': {'use_scale_shift_norm': False}}}}, - 'shallow_res': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 2, 2]}}}}, + 'lr1e3': {'steps': {'generator': {'optimizer_params': {'lr': {.001}}}}}, + 'lr1e5': {'steps': {'generator': {'optimizer_params': {'lr': {.00001}}}}}, + 'no_warmup': {'train': {'warmup_steps': 0}}, } + base_rank = 4 opt = option.parse(base_opt, is_train=True) all_opts = [] for i, (mod, mod_dict) in enumerate(modifications.items()): @@ -65,4 +62,4 @@ if __name__ == '__main__': break else: rank = 0 - launch_trainer(all_opts[rank], base_opt, rank) + launch_trainer(all_opts[rank], base_opt, rank+base_rank) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 8f463284..3264018d 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -219,10 +219,10 @@ class MusicDiffusionFid(evaluator.Evaluator): if __name__ == '__main__': diffusion = load_model_from_config('D:\\dlas\\options\\train_music_waveform_gen3.yml', 'generator', also_load_savepoint=False, - load_path='D:\\dlas\\experiments\\train_music_waveform_gen\\models\\59000_generator_ema.pth').cuda() + load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\75500_generator_ema.pth').cuda() opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 400, 'conditioning_free': False, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 20, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 23, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval()) diff --git a/codes/utils/util.py b/codes/utils/util.py index bbd61533..5d309be5 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -567,6 +567,10 @@ def load_audio(audiopath, sampling_rate, raw_data=None): else: if audiopath[-4:] == '.wav': audio, lsr = load_wav_to_torch(audiopath) + elif audiopath[-5:] == '.flac': + import soundfile as sf + audio, lsr = sf.read(audiopath) + audio = torch.FloatTensor(audio) else: audio, lsr = open_audio(audiopath) audio = torch.FloatTensor(audio)