misc updates
This commit is contained in:
parent
10378fc37f
commit
c9c16e3b01
|
@ -540,7 +540,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ContrastiveTrainingWrapper(nn.Module):
|
class ContrastiveTrainingWrapper(nn.Module):
|
||||||
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.5, mask_time_length=6, num_negatives=100,
|
def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=6, num_negatives=100,
|
||||||
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
max_gumbel_temperature=2.0, min_gumbel_temperature=.5, gumbel_temperature_decay=.999995,
|
||||||
codebook_size=320, codebook_groups=2,
|
codebook_size=320, codebook_groups=2,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
@ -51,6 +52,7 @@ class BasicBlock(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
|
||||||
class BottleNeck(nn.Module):
|
class BottleNeck(nn.Module):
|
||||||
"""Residual block for resnet over 50 layers
|
"""Residual block for resnet over 50 layers
|
||||||
|
|
||||||
|
@ -80,6 +82,7 @@ class BottleNeck(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block, num_block, num_classes=100):
|
def __init__(self, block, num_block, num_classes=100):
|
||||||
|
@ -137,11 +140,81 @@ class ResNet(nn.Module):
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SymbolicLoss:
|
||||||
|
def __init__(self, category_depths=[3,5,5,3], convergence_weighting=[1,.6,.3,.1], divergence_weighting=[.1,.3,.6,1]):
|
||||||
|
self.depths = category_depths
|
||||||
|
self.total_classes = 1
|
||||||
|
for c in category_depths:
|
||||||
|
self.total_classes *= c
|
||||||
|
self.elements_per_level = []
|
||||||
|
m = 1
|
||||||
|
for c in category_depths[1:]:
|
||||||
|
m *= c
|
||||||
|
self.elements_per_level.append(self.total_classes // m)
|
||||||
|
self.elements_per_level = self.elements_per_level + [1]
|
||||||
|
self.convergence_weighting = convergence_weighting
|
||||||
|
self.divergence_weighting = divergence_weighting
|
||||||
|
# TODO: improve the above logic, I'm sure it can be done better.
|
||||||
|
|
||||||
|
def __call__(self, logits, collaboratorLabels):
|
||||||
|
"""
|
||||||
|
Computes the symbolic loss.
|
||||||
|
:param logits: Nested level scores for the network under training.
|
||||||
|
:param collaboratorLabels: level labels from the collaborator network.
|
||||||
|
:return: Convergence loss & divergence loss.
|
||||||
|
"""
|
||||||
|
b, l = logits.shape
|
||||||
|
assert l == self.total_classes, f"Expected {self.total_classes} predictions, got {l}"
|
||||||
|
|
||||||
|
convergence_loss = 0
|
||||||
|
divergence_loss = 0
|
||||||
|
for epc, cw, dw in zip(self.elements_per_level, self.convergence_weighting, self.divergence_weighting):
|
||||||
|
level_logits = logits.view(b, l//epc, epc)
|
||||||
|
level_logits = level_logits.sum(dim=-1)
|
||||||
|
level_labels = collaboratorLabels.div(epc, rounding_mode='trunc')
|
||||||
|
# Convergence
|
||||||
|
convergence_loss = convergence_loss + F.cross_entropy(level_logits, level_labels) * cw
|
||||||
|
# Divergence
|
||||||
|
div_label_indices = level_logits.argmax(dim=-1)
|
||||||
|
# TODO: find the torch-y way of doing this.
|
||||||
|
dp = []
|
||||||
|
for bi, i in enumerate(div_label_indices):
|
||||||
|
dp.append(level_logits[:, i])
|
||||||
|
div_preds = torch.stack(dp, dim=0)
|
||||||
|
div_labels = torch.arange(0, b, device=logits.device)
|
||||||
|
divergence_loss = divergence_loss + F.cross_entropy(div_preds, div_labels)
|
||||||
|
return convergence_loss, divergence_loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sl = SymbolicLoss()
|
||||||
|
logits = torch.randn(5, sl.total_classes)
|
||||||
|
labels = torch.randint(0, sl.total_classes, (5,))
|
||||||
|
sl(logits, labels)
|
||||||
|
|
||||||
|
|
||||||
|
class TwinnedCifar(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss = SymbolicLoss()
|
||||||
|
self.netA = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes)
|
||||||
|
self.netB = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=self.loss.total_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y1 = self.netA(x)
|
||||||
|
y2 = self.netB(x)
|
||||||
|
b = x.shape[0]
|
||||||
|
convergenceA, divergenceA = self.loss(y1[:b//2], y2.argmax(dim=-1)[:b//2])
|
||||||
|
convergenceB, divergenceB = self.loss(y2[b//2:], y1.argmax(dim=-1)[b//2:])
|
||||||
|
return convergenceA + convergenceB, divergenceA + divergenceB
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_cifar_resnet18(opt_net, opt):
|
def register_twin_cifar(opt_net, opt):
|
||||||
""" return a ResNet 18 object
|
""" return a ResNet 18 object
|
||||||
"""
|
"""
|
||||||
return ResNet(BasicBlock, [2, 2, 2, 2])
|
return TwinnedCifar()
|
||||||
|
|
||||||
def resnet34():
|
def resnet34():
|
||||||
""" return a ResNet 34 object
|
""" return a ResNet 34 object
|
||||||
|
|
8
codes/scripts/audio/prep_music/demucs_notes.txt
Normal file
8
codes/scripts/audio/prep_music/demucs_notes.txt
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
My custom demucs library is used for batch source separation:
|
||||||
|
https://github.com/neonbjb/demucs
|
||||||
|
|
||||||
|
```
|
||||||
|
conda activate demucs
|
||||||
|
python setup.py install
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python -m demucs /y/split/bt-music-5 --out=/y/separated/bt-music-5 --num_workers=2 --device cuda
|
||||||
|
```
|
|
@ -26,6 +26,7 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
|
||||||
try:
|
try:
|
||||||
audio = load_audio(file, sampling_rate)
|
audio = load_audio(file, sampling_rate)
|
||||||
except:
|
except:
|
||||||
|
print(f"Error loading file {file}")
|
||||||
report_progress(progress_file, file)
|
report_progress(progress_file, file)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -52,9 +53,9 @@ def process_file(file, base_path, output_path, progress_file, duration_per_clip,
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\sources\\music\\bt-music2')
|
parser.add_argument('-path', type=str, help='Path to search for files', default='Y:\\\slakh2100_flac_redux')
|
||||||
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\sources\\yt-music-1\\already_processed.txt')
|
parser.add_argument('-progress_file', type=str, help='Place to store all files that have already been processed', default='Y:\\\slakh2100_flac_redux\\already_processed.txt')
|
||||||
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\music\\bigdump')
|
parser.add_argument('-output_path', type=str, help='Path for output files', default='Y:\\split\\\slakh2100')
|
||||||
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
|
parser.add_argument('-num_threads', type=int, help='Number of concurrent workers processing files.', default=8)
|
||||||
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
|
parser.add_argument('-duration', type=int, help='Duration per clip in seconds', default=30)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -34,17 +34,14 @@ if __name__ == '__main__':
|
||||||
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
Ad-hoc script (hard coded; no command-line parameters) that spawns multiple separate trainers from a single options
|
||||||
file, with a hard-coded set of modifications.
|
file, with a hard-coded set of modifications.
|
||||||
"""
|
"""
|
||||||
base_opt = '../experiments/train_diffusion_tts9_sweep.yml'
|
base_opt = '../experiments/sweep_music_mel2vec.yml'
|
||||||
modifications = {
|
modifications = {
|
||||||
'baseline': {},
|
'baseline': {},
|
||||||
'more_filters': {'networks': {'generator': {'kwargs': {'model_channels': 96}}}},
|
'lr1e3': {'steps': {'generator': {'optimizer_params': {'lr': {.001}}}}},
|
||||||
'more_kern': {'networks': {'generator': {'kwargs': {'kernel_size': 5}}}},
|
'lr1e5': {'steps': {'generator': {'optimizer_params': {'lr': {.00001}}}}},
|
||||||
'less_heads': {'networks': {'generator': {'kwargs': {'num_heads': 2}}}},
|
'no_warmup': {'train': {'warmup_steps': 0}},
|
||||||
'eff_off': {'networks': {'generator': {'kwargs': {'efficient_convs': False}}}},
|
|
||||||
'more_time': {'networks': {'generator': {'kwargs': {'time_embed_dim_multiplier': 8}}}},
|
|
||||||
'scale_shift_off': {'networks': {'generator': {'kwargs': {'use_scale_shift_norm': False}}}},
|
|
||||||
'shallow_res': {'networks': {'generator': {'kwargs': {'num_res_blocks': [1, 1, 1, 1, 1, 2, 2]}}}},
|
|
||||||
}
|
}
|
||||||
|
base_rank = 4
|
||||||
opt = option.parse(base_opt, is_train=True)
|
opt = option.parse(base_opt, is_train=True)
|
||||||
all_opts = []
|
all_opts = []
|
||||||
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
for i, (mod, mod_dict) in enumerate(modifications.items()):
|
||||||
|
@ -65,4 +62,4 @@ if __name__ == '__main__':
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
launch_trainer(all_opts[rank], base_opt, rank)
|
launch_trainer(all_opts[rank], base_opt, rank+base_rank)
|
||||||
|
|
|
@ -219,10 +219,10 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('D:\\dlas\\options\\train_music_waveform_gen3.yml', 'generator',
|
diffusion = load_model_from_config('D:\\dlas\\options\\train_music_waveform_gen3.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='D:\\dlas\\experiments\\train_music_waveform_gen\\models\\59000_generator_ema.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_music_waveform_gen\\models\\75500_generator_ema.pth').cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 400,
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', 'diffusion_steps': 400,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'spec_decode'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 20, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 23, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
|
@ -567,6 +567,10 @@ def load_audio(audiopath, sampling_rate, raw_data=None):
|
||||||
else:
|
else:
|
||||||
if audiopath[-4:] == '.wav':
|
if audiopath[-4:] == '.wav':
|
||||||
audio, lsr = load_wav_to_torch(audiopath)
|
audio, lsr = load_wav_to_torch(audiopath)
|
||||||
|
elif audiopath[-5:] == '.flac':
|
||||||
|
import soundfile as sf
|
||||||
|
audio, lsr = sf.read(audiopath)
|
||||||
|
audio = torch.FloatTensor(audio)
|
||||||
else:
|
else:
|
||||||
audio, lsr = open_audio(audiopath)
|
audio, lsr = open_audio(audiopath)
|
||||||
audio = torch.FloatTensor(audio)
|
audio = torch.FloatTensor(audio)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user