misc updates

This commit is contained in:
James Betker 2022-05-19 13:39:32 -06:00
parent 10378fc37f
commit c9c16e3b01
8 changed files with 100 additions and 17 deletions

View File

@ -540,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
class ContrastiveTrainingWrapper(nn.Module):
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100,
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
codebook_size=320, codebook_groups=2,
**kwargs):

View File

@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from trainer.networks import register_model
@ -51,6 +52,7 @@ class BasicBlock(nn.Module):
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
@ -80,6 +82,7 @@ class BottleNeck(nn.Module):
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=100):
@ -137,11 +140,81 @@ class ResNet(nn.Module):
return output
class SymbolicLoss:
def __init__(self, category_depths=[3,5,5,3], convergence_weighting=[1,.6,.3,.1], divergence_weighting=[.1,.3,.6,1]):
self.depths = category_depths
self.total_classes = 1
for c in category_depths:
self.total_classes *= c
self.elements_per_level = []
m = 1
for c in category_depths[1:]:
m *= c
self.elements_per_level.append(self.total_classes // m)
self.elements_per_level = self.elements_per_level + [1]
self.convergence_weighting = convergence_weighting
self.divergence_weighting = divergence_weighting
# TODO: improve the above logic, I'm sure it can be done better.
def __call__(self, logits, collaboratorLabels):
"""
Computes the symbolic loss.
:param logits: Nested level scores for the network under training.
:param collaboratorLabels: level labels from the collaborator network.
:return: Convergence loss & divergence loss.
"""
b, l = logits.shape
assert l == self.total_classes, f"Expected {self.total_classes} predictions, got {l}"
convergence_loss = 0
divergence_loss = 0
for epc, cw, dw in zip(self.elements_per_level, self.convergence_weighting, self.divergence_weighting):
level_logits = logits.view(b, l//epc, epc)
level_logits = level_logits.sum(dim=-1)
level_labels = collaboratorLabels.div(epc, rounding_mode='trunc')
# Convergence
convergence_loss = convergence_loss + F.cross_entropy(level_logits, level_labels) * cw
# Divergence
div_label_indices = level_logits.argmax(dim=-1)
# TODO: find the torch-y way of doing this.
dp = []
for bi, i in enumerate(div_label_indices):
dp.append(level_logits[:, i])
div_preds = torch.stack(dp, dim=0)
div_labels = torch.arange(0, b, device=logits.device)
divergence_loss = divergence_loss + F.cross_entropy(div_preds, div_labels)
return convergence_loss, divergence_loss
if __name__ == '__main__':
sl = SymbolicLoss()
logits = torch.randn(5, sl.total_classes)
labels = torch.randint(0, sl.total_classes, (5,))
sl(logits, labels)
class TwinnedCifar(nn.Module):
def __init__(self):
super().__init__()
self.loss = SymbolicLoss()
self.netA = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes)
self.netB = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes)
def forward(self, x):
y1 = self.netA(x)
y2 = self.netB(x)
b = x.shape[0]
convergenceA, divergenceA = self.loss(y1[:b//2], y2.argmax(dim=-1)[:b//2])
convergenceB, divergenceB = self.loss(y2[b//2:], y1.argmax(dim=-1)[b//2:])
return convergenceA + convergenceB, divergenceA + divergenceB
@register_model
def register_cifar_resnet18(opt_net, opt):
def register_twin_cifar(opt_net, opt):
""" return a ResNet 18 object
"""
return ResNet(BasicBlock, [2, 2, 2, 2])
return TwinnedCifar()
def resnet34():
""" return a ResNet 34 object

View File

@ -0,0 +1,8 @@
My custom demucs library is used for batch source separation:
https://github.com/neonbjb/demucs
```
conda activate demucs
python setup.py install
CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda
```

View File

@ -26,6 +26,7 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
try:
audio = load_audio(file, sampling_rate)
except:
print(f"Error loading file {file}")
report_progress(progress_file, file)
return
@ -52,9 +53,9 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music2')
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\yt-music-1\\already_processed.txt')
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\music\\bigdump')
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux')
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt')
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100')
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
args = parser.parse_args()

View File

@ -34,17 +34,14 @@ if __name__ == '__main__':
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
file, with a hard-coded set of modifications.
"""
base_opt = '../experiments/train_diffusion_tts9_sweep.yml'
base_opt = '../experiments/sweep_music_mel2vec.yml'
modifications = {
'baseline': {},
'more_filters': {'networks': {'generator': {'kwargs': {'model_channels': 96}}}},
'more_kern': {'networks': {'generator': {'kwargs': {'kernel_size': 5}}}},
'less_heads': {'networks': {'generator': {'kwargs': {'num_heads': 2}}}},
'eff_off': {'networks': {'generator': {'kwargs': {'efficient_convs': False}}}},
'more_time': {'networks': {'generator': {'kwargs': {'time_embed_dim_multiplier': 8}}}},
'scale_shift_off': {'networks': {'generator': {'kwargs': {'use_scale_shift_norm': False}}}},
'shallow_res': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 2, 2]}}}},
'lr1e3': {'steps': {'generator': {'optimizer_params': {'lr': {.001}}}}},
'lr1e5': {'steps': {'generator': {'optimizer_params': {'lr': {.00001}}}}},
'no_warmup': {'train': {'warmup_steps': 0}},
}
base_rank = 4
opt = option.parse(base_opt, is_train=True)
all_opts = []
for i, (mod, mod_dict) in enumerate(modifications.items()):
@ -65,4 +62,4 @@ if __name__ == '__main__':
break
else:
rank = 0
launch_trainer(all_opts[rank], base_opt, rank)
launch_trainer(all_opts[rank], base_opt, rank+base_rank)

View File

@ -219,10 +219,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
if __name__ == '__main__':
diffusion = load_model_from_config('D:\\dlas\\options\\train_music_waveform_gen3.yml', 'generator',
also_load_savepoint=False,
load_path='D:\\dlas\\experiments\\train_music_waveform_gen\\models\\59000_generator_ema.pth').cuda()
load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\75500_generator_ema.pth').cuda()
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 400,
'conditioning_free': False, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 20, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 23, 'device': 'cuda', 'opt': {}}
eval = MusicDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())

View File

@ -567,6 +567,10 @@ def load_audio(audiopath, sampling_rate, raw_data=None):
else:
if audiopath[-4:] == '.wav':
audio, lsr = load_wav_to_torch(audiopath)
elif audiopath[-5:] == '.flac':
import soundfile as sf
audio, lsr = sf.read(audiopath)
audio = torch.FloatTensor(audio)
else:
audio, lsr = open_audio(audiopath)
audio = torch.FloatTensor(audio)