Something's Wrong
This commit is contained in:
parent
dcd5fecff3
commit
8b3d1cf70a
|
@ -51,7 +51,6 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* embedding the current timestep is *required*, despite this technically being encoded in how many masked tokens exist within a sequence.
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
|
|
|
@ -1739,7 +1739,14 @@ class Base(nn.Module):
|
|||
# perform repetition penalizing
|
||||
if prev_list is not None and repetition_penalty != 1.0:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# penalize non-autoregressively
|
||||
if quant_levels is not None:
|
||||
#logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ]
|
||||
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
# penalize autoregressively
|
||||
else:
|
||||
logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
|
|
|
@ -252,12 +252,13 @@ class NAR(Base):
|
|||
test_artifact = np.load(path, allow_pickle=True)[()]
|
||||
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
|
||||
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
|
||||
proms_list = [ resps for resps in resps_list ]
|
||||
proms_list = [ resps[:75*3, :] for resps in resps_list ]
|
||||
#proms_list = [ resps for resps in resps_list ]
|
||||
len_list = [ resps.shape[0] for resps in resps_list ]
|
||||
"""
|
||||
|
||||
_super = super()
|
||||
def demask_sampling( seq_len, max_steps=20, temperature=0.3 ):
|
||||
def demask_sampling( seq_len, max_steps=5, temperature=1.0 ):
|
||||
starting_temperature = temperature
|
||||
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
|
||||
|
@ -268,12 +269,13 @@ class NAR(Base):
|
|||
|
||||
start_noise = 0.0
|
||||
end_noise = 1.0
|
||||
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
||||
|
||||
# use hardcoded reference file to test inference capabilities
|
||||
if test_artifact is not None:
|
||||
# because we "set" it later on, it's not implicitly captured
|
||||
nonlocal resps_list
|
||||
start_noise = 0.0
|
||||
start_noise = 0.5
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
||||
|
||||
|
@ -348,15 +350,15 @@ class NAR(Base):
|
|||
|
||||
# sample with gumbelnoise
|
||||
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
|
||||
#sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
sampled_ids = filtered_tokens
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
#sampled_ids = filtered_tokens
|
||||
|
||||
# keep unmasked tokens
|
||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
|
||||
|
|
|
@ -8,29 +8,52 @@ from torch import Tensor, einsum, nn
|
|||
from einops import rearrange
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
# `decay` is a factor that will exponentially apply to how far away it is
|
||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
|
||||
# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future)
|
||||
def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ):
|
||||
if factor == 1.0 or previous is None:
|
||||
return logits
|
||||
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for distance, token in enumerate(priors):
|
||||
# rep-pen range
|
||||
if limit and distance >= limit:
|
||||
continue
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
seq_len = logits.shape[0]
|
||||
prev_len = len( previous )
|
||||
|
||||
# apply autoregressively
|
||||
if prev_len < seq_len:
|
||||
unique = set()
|
||||
priors = reversed(previous)
|
||||
for i, token in enumerate(priors):
|
||||
# rep-pen range
|
||||
if limit and i >= limit:
|
||||
continue
|
||||
# skip if we're only applying the decay once
|
||||
if one_time and token in unique:
|
||||
continue
|
||||
|
||||
distance = i + 1
|
||||
logits[-1, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
# apply non-autoregressively
|
||||
else:
|
||||
for i, token in enumerate( previous ):
|
||||
# apply to next token
|
||||
start = i + 1
|
||||
# apply either up to limit tokens, or to the end
|
||||
end = start + limit if limit > 0 else seq_len
|
||||
start = clamp(0, seq_len - 1, start)
|
||||
end = clamp(0, seq_len - 1, end)
|
||||
for j in range( start, end ):
|
||||
distance = j - i
|
||||
logits[j, token] /= factor * (distance ** decay)
|
||||
|
||||
distance += 1
|
||||
logits[:, token] /= factor * (distance ** decay)
|
||||
|
||||
# add to set if we care about it
|
||||
if one_time:
|
||||
unique.add(token)
|
||||
|
||||
return logits
|
||||
|
||||
|
@ -379,8 +402,6 @@ def _sample_entropix(
|
|||
min_p=0.0,
|
||||
cfg=EntropixSamplerConfig(),
|
||||
):
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
if top_k == 0:
|
||||
top_k = logits.shape[-1]
|
||||
|
|
Loading…
Reference in New Issue
Block a user