diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 098fafc0..5972c8ac 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -542,7 +542,6 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): idxs = codevector_idx.view(batch_size, sequence_length, self.num_groups) return idxs - def forward(self, hidden_states, mask_time_indices=None): batch_size, sequence_length, hidden_size = hidden_states.shape @@ -660,6 +659,14 @@ class ContrastiveTrainingWrapper(nn.Module): codes = self.quantizer.get_codes(proj) return codes + def reconstruct(self, mel): + proj = self.m2v.input_blocks(mel).permute(0,2,1) + _, proj = self.m2v.projector(proj) + quantized_features, codevector_perplexity = self.quantizer(proj) + quantized_features = self.project_q(quantized_features) + reconstruction = self.reconstruction_net(quantized_features.permute(0,2,1)) + return reconstruction + def forward(self, mel, inp_lengths=None): mel = mel[:, :, :-1] # The MEL computation always pads with 1, throwing off optimal tensor math. features_shape = (mel.shape[0], mel.shape[-1]//self.m2v.dim_reduction_mult) diff --git a/codes/models/audio/tts/ctc_code_generator.py b/codes/models/audio/tts/ctc_code_generator.py new file mode 100644 index 00000000..68905115 --- /dev/null +++ b/codes/models/audio/tts/ctc_code_generator.py @@ -0,0 +1,121 @@ +from random import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.audio.tts.unet_diffusion_tts7 import CheckpointedLayer +from models.lucidrains.x_transformers import Encoder +from trainer.networks import register_model +from utils.util import opt_get + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + def __init__(self, **xtransformer_kwargs): + super().__init__() + self.transformer = XTransformer(**xtransformer_kwargs) + + for xform in [self.transformer.encoder, self.transformer.decoder.net]: + for i in range(len(xform.attn_layers.layers)): + n, b, r = xform.attn_layers.layers[i] + xform.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, *args, **kwargs): + return self.transformer(*args, **kwargs) + + +class CtcCodeGenerator(nn.Module): + def __init__(self, model_dim=512, layers=10, max_length=2048, dropout=.1, ctc_codes=256, max_pad=120, max_repeat=30): + super().__init__() + self.max_pad = max_pad + self.max_repeat = max_repeat + self.ctc_codes = ctc_codes + pred_codes = (max_pad+1)*(max_repeat+1) + + self.position_embedding = nn.Embedding(max_length, model_dim) + self.codes_embedding = nn.Embedding(ctc_codes, model_dim) + self.recursive_embedding = nn.Embedding(pred_codes, model_dim) + self.mask_embedding = nn.Parameter(torch.randn(model_dim)) + self.encoder = Encoder( + dim=model_dim, + depth=layers, + heads=model_dim//64, + ff_dropout=dropout, + attn_dropout=dropout, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ) + self.pred_head = nn.Linear(model_dim, pred_codes) + self.confidence_head = nn.Linear(model_dim, 1) + + def inference(self, codes, pads, repeats): + position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) + codes_h = self.codes_embedding(codes) + + labels = pads + repeats * self.max_pad + mask = labels == 0 + recursive_h = self.recursive_embedding(labels) + recursive_h[mask] = self.mask_embedding + + h = self.encoder(position_h + codes_h + recursive_h) + pred_logits = self.pred_head(h) + confidences = self.confidence_head(h).squeeze(-1) + confidences = F.softmax(confidences * mask, dim=-1) + return pred_logits, confidences + + def forward(self, codes, pads, repeats, unpadded_lengths): + if unpadded_lengths is not None: + max_len = unpadded_lengths.max() + codes = codes[:, :max_len] + pads = pads[:, :max_len] + repeats = repeats[:, :max_len] + + if pads.max() > self.max_pad: + print(f"Got unexpectedly long pads. Max: {pads.max()}, {pads}") + pads = torch.clip(pads, 0, self.max_pad) + if repeats.max() > self.max_repeat: + print(f"Got unexpectedly long repeats. Max: {repeats.max()}, {repeats}") + repeats = torch.clip(repeats, 0, self.max_repeat) + assert codes.max() < self.ctc_codes, codes.max() + + labels = pads + repeats * self.max_pad + + position_h = self.position_embedding(torch.arange(0, codes.shape[-1], device=codes.device)) + codes_h = self.codes_embedding(codes) + recursive_h = self.recursive_embedding(labels) + + mask_prob = random() + mask = torch.rand_like(labels.float()) > mask_prob + for b in range(codes.shape[0]): + mask[b, unpadded_lengths[b]:] = False + recursive_h[mask.logical_not()] = self.mask_embedding + + h = self.encoder(position_h + codes_h + recursive_h) + pred_logits = self.pred_head(h) + loss = F.cross_entropy(pred_logits.permute(0,2,1), labels, reduce=False) + + confidences = self.confidence_head(h).squeeze(-1) + confidences = F.softmax(confidences * mask, dim=-1) + confidence_loss = loss * confidences + loss = loss / loss.shape[-1] # This balances the confidence_loss and loss. + + return loss.mean(), confidence_loss.mean() + + +@register_model +def register_ctc_code_generator(opt_net, opt): + return CtcCodeGenerator(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + model = CtcCodeGenerator() + inps = torch.randint(0,36, (4, 300)) + pads = torch.randint(0,100, (4,300)) + repeats = torch.randint(0,20, (4,300)) + loss = model(inps, pads, repeats) + print(loss.shape) \ No newline at end of file diff --git a/codes/scripts/audio/gen/ctc_codes.py b/codes/scripts/audio/gen/ctc_codes.py new file mode 100644 index 00000000..3f0444cd --- /dev/null +++ b/codes/scripts/audio/gen/ctc_codes.py @@ -0,0 +1,63 @@ +from itertools import groupby + +import torch +from transformers import Wav2Vec2CTCTokenizer + +from models.audio.tts.ctc_code_generator import CtcCodeGenerator + + +def get_ctc_metadata(codes): + if isinstance(codes, torch.Tensor): + codes = codes.tolist() + grouped = groupby(codes) + rcodes, repeats, pads = [], [], [0] + for val, group in grouped: + if val == 0: + pads[-1] = len(list( + group)) # This is a very important distinction! It means the padding belongs to the character proceeding it. + else: + rcodes.append(val) + repeats.append(len(list(group))) + pads.append(0) + + rcodes = torch.tensor(rcodes) + # These clip values are sane maximum values which I did not see in the datasets I have access to. + repeats = torch.clip(torch.tensor(repeats), min=1, max=30) + pads = torch.clip(torch.tensor(pads[:-1]), max=120) + + return rcodes, pads, repeats + + +if __name__ == '__main__': + model = CtcCodeGenerator(model_dim=512, layers=16, dropout=0).eval().cuda() + model.load_state_dict(torch.load('../experiments/train_encoder_build_ctc_alignments_toy/models/76000_generator_ema.pth')) + + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols') + text = "and now, what do you want." + seq = [0, 0, 0, 38, 51, 51, 41, 11, 11, 51, 51, 0, 0, 0, 0, 52, 0, 60, 0, 0, 0, 0, 0, 0, 6, 11, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60, 45, 0, 38, 57, 57, 11, 0, 41, 52, 52, 11, 11, 62, 52, 52, 58, 0, 11, 11, 60, 0, 0, 0, 0, 38, 0, 0, 51, 51, 0, 0, 57, 0, 0, 7, 7, 0, 0, 0] + codes, pads, repeats = get_ctc_metadata(seq) + + with torch.no_grad(): + codes = codes.cuda().unsqueeze(0) + pads = pads.cuda().unsqueeze(0) + repeats = repeats.cuda().unsqueeze(0) + + ppads = pads.clone() + prepeats = repeats.clone() + mask = torch.zeros_like(pads) + conf_str = tokenizer.decode(codes[0].tolist()) + for s in range(codes.shape[-1]): + logits, confidences = model.inference(codes, pads * mask, repeats * mask) + + confidences = confidences * mask.logical_not() # prevent prediction of tokens that have already been predicted. + i = confidences.argmax(dim=-1) + pred = logits[0,i].argmax() + + pred_pads = pred % model.max_pad + pred_repeats = pred // model.max_pad + ppads[0,i] = pred_pads + prepeats[0,i] = pred_repeats + mask[0,i] = 1 + + conf_str = conf_str[:i] + conf_str[i].upper() + conf_str[i+1:] + print(f"conf: {conf_str} pads={pred_pads}:{pads[0,i].item()} repeats={pred_repeats}:{repeats[0,i].item()}") \ No newline at end of file diff --git a/codes/scripts/audio/gen/use_mel2vec_codes.py b/codes/scripts/audio/gen/use_mel2vec_codes.py index 74b8a783..f2d576de 100644 --- a/codes/scripts/audio/gen/use_mel2vec_codes.py +++ b/codes/scripts/audio/gen/use_mel2vec_codes.py @@ -1,7 +1,8 @@ import torch +import torchvision from models.audio.mel2vec import ContrastiveTrainingWrapper -from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector +from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector, normalize_mel from utils.util import load_audio def collapse_codegroups(codes): @@ -24,17 +25,22 @@ def recover_codegroups(codes, groups): if __name__ == '__main__': model = ContrastiveTrainingWrapper(mel_input_channels=256, inner_dim=1024, layers=24, dropout=0, mask_time_prob=0, - mask_time_length=6, num_negatives=100, codebook_size=8, codebook_groups=8, disable_custom_linear_init=True) - model.load_state_dict(torch.load("../experiments/m2v_music.pth")) + mask_time_length=6, num_negatives=100, codebook_size=16, codebook_groups=4, + disable_custom_linear_init=True, feature_producer_type='standard', + freq_mask_percent=0, do_reconstruction_loss=True) + model.load_state_dict(torch.load("../experiments/m2v_music2.pth")) model.eval() wav = load_audio("Y:/separated/bt-music-1/100 Hits - Running Songs 2014 CD 2/100 Hits - Running Songs 2014 Cd2 - 02 - 7Th Heaven - Ain't Nothin' Goin' On But The Rent/00001/no_vocals.wav", 22050) - mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 22000, 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out'] - + mel = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, + 'normalize': True, 'in': 'in', 'out': 'out'}, {})({'in': wav.unsqueeze(0)})['out'] codes = model.get_codes(mel) - codes2 = model.get_codes(mel) + reconstruction = model.reconstruct(mel) + + torchvision.utils.save_image((normalize_mel(mel).unsqueeze(1)+1)/2, 'mel.png') + torchvision.utils.save_image((normalize_mel(reconstruction).unsqueeze(1)+1)/2, 'reconstructed.png') collapsed = collapse_codegroups(codes) - recovered = recover_codegroups(collapsed, 8) + recovered = recover_codegroups(collapsed, 4) print(codes) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 186b5fe9..7f9e11a2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -332,7 +332,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_diffusion_flat.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index f17afdf4..a6cf584f 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -69,7 +69,7 @@ class MusicDiffusionFid(evaluator.Evaluator): self.diffusion_fn = self.perform_diffusion_from_codes self.local_modules['codegen'] = get_music_codegen() self.spec_fn = TorchMelSpectrogramInjector({'n_mel_channels': 256, 'mel_fmax': 11000, 'filter_length': 16000, - 'normalize': True, 'do_normalization': True, 'in': 'in', 'out': 'out'}, {}) + 'normalize': True, 'in': 'in', 'out': 'out'}, {}) def load_data(self, path): return list(glob(f'{path}/*.wav')) @@ -86,7 +86,7 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'aligned_conditioning': mel}) gen = pixel_shuffle_1d(gen, 16) - return gen, real_resampled, self.spec_fn({'in': gen})['out'], mel, sample_rate + return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate def gen_freq_gap(self, mel, band_range=(60,100)): gap_start, gap_end = band_range