'borrowed' a sampling scheduler for NAR-len's RVQ level 0 (better than before, but still not good enough)

This commit is contained in:
mrq 2024-11-07 21:19:14 -06:00
parent e108c54daf
commit c127c4e488
5 changed files with 236 additions and 78 deletions

View File

@ -41,7 +41,6 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted. * `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input. * `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
### Pure NAR ### Pure NAR
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types). The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
@ -50,10 +49,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively. * The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible. * The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level. * diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
* however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step.
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step. * the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
To-do: fill out this more when it works.
## Embeddings ## Embeddings
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence. The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
@ -99,7 +101,8 @@ Howver, the `resp` requires some extra care, as the model needs to both causally
* The first embedding level pertains to RVQ level 0 for the AR. * The first embedding level pertains to RVQ level 0 for the AR.
* The remaining embedding levels maps to RVQ level 0 + n for the NAR. * The remaining embedding levels maps to RVQ level 0 + n for the NAR.
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc... * In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
* I believe this is because the model needs to "know" whether to predict the next token in the sequence, or the token in the same position of the next RVQ level. * I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful. * Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory. The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.

View File

@ -391,7 +391,8 @@ class AR_NAR(Base):
if sampled.entropy: if sampled.entropy:
metrics.append( sampled.entropy ) metrics.append( sampled.entropy )
elif sampled.scores: elif sampled.scores:
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) #metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
if mirostat is not None: if mirostat is not None:
mirostat = sampled.scores mirostat = sampled.scores

View File

@ -47,8 +47,13 @@ LossStats = namedtuple('LossStats', ['loss', 'stats'])
from ..utils.pattern import DelayedPatternProvider, VALLEPattern from ..utils.pattern import DelayedPatternProvider, VALLEPattern
""" """
def _dropout_mask( input, p=0.8 ): def _dropout_mask( input, p=None ):
return torch.tensor( [ 0 if random.random() < p else 1 for _ in range( input.shape[0] ) ], dtype=torch.uint8, device=input.device ) # cosine scheduling
if p is None:
t = random.random()
p = math.cos(t * math.pi * 0.5)
return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device )
def clamp(n, lo, hi): def clamp(n, lo, hi):
return max(lo, min(n, hi)) return max(lo, min(n, hi))
@ -1004,7 +1009,9 @@ class Base(nn.Module):
# store dropout mask # store dropout mask
if "len" in self.capabilities and quant_level == 0: if "len" in self.capabilities and quant_level == 0:
dropout_mask = _dropout_mask( resps_list[i], p=0.8 ) t = random.random()
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("dropout_mask", dropout_mask ) ) inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task # Audio length prediction task
@ -1145,36 +1152,14 @@ class Base(nn.Module):
) for l in range( input.shape[-1] ) ] ) for l in range( input.shape[-1] ) ]
embedding = _interleave_sequence_reshape( embeddings ) embedding = _interleave_sequence_reshape( embeddings )
elif "len" in self.capabilities and quant_level == 0:
mask_token = self.resps_emb(
torch.tensor( [ self.stop_token ], dtype=torch.int16, device=input.device ),
offset = 0,
quant_level = 0
)
# if training # if training NAR-len RVQ level 0
if not input.is_floating_point(): elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
# get original sequence
embedding = self.resps_emb( embedding = self.resps_emb(
input, torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
offset = 0, offset = 0,
quant_level = 0, quant_level = 0,
) )
# create dropout mask if one is not provided
if dropout_mask is None:
dropout_mask = _dropout_mask( input )
# replace with masked tokens
for i, token in enumerate( dropout_mask ):
if token == 0:
embedding[i] = mask_token
# if inferencing
else:
# fill with mask tokens for now
embedding = torch.concat([ mask_token for _ in range( input.shape[0] ) ])
# cheat-y way to handle performing STT across all levels # cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task: elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
@ -1331,9 +1316,7 @@ class Base(nn.Module):
elif name == "resp": elif name == "resp":
# mask found, apply it # mask found, apply it
if dropout_mask is not None: if dropout_mask is not None:
seq = input if input.dim() == 1 else input[:, 0] target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
masked = torch.tensor([ token if dropout_mask[i] == 1 else self.ignore_index for i, token in enumerate( seq ) ], dtype=torch.int16, device=input.device)
target.append( masked )
elif self.interleave: elif self.interleave:
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
@ -1778,6 +1761,12 @@ class Base(nn.Module):
res = [ Categorical(logits=logit).sample() for logit in logits ] res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities # calculate token probabilities
if "len" in self.capabilities:
scores = [
[ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ]
for logit, tokens in zip(logits, res)
]
else:
scores = [ scores = [
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ] [ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
for logit, tokens in zip(logits, res) for logit, tokens in zip(logits, res)

View File

@ -6,21 +6,22 @@ It *does* have to inference the initial length in an autoregresssive-ish manner
Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer. Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer.
""" """
from .base import Base, list_to_tensor, Categorical
from ..config import cfg
import torch
from torch.nn.utils.rnn import pad_sequence
import random import random
import math import math
import numpy as np
import logging
import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
from .base import Base, list_to_tensor, Categorical, _dropout_mask
from ..config import cfg
from ..emb.qnt import trim, repeat_extend_audio from ..emb.qnt import trim, repeat_extend_audio
from ..samplers import SampleScheduler
import logging
def clamp(n, lo, hi): def clamp(n, lo, hi):
return max(lo, min(n, hi)) return max(lo, min(n, hi))
@ -211,23 +212,91 @@ class NAR(Base):
if len_list is not None: if len_list is not None:
# is NAR sampling_layer_skip_variables = {} if sampling_layer_skip else None
if max_levels == 0:
max_levels = self.n_resp_levels
# fill with mock tokens if max_levels == 0:
#prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ] max_levels = self.n_max_levels - 1
#prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
#prev_list = [ None for resp_len in len_list ] # this breaks the position ID calc if sampling_layer_skip:
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# initial condition
len_list = [ min(l, 500) for l in len_list ]
metrics = []
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ] prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ]
# to-do: special "scheduling" to inference RVQ-level 0 # special "scheduling" to inference RVQ-level 0
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
_super = super()
def forward_lambda( ids, step, temperature ):
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ ids[0] ]
seq_len = ids.shape[-1]
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
ids = sampled[0]
return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0)
scheduler = SampleScheduler(
device=device,
mask_token=self.stop_token,
max_steps=30,
forward_lambda=forward_lambda,
sampling_temperature=sampling_temperature,
)
prev_list = [ scheduler.sample( seq_len=len_list[0] ) ]
# expand if given a raw 1D tensor
for i, resp in enumerate(prev_list):
if resp.dim() == 1:
prev_list[i] = resp.unsqueeze(-1)
# to-do: figure out why this fails when I copy some things from ar_nar
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
level = 0 if n == 0 else prev_list[0].shape[-1] level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break break
@ -249,7 +318,7 @@ class NAR(Base):
inputs=inputs, inputs=inputs,
quant_levels=quant_levels, quant_levels=quant_levels,
# layer_skip_variables=sampling_layer_skip_variables, layer_skip_variables=sampling_layer_skip_variables,
) )
logits, state = output.logits, output.state logits, state = output.logits, output.state
@ -258,24 +327,20 @@ class NAR(Base):
prev_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
#temperature=sampling_temperature, temperature=0.0, # sampling_temperature,
temperature=1.0 if n == 0 else sampling_temperature, #min_temperature=sampling_min_temperature,
min_temperature=sampling_min_temperature, #top_p=sampling_top_p,
top_p=sampling_top_p, #top_k=sampling_top_k,
top_k=sampling_top_k, #min_p=sampling_min_p,
min_p=sampling_min_p, #repetition_penalty=sampling_repetition_penalty,
repetition_penalty=sampling_repetition_penalty, #repetition_penalty_decay=sampling_repetition_penalty_decay,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty, #length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width, #beam_width=sampling_beam_width,
#mirostat=mirostat, #mirostat=mirostat,
) )
resps_list = sampled[0]
if n == 0: resps_list = sampled[0]
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
else:
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list return prev_list

