I guess I'll fall for the NAR-len meme again (I don't know where my previous weights are, so I need to train it again to test something)
This commit is contained in:
parent
bcabde3454
commit
105ed51159
|
@ -36,6 +36,7 @@ Non-autoregressive trainng is performed by having the input tokens from the prev
|
|||
However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence.
|
||||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be a bit of a challenge, as this could be anything from random noise to a unique token.
|
||||
* The current implementation repeats the input prompt's RVQ level 0 as the initial condition, but inferencing fills with stop tokens. This *might* be the problem, but I do not have my `nar-len-llama-8` weights stored anywhere, sadly.
|
||||
* Testing showed that it's easy to predict the duration, but decoding the first RVQ level accurately proves to be a chore.
|
||||
* Initially, output seemed chaotic and unreliable, but further experiments showed the model will "work" for a brief moment before going silent.
|
||||
|
||||
|
@ -48,6 +49,8 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w
|
|||
|
||||
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence.
|
||||
|
||||
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
|
||||
|
||||
### Text Embeddings
|
||||
|
||||
The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward.
|
||||
|
|
|
@ -61,29 +61,50 @@ class AR(Base):
|
|||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
sampling_entropix=False,
|
||||
|
||||
sampling_layer_skip: bool = False,
|
||||
sampling_layer_skip_exit_layer: int = -1,
|
||||
sampling_layer_skip_entropy_threshold: float = -1,
|
||||
sampling_layer_skip_varentropy_threshold: float = -1,
|
||||
|
||||
sampling_refine_on_stop: bool = False,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
text_task = [ "stt" ]
|
||||
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
# generate task list if not provided
|
||||
if task_list is None:
|
||||
task_list = [ "tts" for _ in range(batch_size) ]
|
||||
task_list = [ default_task for _ in range(batch_size) ]
|
||||
|
||||
has_none = resps_list is None or text_list is None
|
||||
if not has_none:
|
||||
for i, task in enumerate( task_list ):
|
||||
if resps_list[i] is None or text_list[i] is None:
|
||||
has_none = True
|
||||
break
|
||||
|
||||
# is training or NAR
|
||||
if resps_list is not None:
|
||||
if not has_none:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# implicit
|
||||
if training is None:
|
||||
training = n_levels == self.n_resp_levels
|
||||
training = 0 if n_levels == self.n_resp_levels else None
|
||||
|
||||
# is training
|
||||
if training:
|
||||
if training is not None:
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||
# determines which RVQ level to target per batch
|
||||
|
@ -107,16 +128,19 @@ class AR(Base):
|
|||
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
|
||||
# input RVQ levels
|
||||
if not self.interleave:
|
||||
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
else:
|
||||
quant_levels = [ 0 for i in range(batch_size) ]
|
||||
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||
for i, task in enumerate( task_list ):
|
||||
if task in text_task:
|
||||
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
@ -130,9 +154,8 @@ class AR(Base):
|
|||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
|
@ -146,9 +169,13 @@ class AR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
stop_sequence = torch.tensor([[self.stop_token] * resps.shape[-1]], device=device, dtype=torch.int16)
|
||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||
|
||||
if quant_level <= 0:
|
||||
# append stop tokens for AR
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
...
|
||||
else:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -157,21 +184,26 @@ class AR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token
|
||||
)
|
||||
|
||||
# is AR
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
# STT
|
||||
start_slice = [ 0 for _ in range(batch_size) ]
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = self.stop_token
|
||||
|
||||
audio_stop_token = self.stop_token
|
||||
text_stop_token = 2
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
|
@ -179,10 +211,59 @@ class AR(Base):
|
|||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
metrics = []
|
||||
|
||||
# ick
|
||||
"""
|
||||
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
||||
low_temperature_range = cfg.dataset.frames_per_second * 5
|
||||
|
||||
original_sampling_temperature = sampling_temperature
|
||||
original_sampling_repetition_penalty = sampling_repetition_penalty
|
||||
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||
"""
|
||||
|
||||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
if sampling_layer_skip:
|
||||
if sampling_layer_skip_entropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
|
||||
if sampling_layer_skip_varentropy_threshold >= 0:
|
||||
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
# add <bos> to text for STT
|
||||
if task_list[i] in text_task:
|
||||
start_slice[i] = 1
|
||||
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
|
||||
# treat input prompt as initial resp (by prefixing with the prompt instead)
|
||||
elif input_prompt_prefix:
|
||||
start_slice[i] = proms_list[i].shape[0]
|
||||
sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i]
|
||||
elif prefix_silence > 0:
|
||||
sequence_list[i] = get_silence(prefix_silence, device=sequence_list[i].device)
|
||||
sequence_list[i] = sequence_list[i][:, 0]
|
||||
# start_slice[i] = sequence_list[i].shape[0]
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
|
||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
|
||||
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
|
||||
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
|
||||
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||
"""
|
||||
if low_temperature:
|
||||
enabled = n < low_temperature_range
|
||||
sampling_repetition_penalty = 1.125 if enabled else 1.25
|
||||
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||
#sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||
"""
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -195,15 +276,20 @@ class AR(Base):
|
|||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
# to-do: find an elegant way to write this
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
state=state,
|
||||
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
|
||||
output_attentions=sampling_entropix,
|
||||
)
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=resps_list,
|
||||
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
@ -220,47 +306,81 @@ class AR(Base):
|
|||
dry_multiplier=sampling_dry_multiplier,
|
||||
dry_base=sampling_dry_base,
|
||||
dry_allowed_length=sampling_dry_allowed_length,
|
||||
|
||||
attentions=output.attentions if sampling_entropix else None,
|
||||
)
|
||||
|
||||
r = sampled[0]
|
||||
|
||||
if cfg.experimental:
|
||||
if sampled.entropy:
|
||||
metrics.append( sampled.entropy )
|
||||
elif sampled.scores:
|
||||
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
|
||||
|
||||
if mirostat is not None:
|
||||
mirostat = sampled.scores
|
||||
elif sampling_beam_width > 0:
|
||||
# expand tuple
|
||||
scores = sampled.scores
|
||||
s = sampled.scores
|
||||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size = sampling_beam_width
|
||||
text_list = text_list * sampling_beam_width
|
||||
proms_list = proms_list * sampling_beam_width
|
||||
sequence_list = sequence_list * sampling_beam_width
|
||||
task_list = task_list * sampling_beam_width
|
||||
start_slice = start_slice * sampling_beam_width
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
scores = [ scores[i] + score for i, score in enumerate(scores) ]
|
||||
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
task = task_list[i]
|
||||
stop_token = audio_stop_token if task not in text_task else text_stop_token
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
# stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
|
||||
|
||||
if metrics:
|
||||
from ..plot import plot_sample_metrics
|
||||
filename = "metrics"
|
||||
if sampling_entropix:
|
||||
filename += f'[entropix]'
|
||||
if sampling_layer_skip_exit_layer >= 0:
|
||||
filename += f'[{sampling_layer_skip_exit_layer+1}]'
|
||||
|
||||
plot_sample_metrics( metrics, filename=f'{filename}.png' )
|
||||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
if sampling_beam_width:
|
||||
sequence_list = [ sequence_list[0] ]
|
||||
sequence_list = sequence_list[:1]
|
||||
task_list = task_list[:1]
|
||||
|
||||
sequence_list = [self._prune(r, stop_token) for r in sequence_list]
|
||||
# remove stop token
|
||||
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||
# remove <bos>
|
||||
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
|
||||
|
||||
for i, seq in enumerate( sequence_list ):
|
||||
steps = seq.shape[0] // self.n_resp_levels
|
||||
nearest_steps = steps * self.n_resp_levels
|
||||
sequence_list[i] = seq[:nearest_steps].view(( steps, self.n_resp_levels ))
|
||||
if sampling_refine_on_stop:
|
||||
# get how much we need to slice from the end
|
||||
slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ]
|
||||
# -1 for the stop token
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
|
||||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
# to-do: compare scores
|
||||
# set the "refined" list as the output
|
||||
sequence_list = refined_list
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
|
|
@ -1126,25 +1126,27 @@ class Base(nn.Module):
|
|||
|
||||
embedding = _interleave_sequence_reshape( embeddings )
|
||||
elif "len" in self.capabilities and quant_level == 0:
|
||||
if input_prom is not None:
|
||||
# fill with the prom as the initial condition
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||
assert input_prom is not None, "Guru mediating during training"
|
||||
# fill with the prom as the initial condition
|
||||
repeat = (input.shape[0] // input_prom.shape[0]) + 1
|
||||
repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1]
|
||||
|
||||
embedding = self.resps_emb(
|
||||
repeated,
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
"""
|
||||
# fill with "stop" token from the len layer for the NAR-only model
|
||||
filler_token = 12
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
"""
|
||||
|
||||
embedding = self.resps_emb(
|
||||
repeated,
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
else:
|
||||
# fill with "stop" token from the len layer for the NAR-only model
|
||||
filler_token = 12
|
||||
embedding = self.resps_emb(
|
||||
# self.dropout_token.repeat((input.shape[0], 1)),
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token),
|
||||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
|
|
|
@ -21,6 +21,9 @@ from tqdm import trange
|
|||
from ..emb.qnt import trim
|
||||
import logging
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
class NAR(Base):
|
||||
|
@ -58,119 +61,172 @@ class NAR(Base):
|
|||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
sampling_entropix=False,
|
||||
|
||||
sampling_layer_skip: bool = False,
|
||||
sampling_layer_skip_exit_layer: int = -1,
|
||||
sampling_layer_skip_entropy_threshold: float = -1,
|
||||
sampling_layer_skip_varentropy_threshold: float = -1,
|
||||
|
||||
sampling_refine_on_stop: bool = False,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
text_task = [ "stt" ]
|
||||
|
||||
# is training
|
||||
if resps_list is not None:
|
||||
len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
|
||||
if text_list is not None:
|
||||
default_task = "tts"
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
else:
|
||||
default_task = "stt"
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
# generate task list if not provided
|
||||
if task_list is None:
|
||||
task_list = [ default_task for _ in range(batch_size) ]
|
||||
|
||||
has_none = resps_list is None or text_list is None
|
||||
if not has_none:
|
||||
for i, task in enumerate( task_list ):
|
||||
if resps_list[i] is None or text_list[i] is None:
|
||||
has_none = True
|
||||
break
|
||||
|
||||
# is training or NAR
|
||||
if not has_none:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# assert n_levels == self.n_resp_levels
|
||||
# implicit
|
||||
if training is None:
|
||||
training = 0 if n_levels == self.n_resp_levels else None
|
||||
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "len" if random.random() < len_train_p else "tts"
|
||||
# is training
|
||||
if training is not None:
|
||||
len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
|
||||
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||
# rate to perform token dropout errors
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
|
||||
if not rvq_levels_p:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
if rvq_levels_p == "equal":
|
||||
rvq_levels_p = [ i for i in range( lo, hi ) ]
|
||||
else:
|
||||
# yuck
|
||||
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
# assert n_levels == self.n_resp_levels
|
||||
|
||||
# input RVQ levels
|
||||
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "len" if random.random() < len_train_p else "tts"
|
||||
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
# specifies how to sample probabilities of which RVQ levels to train against
|
||||
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
|
||||
# rate to perform token dropout errors
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
# RVQ levels to apply token dropout on
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# implicitly set it to all levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels - 1]
|
||||
# allow passing a specific distribution of RVQ levels
|
||||
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
|
||||
if not rvq_levels_p:
|
||||
lo, hi = quant_level_range[0], quant_level_range[1] + 1
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
if rvq_levels_p == "equal":
|
||||
rvq_levels_p = [ i for i in range( lo, hi ) ]
|
||||
else:
|
||||
# yuck
|
||||
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
|
||||
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
# input RVQ levels
|
||||
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
|
||||
for i, task in enumerate( task_list ):
|
||||
if task in text_task:
|
||||
quant_levels[i] = 0 # self.n_resp_levels - 1
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
# I hate python's value/reference semantics so much
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0:
|
||||
# append stop tokens for AR
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
...
|
||||
else:
|
||||
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
...
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# NAR
|
||||
if len_list is not None:
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
# fill with mock tokens
|
||||
# to-do: repeat with the input prompt, as per training
|
||||
prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
|
||||
start = True
|
||||
# to-do: figure out why this fails when I copy some things from ar_nar
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
level = 0 if n == 0 else prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||
|
||||
inputs = self.inputs(
|
||||
|
@ -185,19 +241,17 @@ class NAR(Base):
|
|||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
logits = output.logits
|
||||
|
||||
"""
|
||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||
"""
|
||||
# layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=1.0 if n == 0 else sampling_temperature,
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
|
@ -218,6 +272,9 @@ class NAR(Base):
|
|||
return prev_list
|
||||
|
||||
# is AR
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user