From c127c4e488e417d6ace58300e0850f3f814b684e Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 7 Nov 2024 21:19:14 -0600 Subject: [PATCH] 'borrowed' a sampling scheduler for NAR-len's RVQ level 0 (better than before, but still not good enough) --- docs/models.md | 11 ++-- vall_e/models/ar_nar.py | 3 +- vall_e/models/base.py | 65 +++++++++----------- vall_e/models/nar.py | 131 ++++++++++++++++++++++++++++++---------- vall_e/samplers.py | 104 ++++++++++++++++++++++++++++++- 5 files changed, 236 insertions(+), 78 deletions(-) diff --git a/docs/models.md b/docs/models.md index 0466d3a..ed8e054 100644 --- a/docs/models.md +++ b/docs/models.md @@ -41,7 +41,6 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w * `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted. * `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input. - ### Pure NAR The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types). @@ -50,10 +49,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid * The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively. * The latter however proves to be challenging, as generating tokens from nothing in one step is not possible. * diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level. - * however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step. - * incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation * the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step. +The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise. +* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation + +To-do: fill out this more when it works. + ## Embeddings The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence. @@ -99,7 +101,8 @@ Howver, the `resp` requires some extra care, as the model needs to both causally * The first embedding level pertains to RVQ level 0 for the AR. * The remaining embedding levels maps to RVQ level 0 + n for the NAR. * In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc... -* I believe this is because the model needs to "know" whether to predict the next token in the sequence, or the token in the same position of the next RVQ level. +* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding. + * In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings. * Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful. The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory. diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 98fc2e2..4b9be09 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -391,7 +391,8 @@ class AR_NAR(Base): if sampled.entropy: metrics.append( sampled.entropy ) elif sampled.scores: - metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) + #metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) + metrics.append( [ { "p": p[0] } for p in sampled.scores ] ) if mirostat is not None: mirostat = sampled.scores diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 61ecb64..88e6784 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -47,8 +47,13 @@ LossStats = namedtuple('LossStats', ['loss', 'stats']) from ..utils.pattern import DelayedPatternProvider, VALLEPattern """ -def _dropout_mask( input, p=0.8 ): - return torch.tensor( [ 0 if random.random() < p else 1 for _ in range( input.shape[0] ) ], dtype=torch.uint8, device=input.device ) +def _dropout_mask( input, p=None ): + # cosine scheduling + if p is None: + t = random.random() + p = math.cos(t * math.pi * 0.5) + + return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device ) def clamp(n, lo, hi): return max(lo, min(n, hi)) @@ -1004,7 +1009,9 @@ class Base(nn.Module): # store dropout mask if "len" in self.capabilities and quant_level == 0: - dropout_mask = _dropout_mask( resps_list[i], p=0.8 ) + t = random.random() + p = math.cos(t * math.pi * 0.5) + dropout_mask = _dropout_mask( resps_list[i], p=p ) inputs[i].append( ("dropout_mask", dropout_mask ) ) # Audio length prediction task @@ -1145,36 +1152,14 @@ class Base(nn.Module): ) for l in range( input.shape[-1] ) ] embedding = _interleave_sequence_reshape( embeddings ) - elif "len" in self.capabilities and quant_level == 0: - mask_token = self.resps_emb( - torch.tensor( [ self.stop_token ], dtype=torch.int16, device=input.device ), + + # if training NAR-len RVQ level 0 + elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None: + embedding = self.resps_emb( + torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), offset = 0, - quant_level = 0 + quant_level = 0, ) - - # if training - if not input.is_floating_point(): - # get original sequence - embedding = self.resps_emb( - input, - offset = 0, - quant_level = 0, - ) - - # create dropout mask if one is not provided - if dropout_mask is None: - dropout_mask = _dropout_mask( input ) - - # replace with masked tokens - for i, token in enumerate( dropout_mask ): - if token == 0: - embedding[i] = mask_token - - # if inferencing - else: - # fill with mask tokens for now - embedding = torch.concat([ mask_token for _ in range( input.shape[0] ) ]) - # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... @@ -1331,9 +1316,7 @@ class Base(nn.Module): elif name == "resp": # mask found, apply it if dropout_mask is not None: - seq = input if input.dim() == 1 else input[:, 0] - masked = torch.tensor([ token if dropout_mask[i] == 1 else self.ignore_index for i, token in enumerate( seq ) ], dtype=torch.int16, device=input.device) - target.append( masked ) + target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) ) elif self.interleave: target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) @@ -1778,9 +1761,15 @@ class Base(nn.Module): res = [ Categorical(logits=logit).sample() for logit in logits ] # calculate token probabilities - scores = [ - [ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ] - for logit, tokens in zip(logits, res) - ] + if "len" in self.capabilities: + scores = [ + [ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ] + for logit, tokens in zip(logits, res) + ] + else: + scores = [ + [ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ] + for logit, tokens in zip(logits, res) + ] return Sampled(res, scores, entropy) \ No newline at end of file diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 7fe7cb2..ef26e3d 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -6,21 +6,22 @@ It *does* have to inference the initial length in an autoregresssive-ish manner Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer. """ -from .base import Base, list_to_tensor, Categorical -from ..config import cfg - -import torch -from torch.nn.utils.rnn import pad_sequence import random import math +import numpy as np +import logging +import torch +from torch.nn.utils.rnn import pad_sequence + from einops import rearrange from torch import Tensor from tqdm import trange +from .base import Base, list_to_tensor, Categorical, _dropout_mask +from ..config import cfg from ..emb.qnt import trim, repeat_extend_audio - -import logging +from ..samplers import SampleScheduler def clamp(n, lo, hi): return max(lo, min(n, hi)) @@ -211,23 +212,91 @@ class NAR(Base): if len_list is not None: - # is NAR + sampling_layer_skip_variables = {} if sampling_layer_skip else None + if max_levels == 0: - max_levels = self.n_resp_levels - - # fill with mock tokens - #prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ] - #prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ] - #prev_list = [ None for resp_len in len_list ] # this breaks the position ID calc - + max_levels = self.n_max_levels - 1 + + if sampling_layer_skip: + if sampling_layer_skip_entropy_threshold >= 0: + sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold + if sampling_layer_skip_varentropy_threshold >= 0: + sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold + if sampling_layer_skip_exit_layer >= 0: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer + + # initial condition + len_list = [ min(l, 500) for l in len_list ] + metrics = [] + mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ] - # to-do: special "scheduling" to inference RVQ-level 0 + # special "scheduling" to inference RVQ-level 0 + level = 0 + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) - # to-do: figure out why this fails when I copy some things from ar_nar - for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): - level = 0 if n == 0 else prev_list[0].shape[-1] + _super = super() + def forward_lambda( ids, step, temperature ): + quant_levels = [ level for _ in range(batch_size) ] + prev_list = [ ids[0] ] + seq_len = ids.shape[-1] + + inputs = _super.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + + output = _super.forward( + inputs=inputs, + quant_levels=quant_levels, + + layer_skip_variables=sampling_layer_skip_variables, + ) + logits = output.logits + + sampled = _super.sample( + logits=logits, + prev_list=prev_list, + quant_levels=quant_levels, + + temperature=temperature, + min_temperature=sampling_min_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + min_p=sampling_min_p, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + #mirostat=mirostat, + ) + + ids = sampled[0] + + return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0) + + scheduler = SampleScheduler( + device=device, + mask_token=self.stop_token, + max_steps=30, + forward_lambda=forward_lambda, + sampling_temperature=sampling_temperature, + ) + prev_list = [ scheduler.sample( seq_len=len_list[0] ) ] + + # expand if given a raw 1D tensor + for i, resp in enumerate(prev_list): + if resp.dim() == 1: + prev_list[i] = resp.unsqueeze(-1) + + for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): + level = prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break @@ -249,7 +318,7 @@ class NAR(Base): inputs=inputs, quant_levels=quant_levels, - # layer_skip_variables=sampling_layer_skip_variables, + layer_skip_variables=sampling_layer_skip_variables, ) logits, state = output.logits, output.state @@ -258,24 +327,20 @@ class NAR(Base): prev_list=prev_list, quant_levels=quant_levels, - #temperature=sampling_temperature, - temperature=1.0 if n == 0 else sampling_temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - min_p=sampling_min_p, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, + temperature=0.0, # sampling_temperature, + #min_temperature=sampling_min_temperature, + #top_p=sampling_top_p, + #top_k=sampling_top_k, + #min_p=sampling_min_p, + #repetition_penalty=sampling_repetition_penalty, + #repetition_penalty_decay=sampling_repetition_penalty_decay, #length_penalty=sampling_length_penalty, #beam_width=sampling_beam_width, #mirostat=mirostat, ) - resps_list = sampled[0] - if n == 0: - prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] - else: - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + resps_list = sampled[0] + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list diff --git a/vall_e/samplers.py b/vall_e/samplers.py index c149ed4..697425b 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -5,7 +5,7 @@ import numpy as np import time from torch import Tensor, einsum, nn - +from einops import rearrange from dataclasses import asdict, dataclass, field # Simple filter to modify a token's probability if it shows up in the past @@ -520,4 +520,104 @@ def sample_entropix( metrics["min_p"] = min_p """ - return res, metrics \ No newline at end of file + return res, metrics + +""" +def add_gumbel_noise(t, temperature, device): + return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) +""" + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample(t, temperature = 1., dim = -1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) + +def top_k(logits, thres = 0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = logits.topk(k, dim = -1) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(2, ind, val) + return probs + +# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion""" +class SampleScheduler: + def __init__( + self, + forward_lambda = None, + mask_token = -1, + max_steps = 25, + device = "cuda", + sampling_temperature=1.0, + ): + self.forward_lambda = forward_lambda + self.max_steps = max_steps + self.mask_token = mask_token + self.device = device + + """ + self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1] + self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2] + self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)] + """ + + # lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493 + def sample( self, seq_len ): + ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device) + scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device) + + for step in range( self.max_steps ): + t = step / self.max_steps + mask_ratio = math.cos(t * math.pi * 0.5) + sampling_temperature = 1.0 + annealed_temperature = sampling_temperature * (1.0 - t) + + num_token_masked = max(int(mask_ratio * seq_len), 1) + masked_indices = scores.topk(num_token_masked, dim = -1).indices + + ids = ids.scatter(1, masked_indices, self.mask_token) + + logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature ) + filtered_logits = top_k( logits ) + sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 ) + + is_masked = ids == self.mask_token + ids = torch.where( is_masked, sampled_ids, ids ) + + probs_without_temperature = logits.softmax(dim = -1) + + scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None]) + scores = rearrange(scores, '... 1 -> ...') + #scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5) + + """ + if step + 1 == self.max_steps: + break + + # lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39 + # create next input sequence + mask = (ids == self.mask_token) + mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device) + mask_len = torch.maximum( + torch.Tensor([1]).to(self.device), + torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len ) + )[0].squeeze() + + logits = torch.log_softmax(logits, dim=-1) + sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) + sampled_ids = torch.where(mask, sampled_ids, ids) + sampled_logits = torch.where(mask, sampled_logits, +np.inf).float() + + confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device) + sorted_confidence, _ = torch.sort(confidence, axis=-1) + cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] + masking = (confidence <= cut_off) + + ids = torch.where(masking, self.mask_token, sampled_ids) + """ + + return sampled_ids[0] \ No newline at end of file