Something's Wrong

This commit is contained in:
mrq 2024-11-09 15:07:43 -06:00
parent dcd5fecff3
commit 8b3d1cf70a
4 changed files with 55 additions and 26 deletions

View File

@ -51,7 +51,6 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
* embedding the current timestep is *required*, despite this technically being encoded in how many masked tokens exist within a sequence.
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.

View File

@ -1739,7 +1739,14 @@ class Base(nn.Module):
# perform repetition penalizing
if prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# penalize non-autoregressively
if quant_levels is not None:
#logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ]
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# penalize autoregressively
else:
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:

View File

@ -252,12 +252,13 @@ class NAR(Base):
test_artifact = np.load(path, allow_pickle=True)[()]
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
proms_list = [ resps for resps in resps_list ]
proms_list = [ resps[:75*3, :] for resps in resps_list ]
#proms_list = [ resps for resps in resps_list ]
len_list = [ resps.shape[0] for resps in resps_list ]
"""
_super = super()
def demask_sampling( seq_len, max_steps=20, temperature=0.3 ):
def demask_sampling( seq_len, max_steps=5, temperature=1.0 ):
starting_temperature = temperature
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
@ -268,12 +269,13 @@ class NAR(Base):
start_noise = 0.0
end_noise = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
# use hardcoded reference file to test inference capabilities
if test_artifact is not None:
# because we "set" it later on, it's not implicitly captured
nonlocal resps_list
start_noise = 0.0
start_noise = 0.5
noise_p = math.cos( start_noise * math.pi * 0.5 )
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
@ -348,15 +350,15 @@ class NAR(Base):
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
#sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
sampled_ids = filtered_tokens
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
#sampled_ids = filtered_tokens
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids

View File

@ -8,29 +8,52 @@ from torch import Tensor, einsum, nn
from einops import rearrange
from dataclasses import asdict, dataclass, field
def clamp(n, lo, hi):
return max(lo, min(n, hi))
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ):
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous)
for distance, token in enumerate(priors):
# rep-pen range
if limit and distance >= limit:
continue
# skip if we're only applying the decay once
if one_time and token in unique:
continue
seq_len = logits.shape[0]
prev_len = len( previous )
# apply autoregressively
if prev_len < seq_len:
unique = set()
priors = reversed(previous)
for i, token in enumerate(priors):
# rep-pen range
if limit and i >= limit:
continue
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance = i + 1
logits[-1, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
# apply non-autoregressively
else:
for i, token in enumerate( previous ):
# apply to next token
start = i + 1
# apply either up to limit tokens, or to the end
end = start + limit if limit > 0 else seq_len
start = clamp(0, seq_len - 1, start)
end = clamp(0, seq_len - 1, end)
for j in range( start, end ):
distance = j - i
logits[j, token] /= factor * (distance ** decay)
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
@ -379,8 +402,6 @@ def _sample_entropix(
min_p=0.0,
cfg=EntropixSamplerConfig(),
):
def clamp(n, lo, hi):
return max(lo, min(n, hi))
if top_k == 0:
top_k = logits.shape[-1]