'borrowed' a sampling scheduler for NAR-len's RVQ level 0 (better than before, but still not good enough)
This commit is contained in:
parent
e108c54daf
commit
c127c4e488
|
@ -41,7 +41,6 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
|
||||
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
|
||||
|
||||
|
||||
### Pure NAR
|
||||
|
||||
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
|
@ -50,10 +49,13 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
|
||||
To-do: fill out this more when it works.
|
||||
|
||||
## Embeddings
|
||||
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
|
||||
|
@ -99,7 +101,8 @@ Howver, the `resp` requires some extra care, as the model needs to both causally
|
|||
* The first embedding level pertains to RVQ level 0 for the AR.
|
||||
* The remaining embedding levels maps to RVQ level 0 + n for the NAR.
|
||||
* In other words, embedding level 1 => RVQ level 0, embedding level 2 => RVQ level 1, etc...
|
||||
* I believe this is because the model needs to "know" whether to predict the next token in the sequence, or the token in the same position of the next RVQ level.
|
||||
* I believe this is because the model needs to "know" whether to predict ~~the next token in the sequence, or the token in the same position of the next RVQ level~~ which tokens of a given embedding.
|
||||
* In other words, the AR's RVQ level 0 embedding predicts itself, while the NAR's embeddings predict the next level's embeddings.
|
||||
* Unfortunately, providing a token for the current/target RVQ level within the input sequence doesn't seem to help? I don't remember if I experimented with this or not, but testing of a "sane" `resp` embedding proved to be unfruitful.
|
||||
|
||||
The `prom` and `resp` are split since, in theory, it helps the model know better what audio to source from, and what audio is part of the output sequence. In theory.
|
||||
|
|
|
@ -391,7 +391,8 @@ class AR_NAR(Base):
|
|||
if sampled.entropy:
|
||||
metrics.append( sampled.entropy )
|
||||
elif sampled.scores:
|
||||
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
|
||||
#metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
|
||||
metrics.append( [ { "p": p[0] } for p in sampled.scores ] )
|
||||
|
||||
if mirostat is not None:
|
||||
mirostat = sampled.scores
|
||||
|
|
|
@ -47,8 +47,13 @@ LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
|||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||
"""
|
||||
|
||||
def _dropout_mask( input, p=0.8 ):
|
||||
return torch.tensor( [ 0 if random.random() < p else 1 for _ in range( input.shape[0] ) ], dtype=torch.uint8, device=input.device )
|
||||
def _dropout_mask( input, p=None ):
|
||||
# cosine scheduling
|
||||
if p is None:
|
||||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
|
||||
return torch.tensor( [ random.random() < p for _ in range( input.shape[0] ) ], dtype=torch.bool, device=input.device )
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
@ -1004,7 +1009,9 @@ class Base(nn.Module):
|
|||
|
||||
# store dropout mask
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=0.8 )
|
||||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
@ -1145,36 +1152,14 @@ class Base(nn.Module):
|
|||
) for l in range( input.shape[-1] ) ]
|
||||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
elif "len" in self.capabilities and quant_level == 0:
|
||||
mask_token = self.resps_emb(
|
||||
torch.tensor( [ self.stop_token ], dtype=torch.int16, device=input.device ),
|
||||
offset = 0,
|
||||
quant_level = 0
|
||||
)
|
||||
|
||||
# if training
|
||||
if not input.is_floating_point():
|
||||
# get original sequence
|
||||
# if training NAR-len RVQ level 0
|
||||
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
input,
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
||||
# create dropout mask if one is not provided
|
||||
if dropout_mask is None:
|
||||
dropout_mask = _dropout_mask( input )
|
||||
|
||||
# replace with masked tokens
|
||||
for i, token in enumerate( dropout_mask ):
|
||||
if token == 0:
|
||||
embedding[i] = mask_token
|
||||
|
||||
# if inferencing
|
||||
else:
|
||||
# fill with mask tokens for now
|
||||
embedding = torch.concat([ mask_token for _ in range( input.shape[0] ) ])
|
||||
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
|
@ -1331,9 +1316,7 @@ class Base(nn.Module):
|
|||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
seq = input if input.dim() == 1 else input[:, 0]
|
||||
masked = torch.tensor([ token if dropout_mask[i] == 1 else self.ignore_index for i, token in enumerate( seq ) ], dtype=torch.int16, device=input.device)
|
||||
target.append( masked )
|
||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||
elif self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
|
||||
|
@ -1778,6 +1761,12 @@ class Base(nn.Module):
|
|||
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
# calculate token probabilities
|
||||
if "len" in self.capabilities:
|
||||
scores = [
|
||||
[ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
else:
|
||||
scores = [
|
||||
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
|
|
|
@ -6,21 +6,22 @@ It *does* have to inference the initial length in an autoregresssive-ish manner
|
|||
Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer.
|
||||
"""
|
||||
|
||||
from .base import Base, list_to_tensor, Categorical
|
||||
from ..config import cfg
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from .base import Base, list_to_tensor, Categorical, _dropout_mask
|
||||
from ..config import cfg
|
||||
from ..emb.qnt import trim, repeat_extend_audio
|
||||
|
||||
import logging
|
||||
from ..samplers import SampleScheduler
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
@ -211,23 +212,91 @@ class NAR(Base):
|
|||
|
||||
|
||||
if len_list is not None:
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
# fill with mock tokens
|
||||
#prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
#prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
|
||||
#prev_list = [ None for resp_len in len_list ] # this breaks the position ID calc
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_max_levels - 1
|
||||
|
||||
if sampling_layer_skip:
|
||||
if sampling_layer_skip_entropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
|
||||
if sampling_layer_skip_varentropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
# initial condition
|
||||
len_list = [ min(l, 500) for l in len_list ]
|
||||
metrics = []
|
||||
|
||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ]
|
||||
|
||||
# to-do: special "scheduling" to inference RVQ-level 0
|
||||
# special "scheduling" to inference RVQ-level 0
|
||||
level = 0
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
_super = super()
|
||||
def forward_lambda( ids, step, temperature ):
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ ids[0] ]
|
||||
seq_len = ids.shape[-1]
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
output = _super.forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = output.logits
|
||||
|
||||
sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
ids = sampled[0]
|
||||
|
||||
return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0)
|
||||
|
||||
scheduler = SampleScheduler(
|
||||
device=device,
|
||||
mask_token=self.stop_token,
|
||||
max_steps=30,
|
||||
forward_lambda=forward_lambda,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
prev_list = [ scheduler.sample( seq_len=len_list[0] ) ]
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(prev_list):
|
||||
if resp.dim() == 1:
|
||||
prev_list[i] = resp.unsqueeze(-1)
|
||||
|
||||
# to-do: figure out why this fails when I copy some things from ar_nar
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = 0 if n == 0 else prev_list[0].shape[-1]
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
|
@ -249,7 +318,7 @@ class NAR(Base):
|
|||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
# layer_skip_variables=sampling_layer_skip_variables,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
|
@ -258,24 +327,20 @@ class NAR(Base):
|
|||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
#temperature=sampling_temperature,
|
||||
temperature=1.0 if n == 0 else sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
temperature=0.0, # sampling_temperature,
|
||||
#min_temperature=sampling_min_temperature,
|
||||
#top_p=sampling_top_p,
|
||||
#top_k=sampling_top_k,
|
||||
#min_p=sampling_min_p,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
resps_list = sampled[0]
|
||||
|
||||
if n == 0:
|
||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||
else:
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
resps_list = sampled[0]
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import time
|
||||
|
||||
from torch import Tensor, einsum, nn
|
||||
|
||||
from einops import rearrange
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
|
@ -521,3 +521,103 @@ def sample_entropix(
|
|||
"""
|
||||
|
||||
return res, metrics
|
||||
|
||||
"""
|
||||
def add_gumbel_noise(t, temperature, device):
|
||||
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
||||
"""
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return torch.log(t.clamp(min = eps))
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.zeros_like(t).uniform_(0, 1)
|
||||
return -log(-log(noise))
|
||||
|
||||
def gumbel_sample(t, temperature = 1., dim = -1):
|
||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
||||
|
||||
def top_k(logits, thres = 0.9):
|
||||
k = math.ceil((1 - thres) * logits.shape[-1])
|
||||
val, ind = logits.topk(k, dim = -1)
|
||||
probs = torch.full_like(logits, float('-inf'))
|
||||
probs.scatter_(2, ind, val)
|
||||
return probs
|
||||
|
||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||
class SampleScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
forward_lambda = None,
|
||||
mask_token = -1,
|
||||
max_steps = 25,
|
||||
device = "cuda",
|
||||
sampling_temperature=1.0,
|
||||
):
|
||||
self.forward_lambda = forward_lambda
|
||||
self.max_steps = max_steps
|
||||
self.mask_token = mask_token
|
||||
self.device = device
|
||||
|
||||
"""
|
||||
self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1]
|
||||
self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2]
|
||||
self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)]
|
||||
"""
|
||||
|
||||
# lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493
|
||||
def sample( self, seq_len ):
|
||||
ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device)
|
||||
scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device)
|
||||
|
||||
for step in range( self.max_steps ):
|
||||
t = step / self.max_steps
|
||||
mask_ratio = math.cos(t * math.pi * 0.5)
|
||||
sampling_temperature = 1.0
|
||||
annealed_temperature = sampling_temperature * (1.0 - t)
|
||||
|
||||
num_token_masked = max(int(mask_ratio * seq_len), 1)
|
||||
masked_indices = scores.topk(num_token_masked, dim = -1).indices
|
||||
|
||||
ids = ids.scatter(1, masked_indices, self.mask_token)
|
||||
|
||||
logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature )
|
||||
filtered_logits = top_k( logits )
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 )
|
||||
|
||||
is_masked = ids == self.mask_token
|
||||
ids = torch.where( is_masked, sampled_ids, ids )
|
||||
|
||||
probs_without_temperature = logits.softmax(dim = -1)
|
||||
|
||||
scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None])
|
||||
scores = rearrange(scores, '... 1 -> ...')
|
||||
#scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5)
|
||||
|
||||
"""
|
||||
if step + 1 == self.max_steps:
|
||||
break
|
||||
|
||||
# lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
||||
# create next input sequence
|
||||
mask = (ids == self.mask_token)
|
||||
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device)
|
||||
mask_len = torch.maximum(
|
||||
torch.Tensor([1]).to(self.device),
|
||||
torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len )
|
||||
)[0].squeeze()
|
||||
|
||||
logits = torch.log_softmax(logits, dim=-1)
|
||||
sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
||||
sampled_ids = torch.where(mask, sampled_ids, ids)
|
||||
sampled_logits = torch.where(mask, sampled_logits, +np.inf).float()
|
||||
|
||||
confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device)
|
||||
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
||||
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
|
||||
masking = (confidence <= cut_off)
|
||||
|
||||
ids = torch.where(masking, self.mask_token, sampled_ids)
|
||||
"""
|
||||
|
||||
return sampled_ids[0]
|
Loading…
Reference in New Issue
Block a user