diff --git a/docs/models.md b/docs/models.md index 625e4de..44c3500 100644 --- a/docs/models.md +++ b/docs/models.md @@ -36,6 +36,7 @@ Non-autoregressive trainng is performed by having the input tokens from the prev However, having a pure NAR is challenging, as you need to both explicitly provide the duration and provide a "good enough" starting sequence of tokens for the initial sequence. * The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively. * The latter however proves to be a bit of a challenge, as this could be anything from random noise to a unique token. + * The current implementation repeats the input prompt's RVQ level 0 as the initial condition, but inferencing fills with stop tokens. This *might* be the problem, but I do not have my `nar-len-llama-8` weights stored anywhere, sadly. * Testing showed that it's easy to predict the duration, but decoding the first RVQ level accurately proves to be a chore. * Initially, output seemed chaotic and unreliable, but further experiments showed the model will "work" for a brief moment before going silent. @@ -48,6 +49,8 @@ One problem exhibited from a NAR is producing arfifacts ("crust") in the final w The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can geta way with treating each sequence as separate ranges within a total token sequence. +While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed. + ### Text Embeddings The input text phonemes (or output for STT) are passed through an embedding head (`text`), similar to how a normal text LLM would. Nothing fancy is required, as it's very straightforward. diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index e17763c..d25df2e 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -61,29 +61,50 @@ class AR(Base): sampling_dry_base=1.75, sampling_dry_allowed_length=2, sampling_entropix=False, + sampling_layer_skip: bool = False, sampling_layer_skip_exit_layer: int = -1, + sampling_layer_skip_entropy_threshold: float = -1, + sampling_layer_skip_varentropy_threshold: float = -1, + + sampling_refine_on_stop: bool = False, disable_tqdm=False, use_lora=None, ): - device = text_list[0].device - batch_size = len(text_list) - + text_task = [ "stt" ] + + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + # generate task list if not provided if task_list is None: - task_list = [ "tts" for _ in range(batch_size) ] + task_list = [ default_task for _ in range(batch_size) ] + + has_none = resps_list is None or text_list is None + if not has_none: + for i, task in enumerate( task_list ): + if resps_list[i] is None or text_list[i] is None: + has_none = True + break # is training or NAR - if resps_list is not None: + if not has_none: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) + # implicit if training is None: - training = n_levels == self.n_resp_levels + training = 0 if n_levels == self.n_resp_levels else None # is training - if training: + if training is not None: # specifies how to sample probabilities of which RVQ levels to train against rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" # determines which RVQ level to target per batch @@ -107,16 +128,19 @@ class AR(Base): rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) # input RVQ levels - if not self.interleave: - quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] - # trim resps to only contain all levels below the target level - resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] - else: - quant_levels = [ 0 for i in range(batch_size) ] + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] + for i, task in enumerate( task_list ): + if task in text_task: + quant_levels[i] = 0 # self.n_resp_levels - 1 + + # trim resps to only contain all levels below the target level + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] # tensor to cat for RVQ level 0 + text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) # I hate python's value/reference semantics so much - for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): + for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): # cap quant_level if it exceeds its corresponding resp/prom if quant_level >= resps.shape[-1]: quant_levels[i] = resps.shape[-1] - 1 @@ -130,9 +154,8 @@ class AR(Base): for j, prom in enumerate( proms ): if not isinstance( prom, torch.Tensor ): continue - - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 # apply token dropout error compensation if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): @@ -146,9 +169,13 @@ class AR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - stop_sequence = torch.tensor([[self.stop_token] * resps.shape[-1]], device=device, dtype=torch.int16) - resps_list[i] = torch.cat([ resps, stop_sequence ]) - + if quant_level <= 0: + # append stop tokens for AR + if task in text_task: + #text_list[i] = torch.cat([ resps, text_stop_sequence ]) + ... + else: + resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) inputs = self.inputs( text_list=text_list, @@ -157,21 +184,26 @@ class AR(Base): lang_list=lang_list, tone_list=tone_list, task_list=task_list, + + quant_levels=quant_levels, ) return super().forward( inputs=inputs, + quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token ) # is AR if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + # STT + start_slice = [ 0 for _ in range(batch_size) ] sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() - stop_token = self.stop_token - + audio_stop_token = self.stop_token + text_stop_token = 2 state = None mirostat = [ @@ -179,10 +211,59 @@ class AR(Base): ] * batch_size if sampling_mirostat_tau > 0.0 else None scores = [ 1.0 ] * sampling_beam_width + metrics = [] + + # ick + """ + low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 # + low_temperature_range = cfg.dataset.frames_per_second * 5 + + original_sampling_temperature = sampling_temperature + original_sampling_repetition_penalty = sampling_repetition_penalty + original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay + """ + + sampling_layer_skip_variables = {} if sampling_layer_skip else None + + if sampling_layer_skip: + if sampling_layer_skip_entropy_threshold >= 0: + sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold + if sampling_layer_skip_varentropy_threshold >= 0: + sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold + if sampling_layer_skip_exit_layer >= 0: + sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer + + for i, sequence in enumerate( sequence_list ): + # add to text for STT + if task_list[i] in text_task: + start_slice[i] = 1 + sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)]) + # treat input prompt as initial resp (by prefixing with the prompt instead) + elif input_prompt_prefix: + start_slice[i] = proms_list[i].shape[0] + sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i] + elif prefix_silence > 0: + sequence_list[i] = get_silence(prefix_silence, device=sequence_list[i].device) + sequence_list[i] = sequence_list[i][:, 0] + # start_slice[i] = sequence_list[i].shape[0] # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] + # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + + # greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence + # naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio + # however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled + # to-do: tune these values, maybe have it factor based on confidence scores or something + """ + if low_temperature: + enabled = n < low_temperature_range + sampling_repetition_penalty = 1.125 if enabled else 1.25 + #sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay + #sampling_temperature = original_sampling_temperature if enabled else 1.0 + """ inputs = self.inputs( text_list=text_list, @@ -195,15 +276,20 @@ class AR(Base): quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) + # to-do: find an elegant way to write this output = super().forward( inputs=inputs, state=state, + + layer_skip_variables=sampling_layer_skip_variables, + + output_attentions=sampling_entropix, ) logits, state = output.logits, output.state sampled = super().sample( logits=logits, - prev_list=resps_list, + prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], temperature=sampling_temperature, min_temperature=sampling_min_temperature, @@ -220,47 +306,81 @@ class AR(Base): dry_multiplier=sampling_dry_multiplier, dry_base=sampling_dry_base, dry_allowed_length=sampling_dry_allowed_length, + + attentions=output.attentions if sampling_entropix else None, ) r = sampled[0] + if cfg.experimental: + if sampled.entropy: + metrics.append( sampled.entropy ) + elif sampled.scores: + metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] ) + if mirostat is not None: mirostat = sampled.scores elif sampling_beam_width > 0: # expand tuple - scores = sampled.scores + s = sampled.scores # first step, expand batch if batch_size == 1: batch_size = sampling_beam_width text_list = text_list * sampling_beam_width proms_list = proms_list * sampling_beam_width sequence_list = sequence_list * sampling_beam_width + task_list = task_list * sampling_beam_width + start_slice = start_slice * sampling_beam_width stopped = torch.zeros(batch_size, device=device).bool() - scores = [ scores[i] + score for i, score in enumerate(scores) ] + scores = [ scores[i] + score for i, score in enumerate(s) ] # append tokens for i, ri in enumerate(r): + task = task_list[i] + stop_token = audio_stop_token if task not in text_task else text_stop_token if stop_token in ri: stopped[i] = True sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) # stop token found - stopped |= r == stop_token + # stopped |= r == stop_token if stopped.all().item(): break + # to-do for layerskip / speculative sampling: rerun the last sequence again at max depth + + if metrics: + from ..plot import plot_sample_metrics + filename = "metrics" + if sampling_entropix: + filename += f'[entropix]' + if sampling_layer_skip_exit_layer >= 0: + filename += f'[{sampling_layer_skip_exit_layer+1}]' + + plot_sample_metrics( metrics, filename=f'{filename}.png' ) + # pick the best scoring candidate # desu this is always going to be candidate 0 if sampling_beam_width: - sequence_list = [ sequence_list[0] ] + sequence_list = sequence_list[:1] + task_list = task_list[:1] - sequence_list = [self._prune(r, stop_token) for r in sequence_list] + # remove stop token + sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)] + # remove + sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ] - for i, seq in enumerate( sequence_list ): - steps = seq.shape[0] // self.n_resp_levels - nearest_steps = steps * self.n_resp_levels - sequence_list[i] = seq[:nearest_steps].view(( steps, self.n_resp_levels )) + if sampling_refine_on_stop: + # get how much we need to slice from the end + slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ] + # -1 for the stop token + logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ] + # greedy sample from the sequence + refined_list = [ logit.argmax(dim=-1) for logit in logits ] + # to-do: compare scores + # set the "refined" list as the output + sequence_list = refined_list return sequence_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 80078ec..39a9965 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1126,25 +1126,27 @@ class Base(nn.Module): embedding = _interleave_sequence_reshape( embeddings ) elif "len" in self.capabilities and quant_level == 0: - if input_prom is not None: - # fill with the prom as the initial condition - repeat = (input.shape[0] // input_prom.shape[0]) + 1 - repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1] + assert input_prom is not None, "Guru mediating during training" + # fill with the prom as the initial condition + repeat = (input.shape[0] // input_prom.shape[0]) + 1 + repeated = input_prom[:, :1].repeat((repeat, 1))[:input.shape[0], :1] + + embedding = self.resps_emb( + repeated, + offset = 0, + quant_level = 0, + ) + """ + # fill with "stop" token from the len layer for the NAR-only model + filler_token = 12 + embedding = self.resps_emb( + # self.dropout_token.repeat((input.shape[0], 1)), + torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token), + offset = 0, + quant_level = 0, + ) + """ - embedding = self.resps_emb( - repeated, - offset = 0, - quant_level = 0, - ) - else: - # fill with "stop" token from the len layer for the NAR-only model - filler_token = 12 - embedding = self.resps_emb( - # self.dropout_token.repeat((input.shape[0], 1)), - torch.full_like(input if input.dim() == 1 else input[..., 0], filler_token), - offset = 0, - quant_level = 0, - ) # cheat-y way to handle performing STT across all levels elif task_type in summed_embeddings_task: # we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT...... diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index e2a4734..8479d2e 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -21,6 +21,9 @@ from tqdm import trange from ..emb.qnt import trim import logging +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + _logger = logging.getLogger(__name__) class NAR(Base): @@ -58,119 +61,172 @@ class NAR(Base): sampling_dry_base=1.75, sampling_dry_allowed_length=2, sampling_entropix=False, + sampling_layer_skip: bool = False, sampling_layer_skip_exit_layer: int = -1, + sampling_layer_skip_entropy_threshold: float = -1, + sampling_layer_skip_varentropy_threshold: float = -1, + + sampling_refine_on_stop: bool = False, disable_tqdm=False, use_lora=None, ): - device = text_list[0].device - batch_size = len(text_list) + text_task = [ "stt" ] - # is training - if resps_list is not None: - len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05 + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + # generate task list if not provided + if task_list is None: + task_list = [ default_task for _ in range(batch_size) ] + + has_none = resps_list is None or text_list is None + if not has_none: + for i, task in enumerate( task_list ): + if resps_list[i] is None or text_list[i] is None: + has_none = True + break + + # is training or NAR + if not has_none: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) - # assert n_levels == self.n_resp_levels + # implicit + if training is None: + training = 0 if n_levels == self.n_resp_levels else None - # to-do: make this YAML configurable - def sample_task(): - return "len" if random.random() < len_train_p else "tts" + # is training + if training is not None: + len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05 - # generate task list to train against - task_list = [ sample_task() for _ in range(batch_size) ] + n_levels_set = {r.shape[-1] for r in resps_list} + n_levels = next(iter(n_levels_set)) - # specifies how to sample probabilities of which RVQ levels to train against - rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] - # rate to perform token dropout errors - token_dropout_error = self.config.experimental.token_dropout_error - # RVQ levels to apply token dropout on - token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels - # implicitly set it to all levels - if not token_dropout_rvq_levels: - token_dropout_rvq_levels = [0, self.resp_levels - 1] - # allow passing a specific distribution of RVQ levels - rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] - if not rvq_levels_p: - lo, hi = quant_level_range[0], quant_level_range[1] + 1 - # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - if rvq_levels_p == "equal": - rvq_levels_p = [ i for i in range( lo, hi ) ] - else: - # yuck - rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + # assert n_levels == self.n_resp_levels - # input RVQ levels - quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] - # trim resps to only contain all levels below the target level - resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] + # to-do: make this YAML configurable + def sample_task(): + return "len" if random.random() < len_train_p else "tts" - # I hate python's value/reference semantics so much - for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): - # cap quant_level if it exceeds its corresponding resp/prom - if quant_level >= resps.shape[-1]: - quant_levels[i] = resps.shape[-1] - 1 + # generate task list to train against + task_list = [ sample_task() for _ in range(batch_size) ] - # proms could be a Tensor, list[Tensor], or None - if isinstance( proms, torch.Tensor ): - if quant_level >= proms.shape[-1]: - quant_levels[i] = proms.shape[-1] - 1 + # specifies how to sample probabilities of which RVQ levels to train against + rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal" + # determines which RVQ level to target per batch + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors + token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # implicitly set it to all levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels - 1] + # allow passing a specific distribution of RVQ levels + rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else [] + if not rvq_levels_p: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + if rvq_levels_p == "equal": + rvq_levels_p = [ i for i in range( lo, hi ) ] + else: + # yuck + rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) - elif isinstance( proms, list ): - for j, prom in enumerate( proms ): - if not isinstance( prom, torch.Tensor ): - continue - - if quant_level >= prom.shape[-1]: - quant_levels[i] = prom.shape[-1] - 1 + # input RVQ levels + quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ] + for i, task in enumerate( task_list ): + if task in text_task: + quant_levels[i] = 0 # self.n_resp_levels - 1 + + # trim resps to only contain all levels below the target level + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] - # apply token dropout error compensation - if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): - steps = resps.shape[0] - for l in range( quant_level ): - for t in range( steps ): - token = resps[t, l].item() + # tensor to cat for RVQ level 0 + text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) + # I hate python's value/reference semantics so much + for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): + # cap quant_level if it exceeds its corresponding resp/prom + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 - if random.random() < token_dropout_error: - offset = 1 * ( 1 if random.random() < 0.5 else -1 ) - resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 - inputs = self.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - lang_list=lang_list, - tone_list=tone_list, - task_list=task_list, + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): + if not isinstance( prom, torch.Tensor ): + continue + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 - quant_levels=quant_levels, - ) + # apply token dropout error compensation + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() + + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 + + # only apply stop token for RVQ level 0 + if quant_level <= 0: + # append stop tokens for AR + if task in text_task: + #text_list[i] = torch.cat([ resps, text_stop_sequence ]) + ... + else: + #resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) + ... + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + task_list=task_list, + + quant_levels=quant_levels, + ) + + return super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) - return super().forward( - inputs=inputs, - quant_levels=quant_levels, - ) - # NAR if len_list is not None: # is NAR if max_levels == 0: max_levels = self.n_resp_levels # fill with mock tokens + # to-do: repeat with the input prompt, as per training prev_list = [ torch.tensor([ self.stop_token for _ in range(resp_len) ], device=device, dtype=torch.int16) for resp_len in len_list ] - start = True + # to-do: figure out why this fails when I copy some things from ar_nar for n in trange( max_levels, desc="NAR", disable=disable_tqdm ): level = 0 if n == 0 else prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) + quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) inputs = self.inputs( @@ -185,19 +241,17 @@ class NAR(Base): output = super().forward( inputs=inputs, quant_levels=quant_levels, - ) - logits = output.logits - """ - resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] - """ + # layer_skip_variables=sampling_layer_skip_variables, + ) + logits, state = output.logits, output.state sampled = super().sample( logits=logits, prev_list=prev_list, quant_levels=quant_levels, - temperature=1.0 if n == 0 else sampling_temperature, + temperature=sampling_temperature, min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, @@ -218,6 +272,9 @@ class NAR(Base): return prev_list # is AR + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora ) + sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool()