forked from mrq/DL-Art-School
ressurect ctc code gen with some cool new ideas
This commit is contained in:
parent
65b441d74e
commit
48aab2babe
|
@ -542,7 +542,6 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||||
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
|
idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups)
|
||||||
return idxs
|
return idxs
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, mask_time_indices=None):
|
def forward(self, hidden_states, mask_time_indices=None):
|
||||||
batch_size, sequence_length, hidden_size = hidden_states.shape
|
batch_size, sequence_length, hidden_size = hidden_states.shape
|
||||||
|
|
||||||
|
@ -660,6 +659,14 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
codes = self.quantizer.get_codes(proj)
|
codes = self.quantizer.get_codes(proj)
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
|
def reconstruct(self, mel):
|
||||||
|
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||||
|
_, proj = self.m2v.projector(proj)
|
||||||
|
quantized_features, codevector_perplexity = self.quantizer(proj)
|
||||||
|
quantized_features = self.project_q(quantized_features)
|
||||||
|
reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1))
|
||||||
|
return reconstruction
|
||||||
|
|
||||||
def forward(self, mel, inp_lengths=None):
|
def forward(self, mel, inp_lengths=None):
|
||||||
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math.
|
||||||
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult)
|
||||||
|
|
121
codes/models/audio/tts/ctc_code_generator.py
Normal file
121
codes/models/audio/tts/ctc_code_generator.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
from random import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer
|
||||||
|
from models.lucidrains.x_transformers import Encoder
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointedXTransformerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||||
|
to channels-last that XTransformer expects.
|
||||||
|
"""
|
||||||
|
def __init__(self, **xtransformer_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.transformer = XTransformer(**xtransformer_kwargs)
|
||||||
|
|
||||||
|
for xform in [self.transformer.encoder, self.transformer.decoder.net]:
|
||||||
|
for i in range(len(xform.attn_layers.layers)):
|
||||||
|
n, b, r = xform.attn_layers.layers[i]
|
||||||
|
xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r])
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.transformer(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CtcCodeGenerator(nn.Module):
|
||||||
|
def __init__(self, model_dim=512, layers=10, max_length=2048, dropout=.1, ctc_codes=256, max_pad=120, max_repeat=30):
|
||||||
|
super().__init__()
|
||||||
|
self.max_pad = max_pad
|
||||||
|
self.max_repeat = max_repeat
|
||||||
|
self.ctc_codes = ctc_codes
|
||||||
|
pred_codes = (max_pad+1)*(max_repeat+1)
|
||||||
|
|
||||||
|
self.position_embedding = nn.Embedding(max_length, model_dim)
|
||||||
|
self.codes_embedding = nn.Embedding(ctc_codes, model_dim)
|
||||||
|
self.recursive_embedding = nn.Embedding(pred_codes, model_dim)
|
||||||
|
self.mask_embedding = nn.Parameter(torch.randn(model_dim))
|
||||||
|
self.encoder = Encoder(
|
||||||
|
dim=model_dim,
|
||||||
|
depth=layers,
|
||||||
|
heads=model_dim//64,
|
||||||
|
ff_dropout=dropout,
|
||||||
|
attn_dropout=dropout,
|
||||||
|
use_rmsnorm=True,
|
||||||
|
ff_glu=True,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
)
|
||||||
|
self.pred_head = nn.Linear(model_dim, pred_codes)
|
||||||
|
self.confidence_head = nn.Linear(model_dim, 1)
|
||||||
|
|
||||||
|
def inference(self, codes, pads, repeats):
|
||||||
|
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||||
|
codes_h = self.codes_embedding(codes)
|
||||||
|
|
||||||
|
labels = pads + repeats * self.max_pad
|
||||||
|
mask = labels == 0
|
||||||
|
recursive_h = self.recursive_embedding(labels)
|
||||||
|
recursive_h[mask] = self.mask_embedding
|
||||||
|
|
||||||
|
h = self.encoder(position_h + codes_h + recursive_h)
|
||||||
|
pred_logits = self.pred_head(h)
|
||||||
|
confidences = self.confidence_head(h).squeeze(-1)
|
||||||
|
confidences = F.softmax(confidences * mask, dim=-1)
|
||||||
|
return pred_logits, confidences
|
||||||
|
|
||||||
|
def forward(self, codes, pads, repeats, unpadded_lengths):
|
||||||
|
if unpadded_lengths is not None:
|
||||||
|
max_len = unpadded_lengths.max()
|
||||||
|
codes = codes[:, :max_len]
|
||||||
|
pads = pads[:, :max_len]
|
||||||
|
repeats = repeats[:, :max_len]
|
||||||
|
|
||||||
|
if pads.max() > self.max_pad:
|
||||||
|
print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}")
|
||||||
|
pads = torch.clip(pads, 0, self.max_pad)
|
||||||
|
if repeats.max() > self.max_repeat:
|
||||||
|
print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}")
|
||||||
|
repeats = torch.clip(repeats, 0, self.max_repeat)
|
||||||
|
assert codes.max() < self.ctc_codes, codes.max()
|
||||||
|
|
||||||
|
labels = pads + repeats * self.max_pad
|
||||||
|
|
||||||
|
position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device))
|
||||||
|
codes_h = self.codes_embedding(codes)
|
||||||
|
recursive_h = self.recursive_embedding(labels)
|
||||||
|
|
||||||
|
mask_prob = random()
|
||||||
|
mask = torch.rand_like(labels.float()) > mask_prob
|
||||||
|
for b in range(codes.shape[0]):
|
||||||
|
mask[b, unpadded_lengths[b]:] = False
|
||||||
|
recursive_h[mask.logical_not()] = self.mask_embedding
|
||||||
|
|
||||||
|
h = self.encoder(position_h + codes_h + recursive_h)
|
||||||
|
pred_logits = self.pred_head(h)
|
||||||
|
loss = F.cross_entropy(pred_logits.permute(0,2,1), labels, reduce=False)
|
||||||
|
|
||||||
|
confidences = self.confidence_head(h).squeeze(-1)
|
||||||
|
confidences = F.softmax(confidences * mask, dim=-1)
|
||||||
|
confidence_loss = loss * confidences
|
||||||
|
loss = loss / loss.shape[-1] # This balances the confidence_loss and loss.
|
||||||
|
|
||||||
|
return loss.mean(), confidence_loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_ctc_code_generator(opt_net, opt):
|
||||||
|
return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = CtcCodeGenerator()
|
||||||
|
inps = torch.randint(0,36, (4, 300))
|
||||||
|
pads = torch.randint(0,100, (4,300))
|
||||||
|
repeats = torch.randint(0,20, (4,300))
|
||||||
|
loss = model(inps, pads, repeats)
|
||||||
|
print(loss.shape)
|
63
codes/scripts/audio/gen/ctc_codes.py
Normal file
63
codes/scripts/audio/gen/ctc_codes.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from itertools import groupby
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Wav2Vec2CTCTokenizer
|
||||||
|
|
||||||
|
from models.audio.tts.ctc_code_generator import CtcCodeGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def get_ctc_metadata(codes):
|
||||||
|
if isinstance(codes, torch.Tensor):
|
||||||
|
codes = codes.tolist()
|
||||||
|
grouped = groupby(codes)
|
||||||
|
rcodes, repeats, pads = [], [], [0]
|
||||||
|
for val, group in grouped:
|
||||||
|
if val == 0:
|
||||||
|
pads[-1] = len(list(
|
||||||
|
group)) # This is a very important distinction! It means the padding belongs to the character proceeding it.
|
||||||
|
else:
|
||||||
|
rcodes.append(val)
|
||||||
|
repeats.append(len(list(group)))
|
||||||
|
pads.append(0)
|
||||||
|
|
||||||
|
rcodes = torch.tensor(rcodes)
|
||||||
|
# These clip values are sane maximum values which I did not see in the datasets I have access to.
|
||||||
|
repeats = torch.clip(torch.tensor(repeats), min=1, max=30)
|
||||||
|
pads = torch.clip(torch.tensor(pads[:-1]), max=120)
|
||||||
|
|
||||||
|
return rcodes, pads, repeats
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = CtcCodeGenerator(model_dim=512, layers=16, dropout=0).eval().cuda()
|
||||||
|
model.load_state_dict(torch.load('../experiments/train_encoder_build_ctc_alignments_toy/models/76000_generator_ema.pth'))
|
||||||
|
|
||||||
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
|
||||||
|
text = "and now, what do you want."
|
||||||
|
seq = [0, 0, 0, 38, 51, 51, 41, 11, 11, 51, 51, 0, 0, 0, 0, 52, 0, 60, 0, 0, 0, 0, 0, 0, 6, 11, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60, 45, 0, 38, 57, 57, 11, 0, 41, 52, 52, 11, 11, 62, 52, 52, 58, 0, 11, 11, 60, 0, 0, 0, 0, 38, 0, 0, 51, 51, 0, 0, 57, 0, 0, 7, 7, 0, 0, 0]
|
||||||
|
codes, pads, repeats = get_ctc_metadata(seq)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
codes = codes.cuda().unsqueeze(0)
|
||||||
|
pads = pads.cuda().unsqueeze(0)
|
||||||
|
repeats = repeats.cuda().unsqueeze(0)
|
||||||
|
|
||||||
|
ppads = pads.clone()
|
||||||
|
prepeats = repeats.clone()
|
||||||
|
mask = torch.zeros_like(pads)
|
||||||
|
conf_str = tokenizer.decode(codes[0].tolist())
|
||||||
|
for s in range(codes.shape[-1]):
|
||||||
|
logits, confidences = model.inference(codes, pads * mask, repeats * mask)
|
||||||
|
|
||||||
|
confidences = confidences * mask.logical_not() # prevent prediction of tokens that have already been predicted.
|
||||||
|
i = confidences.argmax(dim=-1)
|
||||||
|
pred = logits[0,i].argmax()
|
||||||
|
|
||||||
|
pred_pads = pred % model.max_pad
|
||||||
|
pred_repeats = pred // model.max_pad
|
||||||
|
ppads[0,i] = pred_pads
|
||||||
|
prepeats[0,i] = pred_repeats
|
||||||
|
mask[0,i] = 1
|
||||||
|
|
||||||
|
conf_str = conf_str[:i] + conf_str[i].upper() + conf_str[i+1:]
|
||||||
|
print(f"conf: {conf_str} pads={pred_pads}:{pads[0,i].item()} repeats={pred_repeats}:{repeats[0,i].item()}")
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
|
|
||||||
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
from models.audio.mel2vec import ContrastiveTrainingWrapper
|
||||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, normalize_mel
|
||||||
from utils.util import load_audio
|
from utils.util import load_audio
|
||||||
|
|
||||||
def collapse_codegroups(codes):
|
def collapse_codegroups(codes):
|
||||||
|
@ -24,17 +25,22 @@ def recover_codegroups(codes, groups):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0,
|
||||||
mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True)
|
mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4,
|
||||||
model.load_state_dict(torch.load("../experiments/m2v_music.pth"))
|
disable_custom_linear_init=True, feature_producer_type='standard',
|
||||||
|
freq_mask_percent=0, do_reconstruction_loss=True)
|
||||||
|
model.load_state_dict(torch.load("../experiments/m2v_music2.pth"))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050)
|
||||||
mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out']
|
mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||||
|
'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out']
|
||||||
codes = model.get_codes(mel)
|
codes = model.get_codes(mel)
|
||||||
codes2 = model.get_codes(mel)
|
reconstruction = model.reconstruct(mel)
|
||||||
|
|
||||||
|
torchvision.utils.save_image((normalize_mel(mel).unsqueeze(1)+1)/2, 'mel.png')
|
||||||
|
torchvision.utils.save_image((normalize_mel(reconstruction).unsqueeze(1)+1)/2, 'reconstructed.png')
|
||||||
|
|
||||||
collapsed = collapse_codegroups(codes)
|
collapsed = collapse_codegroups(codes)
|
||||||
recovered = recover_codegroups(collapsed, 8)
|
recovered = recover_codegroups(collapsed, 4)
|
||||||
|
|
||||||
print(codes)
|
print(codes)
|
|
@ -332,7 +332,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_flat.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -69,7 +69,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
self.diffusion_fn = self.perform_diffusion_from_codes
|
self.diffusion_fn = self.perform_diffusion_from_codes
|
||||||
self.local_modules['codegen'] = get_music_codegen()
|
self.local_modules['codegen'] = get_music_codegen()
|
||||||
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000,
|
||||||
'normalize': True, 'do_normalization': True, 'in': 'in', 'out': 'out'}, {})
|
'normalize': True, 'in': 'in', 'out': 'out'}, {})
|
||||||
|
|
||||||
def load_data(self, path):
|
def load_data(self, path):
|
||||||
return list(glob(f'{path}/*.wav'))
|
return list(glob(f'{path}/*.wav'))
|
||||||
|
@ -86,7 +86,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
model_kwargs={'aligned_conditioning': mel})
|
model_kwargs={'aligned_conditioning': mel})
|
||||||
gen = pixel_shuffle_1d(gen, 16)
|
gen = pixel_shuffle_1d(gen, 16)
|
||||||
|
|
||||||
return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate
|
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
||||||
|
|
||||||
def gen_freq_gap(self, mel, band_range=(60,100)):
|
def gen_freq_gap(self, mel, band_range=(60,100)):
|
||||||
gap_start, gap_end = band_range
|
gap_start, gap_end = band_range
|
||||||
|
|
Loading…
Reference in New Issue
Block a user