View File

@ -5,7 +5,7 @@ import numpy as np
import time import time
from torch import Tensor, einsum, nn from torch import Tensor, einsum, nn
from einops import rearrange
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
# Simple filter to modify a token's probability if it shows up in the past # Simple filter to modify a token's probability if it shows up in the past
@ -521,3 +521,103 @@ def sample_entropix(
""" """
return res, metrics return res, metrics
"""
def add_gumbel_noise(t, temperature, device):
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
"""
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim = -1)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(2, ind, val)
return probs
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
class SampleScheduler:
def __init__(
self,
forward_lambda = None,
mask_token = -1,
max_steps = 25,
device = "cuda",
sampling_temperature=1.0,
):
self.forward_lambda = forward_lambda
self.max_steps = max_steps
self.mask_token = mask_token
self.device = device
"""
self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1]
self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2]
self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)]
"""
# lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493
def sample( self, seq_len ):
ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device)
scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device)
for step in range( self.max_steps ):
t = step / self.max_steps
mask_ratio = math.cos(t * math.pi * 0.5)
sampling_temperature = 1.0
annealed_temperature = sampling_temperature * (1.0 - t)
num_token_masked = max(int(mask_ratio * seq_len), 1)
masked_indices = scores.topk(num_token_masked, dim = -1).indices
ids = ids.scatter(1, masked_indices, self.mask_token)
logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature )
filtered_logits = top_k( logits )
sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 )
is_masked = ids == self.mask_token
ids = torch.where( is_masked, sampled_ids, ids )
probs_without_temperature = logits.softmax(dim = -1)
scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')
#scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5)
"""
if step + 1 == self.max_steps:
break
# lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
# create next input sequence
mask = (ids == self.mask_token)
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device)
mask_len = torch.maximum(
torch.Tensor([1]).to(self.device),
torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len )
)[0].squeeze()
logits = torch.log_softmax(logits, dim=-1)
sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
sampled_ids = torch.where(mask, sampled_ids, ids)
sampled_logits = torch.where(mask, sampled_logits, +np.inf).float()
confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device)
sorted_confidence, _ = torch.sort(confidence, axis=-1)
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
masking = (confidence <= cut_off)
ids = torch.where(masking, self.mask_token, sampled_ids)
"""
return sampled_ids[0]