new NAR-len training paradigm......
This commit is contained in:
parent
ed174c589e
commit
e108c54daf
|
@ -33,14 +33,6 @@ As decoding is done non-autoregressively, the model can process tokens "in place
|
|||
|
||||
Non-autoregressive trainng is performed by having the input tokens from the previous RVQ level predict the next level's token in place. The output logits are in the same position, and do not require further modifications as required for the AR.
|
||||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
One problem exhibited from a NAR is producing arfifacts ("crust") in the final waveform. I believe this is a confidence problem where the wrong token is inferred.
|
||||
* Unfortunately, one solution is to simply train a separate NAR, as this should help bolster the model's NAR capabilities without the AR influencing things, as I imagine being able to both causally and parallel-ly decode tokens harms things.
|
||||
* This is backed by the used `cfg.model.experimental.rvq_levels_p` distribution affecting the model's AR capabilities, as increasing the NAR's share in training causes the AR to perform *less*.
|
||||
|
@ -49,6 +41,19 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
* `token_dropout_error`: This will randomly nudge a small percentage of tokens from the prior RVQ level to simulate wrong tokens being predicted.
|
||||
* `token_dropout_rate`: This will randomly mask off tokens from the prior RVQ level with a mask token, to try and have the model not-strongly-rely on the given input.
|
||||
|
||||
|
||||
### Pure NAR
|
||||
|
||||
The pure NAR (`nar-len`) model is a model-type that inferences audio tokens purely non-autoregressively. Despite being called a pure NAR, duration is then inferred by autoregressively decoding for its length (as the AR+NAR model shows that you can mix both types).
|
||||
|
||||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* however, it's possible to have a similar paradigm to diffusers, but instead iterating upon random noise, masked tokens are iterated per step, and each step picks the most confident tokens per step.
|
||||
* incidentally, [this paper](https://arxiv.org/abs/2406.05478) demonstrates this in the use of a NAR transformer for image generation
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
## Embeddings
|
||||
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
|
||||
|
|
|
@ -188,7 +188,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
# resize modules if I'm doing experiments and can't be assed to manually trim things
|
||||
if cfg.trainer.resize_modules:
|
||||
uses_stop_token = 1 if "len" not in model.capabilities and model.causal_size > 0 else 0
|
||||
uses_stop_token = 1 if model.causal_size > 0 else 0
|
||||
keys = [
|
||||
("text_emb.weight", model.config.text_tokens ),
|
||||
("tasks_emb.weight", model.config.tasks ),
|
||||
|
|
|
@ -47,6 +47,9 @@ LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
|||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||
"""
|
||||
|
||||
def _dropout_mask( input, p=0.8 ):
|
||||
return torch.tensor( [ 0 if random.random() < p else 1 for _ in range( input.shape[0] ) ], dtype=torch.uint8, device=input.device )
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
|
@ -460,7 +463,12 @@ class Base(nn.Module):
|
|||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# AR+NAR model
|
||||
# AR+NAR model / NAR-len model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
"""
|
||||
elif "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
|
@ -468,7 +476,8 @@ class Base(nn.Module):
|
|||
# NAR-len model
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels + 1)
|
||||
l_tokens = [n_resp_tokens] * (self.n_resp_levels)
|
||||
"""
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
|
@ -485,7 +494,6 @@ class Base(nn.Module):
|
|||
# it would be nicer for these to be a token or live inside an embedding
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
self.dropout_token = nn.Parameter(torch.randn(d_model))
|
||||
self.mask_token = self.dropout_token # alias (hopefully) to the above
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
||||
|
@ -993,6 +1001,11 @@ class Base(nn.Module):
|
|||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=0.8 )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1087,6 +1100,18 @@ class Base(nn.Module):
|
|||
|
||||
task_type = "tts"
|
||||
input_prom = None
|
||||
dropout_mask = None
|
||||
|
||||
# pre-iterate
|
||||
for name, input in batch_input:
|
||||
"""
|
||||
if name == "prop":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
|
||||
"""
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
embedding = None
|
||||
|
@ -1107,8 +1132,7 @@ class Base(nn.Module):
|
|||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
|
||||
|
||||
# to-do: probably insert separators if task requires it?
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
|
@ -1122,19 +1146,11 @@ class Base(nn.Module):
|
|||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
elif "len" in self.capabilities and quant_level == 0:
|
||||
# fill with the prom as the initial condition
|
||||
"""
|
||||
assert input_prom is not None, "Guru mediation"
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||
|
||||
embedding = self.resps_emb(
|
||||
repeated,
|
||||
mask_token = self.resps_emb(
|
||||
torch.tensor( [ self.stop_token ], dtype=torch.int16, device=input.device ),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
quant_level = 0
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# if training
|
||||
if not input.is_floating_point():
|
||||
|
@ -1144,27 +1160,20 @@ class Base(nn.Module):
|
|||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
# randomly replace with mask tokens
|
||||
for i in range( embedding.shape[0] ):
|
||||
# a paper said to do this
|
||||
if random.random() > 0.8:
|
||||
continue
|
||||
embedding[i] = self.dropout_token
|
||||
|
||||
# create dropout mask if one is not provided
|
||||
if dropout_mask is None:
|
||||
dropout_mask = _dropout_mask( input )
|
||||
|
||||
# replace with masked tokens
|
||||
for i, token in enumerate( dropout_mask ):
|
||||
if token == 0:
|
||||
embedding[i] = mask_token
|
||||
|
||||
# if inferencing
|
||||
else:
|
||||
# fill with mask tokens
|
||||
embedding = torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( input.shape[0] ) ])
|
||||
|
||||
"""
|
||||
# fill with filler token from the len layer for the NAR-only model
|
||||
filler_token = 12
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
"""
|
||||
# fill with mask tokens for now
|
||||
embedding = torch.concat([ mask_token for _ in range( input.shape[0] ) ])
|
||||
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
|
@ -1241,6 +1250,10 @@ class Base(nn.Module):
|
|||
if isinstance(input, str):
|
||||
return 1
|
||||
|
||||
# a mask
|
||||
if name == "dropout_mask":
|
||||
return 0
|
||||
|
||||
# list of tokens
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
|
@ -1302,6 +1315,12 @@ class Base(nn.Module):
|
|||
quant_level = quant_levels[batch_index]
|
||||
target = []
|
||||
task_type = "tts"
|
||||
|
||||
dropout_mask = None
|
||||
for name, input in batch:
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
|
@ -1310,8 +1329,14 @@ class Base(nn.Module):
|
|||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||
elif name == "resp":
|
||||
if self.interleave:
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
seq = input if input.dim() == 1 else input[:, 0]
|
||||
masked = torch.tensor([ token if dropout_mask[i] == 1 else self.ignore_index for i, token in enumerate( seq ) ], dtype=torch.int16, device=input.device)
|
||||
target.append( masked )
|
||||
elif self.interleave:
|
||||
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
|
||||
|
||||
elif task_type in summed_embeddings_task:
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
else:
|
||||
|
@ -1544,7 +1569,7 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
|
||||
|
||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
|
||||
output = self._forward(
|
||||
|
|
|
@ -220,7 +220,8 @@ class NAR(Base):
|
|||
#prev_list = [ repeat_extend_audio( prom, resp_len ) for resp_len, prom in zip(len_list, proms_list) ]
|
||||
#prev_list = [ None for resp_len in len_list ] # this breaks the position ID calc
|
||||
|
||||
prev_list = [ torch.concat([ self.dropout_token.unsqueeze(0) for _ in range( resp_len ) ]) for resp_len in len_list ]
|
||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ]
|
||||
|
||||
# to-do: special "scheduling" to inference RVQ-level 0
|
||||
|
||||
|
@ -484,6 +485,8 @@ def example_usage():
|
|||
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
|
||||
|
||||
len_list = [ min(l, 500) for l in len_list ]
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user