あたしって、ほんとバカ
This commit is contained in:
parent
77ff23e319
commit
5698188824
|
@ -35,15 +35,19 @@ Non-autoregressive trainng is performed by having the input tokens from the prev
|
|||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be a bit of a challenge, as this could be anything from random noise to a unique token.
|
||||
* The current implementation repeats the input prompt's RVQ level 0 as the initial condition, but inferencing fills with stop tokens. This *might* be the problem, but I do not have my `nar-len-llama-8` weights stored anywhere, sadly.
|
||||
* Testing showed that it's easy to predict the duration, but decoding the first RVQ level accurately proves to be a chore.
|
||||
* Initially, output seemed chaotic and unreliable, but further experiments showed the model will "work" for a brief moment before going silent.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
One problem exhibited from a NAR is producing arfifacts ("crust") in the final waveform. I believe this is a confidence problem where the wrong token is inferred.
|
||||
* Unfortunately, one solution is to simply train a separate NAR, as this should help bolster the model's NAR capabilities without the AR influencing things, as I imagine being able to both causally and parallel-ly decode tokens harms things.
|
||||
* This is backed by the used `cfg.model.experimental.rvq_levels_p` distribution affecting the model's AR capabilities, as increasing the NAR's share in training causes the AR to perform *less*.
|
||||
* However, this may be simply wrong, but checkpoints that used such distributions felt lobotomized.
|
||||
* Another solution that may help is to provide two token dropout methods:
|
||||
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
|
||||
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
|
||||
|
||||
## Embeddings
|
||||
|
||||
|
|
|
@ -241,8 +241,8 @@ class AudioEmbedding(nn.Module):
|
|||
if self.capabilities is None:
|
||||
offset = 0
|
||||
# resp
|
||||
elif "len" in self.capabilities:
|
||||
offset = 1
|
||||
#elif "len" in self.capabilities:
|
||||
# offset = 1
|
||||
elif "nar" not in self.capabilities:
|
||||
offset = 0
|
||||
elif quant_level > 0:
|
||||
|
@ -460,21 +460,15 @@ class Base(nn.Module):
|
|||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
# AR+NAR model
|
||||
elif "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
# AR+NAR model
|
||||
# NAR-len model
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels + (1 if split_classifiers else 0))
|
||||
|
||||
# there seems to be a problem with the NAR-only model with non-unified position IDs.............
|
||||
"""
|
||||
if "len" in self.capabilities and not unified_position_ids:
|
||||
raise Exception("ERROR: model instability for NAR-only model when not using unified position IDs.")
|
||||
"""
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels + 1)
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
|
@ -490,7 +484,8 @@ class Base(nn.Module):
|
|||
|
||||
# it would be nicer for these to be a token or live inside an embedding
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value
|
||||
self.dropout_token = nn.Parameter(torch.randn(d_model))
|
||||
self.mask_token = dropout_token # alias (hopefully) to the above
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
||||
|
@ -521,7 +516,6 @@ class Base(nn.Module):
|
|||
capabilities=self.capabilities,
|
||||
)
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||
|
@ -533,6 +527,7 @@ class Base(nn.Module):
|
|||
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
# "len" RVQ level-0 gets an additional token
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model)
|
||||
|
||||
# experimental NAR-only mode
|
||||
|
@ -555,6 +550,7 @@ class Base(nn.Module):
|
|||
if attention_backend not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
|
||||
# override any requested padding size
|
||||
if attention_backend == "flash_attn_v100":
|
||||
self.l_padding = 32
|
||||
elif attention_backend == "fused_attn":
|
||||
|
@ -1126,8 +1122,9 @@ class Base(nn.Module):
|
|||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
elif "len" in self.capabilities and quant_level == 0:
|
||||
assert input_prom is not None, "Guru mediating during training"
|
||||
# fill with the prom as the initial condition
|
||||
"""
|
||||
assert input_prom is not None, "Guru mediation"
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||
|
||||
|
@ -1137,7 +1134,29 @@ class Base(nn.Module):
|
|||
quant_level = 0,
|
||||
)
|
||||
"""
|
||||
# fill with "stop" token from the len layer for the NAR-only model
|
||||
|
||||
|
||||
# if training
|
||||
if not input.is_floating_point():
|
||||
# get original sequence
|
||||
embedding = self.resps_emb(
|
||||
input,
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
# randomly replace with mask tokens
|
||||
for i in range( embedding.shape[0] ):
|
||||
# a paper said to do this
|
||||
if random.random() > 0.8:
|
||||
continue
|
||||
embedding[i] = self.dropout_token
|
||||
# if inferencing
|
||||
else:
|
||||
# fill with mask tokens
|
||||
embedding = torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( input.shape[0] ) ])
|
||||
|
||||
"""
|
||||
# fill with filler token from the len layer for the NAR-only model
|
||||
filler_token = 12
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
|
@ -1165,9 +1184,11 @@ class Base(nn.Module):
|
|||
)
|
||||
else:
|
||||
offset = 0
|
||||
"""
|
||||
if "len" in self.capabilities:
|
||||
offset = 1
|
||||
elif "nar" not in self.capabilities:
|
||||
"""
|
||||
if "nar" not in self.capabilities:
|
||||
offset = 0
|
||||
elif quant_level > 0:
|
||||
offset = 1
|
||||
|
@ -1676,8 +1697,10 @@ class Base(nn.Module):
|
|||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
# (AR-len) disable extraneous tokens
|
||||
"""
|
||||
if quant_levels is None and "len" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
"""
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||
|
|
|
@ -216,9 +216,11 @@ class NAR(Base):
|
|||
max_levels = self.n_resp_levels
|
||||
|
||||
# fill with mock tokens
|
||||
# to-do: repeat with the input prompt, as per training
|
||||
#prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
|
||||
#prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
|
||||
|
||||
prev_list = [ torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( resp_len ) ]) for resp_len in len_list ]
|
||||
#prev_list = [ None for resp_len in len_list ]
|
||||
|
||||
# to-do: figure out why this fails when I copy some things from ar_nar
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user