unified nar.py into ar_nar.py

This commit is contained in:
mrq 2024-11-10 12:19:48 -06:00
parent a9d2faf2d7
commit 9cb0b6901b
11 changed files with 689 additions and 1710 deletions

View File

@ -233,34 +233,13 @@ This script aims to implement everything as required per VALL-E agnostically, to
## `models/ar_nar.py`
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively.
By default, this is the default model, but is used through `cfg.model.capabilities = ["ar", "nar"]`.
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.
* Since one model can be trained AR-ly and NAR-ly, RVQ-level 0 can also be trained non-autoregressively with diffusion-like masking.
For training, this model handles preparing the batch provided through the dataloader according to a randomly sampled targetted RVQ-level.
For inferencing, this will dynamically inference depending on the arguments provided.
## `models/ar.py`
This script implements VALL-E as a pure autoregressive (AR) model.
If `cfg.model.experimental.interleave=True`, this makes use of interleaving its audio codes, instead of inferencing per-codebook level. If not, this simply attends to RVQ level 0.
This model serves as an experiment that failed, and might be revisited in the future.
Use of this is governed through `cfg.model.capabilities = ["ar"]`
## `models/nar.py`
This script implements VALL-E as a mostly-pure non-autoregresive model, where it infers the duration autoregressively (if `"len" in cfg.model.capabilities`). If not, this simply attends to RVQ levels 1+.
This makes use of training an additional `len` task that can infer the duration of a requested input, as well as (maybe) using special tokens as the initial input for RVQ-level 0 (the level the AR attends to).
This model serves as an experiment that failed, and might be revisited in the future.
Use of this is governed through `cfg.model.capabilities = ["nar"]`
## `models/experimental.py`
This script implements VALL-E as a mostly-HuggingFace compatible model, where it handles processing tokens as a uniform sequence of IDs.

View File

@ -255,13 +255,13 @@ class ModelExperimentalSettings:
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
# classifier-free guidance shit
cfg_cond_dropout_p: float = 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.3 # probability to drop out input audio prompt during training
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
cfg_prom_dropout_p: float = 0.0 # 0.3 # probability to drop out input audio prompt during training
layerskip: bool = False # layerskip compatible model (or training for)
#layerskip_rvq_levels: list = field(default_factory=lambda: []) # RVQ levels to train / inference layerskip for (to-do: implement, see if it matters)
@ -757,6 +757,7 @@ class Config(BaseConfig):
device: str = "cuda" # target device
mode: str = "training" # "inferencing"
experimental: bool = False # debug flag
silent_errors: bool = False # if False, raise exceptions on errors that could silently lead to problems, if True ignore them
dataset: Dataset = field(default_factory=lambda: Dataset)
models: dict | list | None = field(default_factory=lambda: [])
@ -879,7 +880,12 @@ class Config(BaseConfig):
if data_parent.exists():
return [ path.parent / child.name for child in Path(data_parent).glob(path.name) ]
return path
# return an empty list
if self.silent_errors:
return []
# raise an error to avoid headaches
raise Exception(f'Cannot unglob requested path: {path}')
def format( self, training=True ):
@ -957,10 +963,6 @@ class Config(BaseConfig):
model["experimental"]["rvq_levels_p"] = model["experimental"]["p_rvq_levels"]
del model["experimental"]["p_rvq_levels"]
if "p_len_train" in model["experimental"]:
model["experimental"]["len_train_p"] = model["experimental"]["p_len_train"]
del model["experimental"]["p_len_train"]
self.models = [ Model(**model) if isinstance(model, dict) else model for model in self.models ]
self.loras = [ LoRA(**lora) if isinstance(lora, dict) else lora for lora in self.loras ]
@ -999,22 +1001,17 @@ class Config(BaseConfig):
if self.tokenizer == "naive":
self.tokenizer = NaiveTokenizer()
else:
# ick...
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = self.rel_path / self.tokenizer_path
if tokenizer_path and not tokenizer_path.exists():
# deduce path if a local copy is not provided
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / self.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
if not self.silent_errors and not tokenizer_path.exists():
raise Exception(f'Tokenizer path not found: {tokenizer_path}')
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
self.tokenizer = NaiveTokenizer()
except Exception as e:
self.tokenizer = NaiveTokenizer()
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass
# Preserves the old behavior
@ -1071,8 +1068,9 @@ cfg = Config.from_cli()
try:
cfg.format()
except Exception as e:
_logger.error(f"Error while parsing config YAML: {str(e)}")
if not cfg.silent_errors:
raise e # throw an error because I'm tired of silent errors messing things up for me
_logger.error(f"Error while parsing config YAML: {str(e)}")
if __name__ == "__main__":
print(cfg)

View File

@ -1199,6 +1199,10 @@ class Dataset(_Dataset):
task
]
# Duration prediction (<text><prompt> => len(<resp>))
elif task == "len":
proms = self.sample_prompts(spkr_name, reference=path)
# noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>)
elif task == "ns" or task == "sr":

View File

@ -193,7 +193,7 @@ def load_engines(training=True, **model_kwargs):
("text_emb.weight", model.config.text_tokens ),
("tasks_emb.weight", model.config.tasks ),
("langs_emb.weight", model.config.langs ),
("rvq_l_emb.weight", model.config.resp_levels + (1 if "len" in model.config.capabilities else 0) ),
("rvq_l_emb.weight", model.config.resp_levels ),
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),

View File

@ -49,11 +49,8 @@ class TTS():
else:
raise Exception(f"Unknown config passed: {config}")
try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp
@ -268,7 +265,7 @@ class TTS():
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
text_list = model_ar(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, task_list=["stt"],
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
@ -318,7 +315,7 @@ class TTS():
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps,
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, task_list=["tts"],
input_prompt_prefix=input_prompt_prefix,
prefix_silence=prefix_silence,
sampling_temperature=ar_temp,
@ -343,7 +340,7 @@ class TTS():
use_lora=use_lora,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
input_prompt_prefix=input_prompt_prefix,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
@ -359,8 +356,8 @@ class TTS():
use_lora=use_lora,
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that
len_list = [ clamp(1, max_ar_steps, l) for l in len_list ]
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that
len_list = [ clamp(l, 1, max_ar_steps) for l in len_list ]
kwargs = {}
@ -375,7 +372,7 @@ class TTS():
kwargs["resps_list"] = [ resp[:, :1] ]
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
max_steps=max_ar_steps,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,

View File

@ -60,25 +60,7 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
def get_model(config, training=True, **model_kwargs):
name = config.name
if "len" in config.capabilities:
from .nar import NAR
model = NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
**model_kwargs
)
elif config.experimental.hf:
if config.experimental.hf:
from .experimental import Model as Experimental
model = Experimental(
n_text_tokens=config.text_tokens,

View File

@ -1,638 +0,0 @@
"""
# an AR model that (should) handle:
* handling all RVQ levels, but does it in an autoregressive manner
It's in a mess of a state, because I want this to be an interleaved model, but it just seems better to use the vall_e.models.experimental model.
"""
from .base import Base, list_to_tensor, Categorical
from ..config import cfg
import torch
from torch.nn.utils.rnn import pad_sequence
import random
import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
import logging
_logger = logging.getLogger(__name__)
from ..utils import clamp
from ..emb.qnt import trim, encode_as_embedding
from .lora import enable_lora
class AR(Base):
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
text_task = [ "stt" ]
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
has_none = resps_list is None or text_list is None
if not has_none:
for i, task in enumerate( task_list ):
if resps_list[i] is None or text_list[i] is None:
has_none = True
break
# is training or NAR
if not has_none:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# implicit
if training is None:
training = 0 if n_levels == self.n_resp_levels else None
# is training
if training is not None:
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
# yuck
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
for i, task in enumerate( task_list ):
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance( proms, torch.Tensor ):
if quant_level >= proms.shape[-1]:
quant_levels[i] = proms.shape[-1] - 1
elif isinstance( proms, list ):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = resps.shape[0]
for l in range( quant_level ):
for t in range( steps ):
token = resps[t, l].item()
if random.random() < token_dropout_error:
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0:
# append stop tokens for AR
if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
...
else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
task_list=task_list,
quant_levels=quant_levels,
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token
)
# is AR
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# STT
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
audio_stop_token = self.stop_token
text_stop_token = 2
state = None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
scores = [ 1.0 ] * sampling_beam_width
metrics = []
# ick
"""
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
low_temperature_range = cfg.dataset.frames_per_second * 5
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
if task_list[i] in text_task:
start_slice[i] = 1
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
# treat input prompt as initial resp (by prefixing with the prompt instead)
elif input_prompt_prefix:
start_slice[i] = proms_list[i].shape[0]
sequence_list[i], proms_list[i] = proms_list[i][:, 0], sequence_list[i]
elif prefix_silence > 0:
sequence_list[i] = get_silence(prefix_silence, device=sequence_list[i].device)
sequence_list[i] = sequence_list[i][:, 0]
# start_slice[i] = sequence_list[i].shape[0]
# get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
# greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
"""
if low_temperature:
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.125 if enabled else 1.25
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
#sampling_temperature = original_sampling_temperature if enabled else 1.0
"""
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
# to-do: find an elegant way to write this
output = super().forward(
inputs=inputs,
state=state,
layer_skip_variables=sampling_layer_skip_variables,
output_attentions=sampling_entropix,
)
logits, state = output.logits, output.state
sampled = super().sample(
logits=logits,
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
dry_multiplier=sampling_dry_multiplier,
dry_base=sampling_dry_base,
dry_allowed_length=sampling_dry_allowed_length,
attentions=output.attentions if sampling_entropix else None,
)
r = sampled[0]
if cfg.experimental:
if sampled.entropy:
metrics.append( sampled.entropy )
elif sampled.scores:
metrics.append( [ { "p": p[0], "exited_layer": output.exited_layer } for p in sampled.scores ] )
if mirostat is not None:
mirostat = sampled.scores
elif sampling_beam_width > 0:
# expand tuple
s = sampled.scores
# first step, expand batch
if batch_size == 1:
batch_size = sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
task_list = task_list * sampling_beam_width
start_slice = start_slice * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
scores = [ scores[i] + score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
task = task_list[i]
stop_token = audio_stop_token if task not in text_task else text_stop_token
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
# stopped |= r == stop_token
if stopped.all().item():
break
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
if metrics:
from ..plot import plot_sample_metrics
filename = "metrics"
if sampling_entropix:
filename += f'[entropix]'
if sampling_layer_skip_exit_layer >= 0:
filename += f'[{sampling_layer_skip_exit_layer+1}]'
plot_sample_metrics( metrics, filename=f'{filename}.png' )
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width:
sequence_list = sequence_list[:1]
task_list = task_list[:1]
# remove stop token
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
# remove <bos>
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
if sampling_refine_on_stop:
# get how much we need to slice from the end
slice_lengths = [ sequence.shape[-1] for sequence in sequence_list ]
# -1 for the stop token
logits = [ logit[-length-1:-1] for logit, length in zip(logits, slice_lengths) ]
# greedy sample from the sequence
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
sequence_list = refined_list
return sequence_list
def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..engines import Engine, Engines
from ..utils import wrapper as ml
import numpy as np
import re
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
]
proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
batch_size = len(text_list)
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1,
'p_dropout': 0.1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model
}
"""
try:
kwargs['config'] = cfg.model
except Exception as e:
pass
"""
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
tasks = cfg.dataset.tasks_list
model = AR(**kwargs).to(device)
steps = 75 * len(tasks) * cfg.model.experimental.causal_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
"""
cfg.optimizations.model_offloading = {
"devices": ["cuda:0", "cpu"],
# "limits": [ 0.9, -1 ],
"assign": [[ f'layers.{i}.' for i in range(0,10) ], [ f'layers.{i}.' for i in range(11,12) ] + [ "model.norm" ]],
# "limits": [ 256 * (1024 ** 2), -1 ]
}
"""
engine = Engine(model=model, optimizer=optimizer)
engines = Engines({"ar": engine})
engines.setup()
"""
if cfg.optimizations.model_offloading:
model = ml.offload_model( model, policy=cfg.optimizations.model_offloading )
"""
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
_logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad()
def sample_data(task=None):
texts = []
proms = []
resps = []
for i in range(batch_size):
if task is None:
task = random.choice(tasks)
text = text_list[i]
prom = proms_list[i]
resp = resps_list[i]
# do nothing
if task == "tts":
...
elif task == "tts-c":
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
prom = resp[:trim_length]
resp = resp[trim_length:]
elif task == "ns" or task == "sr":
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
# set the target to just be the noise if <sr>
if task == "sr":
resp = noise_ext
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
texts.append( text.to(device) )
proms.append( prom.to(device) )
resps.append( resp.to(device) )
return texts, proms, resps
@torch.inference_mode()
def sample( name, steps=1000, task=None ):
engine.eval()
texts, proms, resps = sample_data( task )
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(steps)
for i in t:
texts, proms, resps = sample_data()
stats = {"step": i}
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()
"""
if cfg.optimizations.compile:
model = ml.compile_model(model, backend=cfg.optimizations.compile)
"""
for task in tasks:
sample("final", task=task)
engines.quit()
if __name__ == "__main__":
example_usage()

View File

@ -28,54 +28,21 @@ from ..utils import get_devices, setup_logging, timer, clamp
from .lora import enable_lora
text_task = [ "stt" ]
class AR_NAR(Base):
def forward(
def forward_train(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
resps_list: list[Tensor],
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
text_task = [ "stt" ]
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
@ -85,28 +52,6 @@ class AR_NAR(Base):
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
has_none = resps_list is None or text_list is None
if not has_none:
for i, task in enumerate( task_list ):
if resps_list[i] is None or text_list[i] is None:
has_none = True
break
# is training or NAR
if not has_none:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# implicit
if training is None:
training = 0 if n_levels == self.n_resp_levels else None
# is training
if training is not None:
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
@ -115,9 +60,27 @@ class AR_NAR(Base):
token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
# force set mask training
if "len" not in self.capabilities:
masking_train_rvq_levels = 0.0
elif "ar" not in self.capabilities:
masking_train_rvq_levels = 1.0
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# rate to train RVQ level AR-ly or NAR-ly
masking_train_p = self.config.experimental.masking_train_p if self.config is not None else 0.5
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, 0]
# allow passing a specific distribution of RVQ levels
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
@ -131,16 +94,23 @@ class AR_NAR(Base):
# input RVQ levels
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
# timestep levels (for TTS NAR)
timesteps = [ None for _ in range(batch_size) ]
for i, task in enumerate( task_list ):
lo, hi = masking_train_rvq_levels[0], masking_train_rvq_levels[1]
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1
elif lo <= quant_levels[i] and quant_levels[i] <= hi and random.random() < masking_train_p:
timesteps[i] = random.random()
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
@ -179,6 +149,24 @@ class AR_NAR(Base):
else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + ["len"]:
drop_text = False
drop_audio = False
if random.random() < cfg_prom_dropout_p:
drop_audio = True
if random.random() < cfg_cond_dropout_p:
drop_audio = True
drop_text = True
if drop_text:
text_list[i] = text_start_stop_sequence
if drop_audio:
proms_list[i] = None
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
@ -186,26 +174,75 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
task_list=task_list,
time_list=timesteps,
quant_levels=quant_levels,
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels, # could technically just grab this from the above inputs since they're included as an RVQ level token
quant_levels=quant_levels,
)
# is NAR
def forward_nar(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
if max_levels == 0:
max_levels = self.n_max_levels - 1
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
if resp.dim() == 1:
resps_list[i] = resp.unsqueeze(-1)
prev_list = resps_list
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if sampling_layer_skip:
@ -216,6 +253,165 @@ class AR_NAR(Base):
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# inference NAR level 0
if len_list is not None:
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ]
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
_super = super()
def demask_sampling( batch_index, seq_len ):
# overrides
max_steps = 10
temperature = 0.3
cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
# if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None:
start_noise = denoise_start
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device )
input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] )
else:
input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = start_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
# boolean mask
is_masked = input_ids == self.stop_token
# setup inputs
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = _super.inputs(
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
null_output = _super.forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for logit, null_logits in zip(output.logits, null_output.logits):
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
# sample with sampler settings
filtered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
)
# retrieves unfiltered logits
unfiltered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0,
)
# update previous list of tokens
prev_list = [ input_ids ]
# extract logits
filtered_logits = filtered_sampled.logits[0]
unfiltered_logits = unfiltered_sampled.logits[0]
# extract scores
filtered_scores = filtered_sampled.scores[0]
unfiltered_scores = unfiltered_sampled.scores[0]
# extract sampled tokens
filtered_tokens = filtered_sampled[0][0]
unfiltered_tokens = unfiltered_sampled[0][0]
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
#sampled_ids = filtered_tokens
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
# perform demasked sampling (mock diffusion)
resps_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ]
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
if resp.dim() == 1:
resps_list[i] = resp.unsqueeze(-1)
prev_list = resps_list
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
@ -238,8 +434,7 @@ class AR_NAR(Base):
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits, state = output.logits, output.state
@ -265,10 +460,115 @@ class AR_NAR(Base):
return prev_list
# is AR
def forward_ar(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# inference len
if task_list is not None and task_list[0] == "len":
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = 10
task_list = [ "len" for _ in range(batch_size) ]
quant_levels = [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
for n in trange(10, desc="AR", disable=disable_tqdm):
len_list = sequence_list
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
logits = output.logits
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i][0] = stop_token
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
# convert tokens into int
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
# STT
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
@ -352,9 +652,7 @@ class AR_NAR(Base):
output = super().forward(
inputs=inputs,
state=state,
layer_skip_variables=sampling_layer_skip_variables,
#layer_skip_variables=sampling_layer_skip_variables,
output_attentions=sampling_entropix,
)
logits, state = output.logits, output.state
@ -457,10 +755,144 @@ class AR_NAR(Base):
return sequence_list
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
kwargs = dict(
max_steps=max_steps,
max_levels=max_levels,
input_prompt_prefix=input_prompt_prefix,
prefix_silence=prefix_silence,
denoise_start=denoise_start,
sampling_temperature=sampling_temperature,
sampling_min_temperature=sampling_min_temperature,
sampling_top_k=sampling_top_k,
sampling_top_p=sampling_top_p,
sampling_min_p=sampling_min_p,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_repetition_penalty_decay=sampling_repetition_penalty_decay,
sampling_length_penalty=sampling_length_penalty,
sampling_beam_width=sampling_beam_width,
sampling_mirostat_tau=sampling_mirostat_tau,
sampling_mirostat_eta=sampling_mirostat_eta,
sampling_dry_multiplier=sampling_dry_multiplier,
sampling_dry_base=sampling_dry_base,
sampling_dry_allowed_length=sampling_dry_allowed_length,
sampling_entropix=sampling_entropix,
sampling_layer_skip=sampling_layer_skip,
sampling_layer_skip_exit_layer=sampling_layer_skip_exit_layer,
sampling_layer_skip_entropy_threshold=sampling_layer_skip_entropy_threshold,
sampling_layer_skip_varentropy_threshold=sampling_layer_skip_varentropy_threshold,
sampling_refine_on_stop=sampling_refine_on_stop,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
)
# deduce batch_size
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
# implicitly set for training
if training is None and text_list is not None and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
training = n_levels == self.n_resp_levels
# is training
if training:
return self.forward_train(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
)
# is NAR
if (len_list is not None or resps_list is not None) and text_list is not None:
return self.forward_nar(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
**kwargs,
)
# is AR
return self.forward_ar(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
**kwargs,
)
def example_usage():
cfg.device = "cuda"
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
@ -477,33 +909,23 @@ def example_usage():
import re
setup_logging()
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
def load_artifact( path ):
artifact = np.load(path, allow_pickle=True)[()]
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=cfg.device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=cfg.device)
return text, audio
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
batch_size = cfg.hyperparameters.batch_size
cfg.model.experimental.masking_train_p = 0.5
text_list = [ text ]
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
resps_list = [ audio ]
text_list = [ text ] * batch_size
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
resps_list = [ audio ] * batch_size
batch_size = len(text_list)
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_text_tokens': 256,
'n_audio_tokens': 1024,
@ -520,19 +942,11 @@ def example_usage():
'config': cfg.model
}
"""
try:
kwargs['config'] = cfg.model
except Exception as e:
pass
"""
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
#available_tasks = cfg.dataset.tasks_list
available_tasks = ["tts"] # , "stt"]
available_tasks = ["tts-ar", "tts-nar"]
model = AR_NAR(**kwargs).to(device)
steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -620,9 +1034,9 @@ def example_usage():
def sample_data(t=None):
if isinstance(t, list):
tasks = t
texts = [ text_list[0].to(device) if task != "stt" else None for i, task in enumerate( tasks ) ]
proms = [ proms_list[0].to(device) if task != "stt" else [ "stt" ] for i, task in enumerate( tasks ) ]
resps = [ None if task != "stt" else resps_list[0].to(device) for i, task in enumerate( tasks ) ]
texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ]
resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ]
return texts, proms, resps, tasks
@ -634,45 +1048,15 @@ def example_usage():
for i in range(batch_size):
task = random.choice(available_tasks) if t is None else t
text = text_list[i].to(device)
prom = proms_list[i].to(device)
resp = resps_list[i].to(device)
text = text_list[i].to(cfg.device)
prom = proms_list[i].to(cfg.device)
resp = resps_list[i].to(cfg.device)
# do nothing
if task == "tts":
...
elif task == "stt":
prom = [
task
]
# to-do: reimplement this from data.py
"""
elif task == "tts-c":
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
prom = resp[:trim_length]
resp = resp[trim_length:]
prom = prom.to(device)
elif task == "ns" or task == "sr":
# extend the noise to fill the target audio
noise_ext = repeat_extend_audio( noise, resp.shape[0] )
# create the input prompt by merging the target audio with the noise
prom = merge_audio( resp.cpu(), noise_ext, scale=[1, cfg.dataset.noise_scale], device=cfg.dataset.reencode_device )
prom = prom.to(device)
# set the target to just be the noise if <sr>
if task == "sr":
resp = noise_ext
# set the text prompt to empty to train without a guided text prompt
if random.random() < 0.5:
text = torch.tensor([bos_id, eos_id], device=device, dtype=torch.uint8)
prom = [
task,
prom,
]
"""
if task == "stt":
prom = [ task ]
else:
task = "tts"
texts.append( text )
proms.append( prom )
@ -685,27 +1069,18 @@ def example_usage():
def sample( name, steps=500, task=None ):
engine.eval()
texts, proms, resps, tasks = sample_data( task )
text_list, proms_list, resp_list, task_list = sample_data( task )
if "ar" in cfg.model.capabilities:
output = engine( texts, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
text = [ cfg.tokenizer.decode( output[i] ) for i, task in enumerate( tasks ) if task == "stt" ]
texts = [ texts[i] for i, task in enumerate( tasks ) if task != "stt" ]
proms = [ proms[i] for i, task in enumerate( tasks ) if task != "stt" ]
resps = [ output[i] for i, task in enumerate( tasks ) if task != "stt" ]
tasks = [ tasks[i] for i, task in enumerate( tasks ) if task != "stt" ]
print( "STT:", text )
if task == "tts-nar":
len_list = engine(text_list, proms_list, task_list=["len"], max_steps=5, sampling_temperature=0.0 )
len_list = [ resp_list[0].shape[0] for l in len_list ]
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.0 )
else:
resps = [ resp[:, 0] for resp in resps ]
resps_list = engine( text_list, proms_list, task_list=["tts"], max_steps=steps, sampling_temperature=1.0 )
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.0 )
if "nar" in cfg.model.capabilities:
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
for i, o in enumerate(resps_list):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device)
unload_model()
@ -716,7 +1091,7 @@ def example_usage():
texts, proms, resps, tasks = sample_data()
stats = {"step": i}
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks)
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
@ -735,11 +1110,8 @@ def example_usage():
model = ml.compile_model(model, backend=cfg.optimizations.compile)
"""
"""
for task in available_tasks:
sample("final", task=task)
"""
sample("final", task=available_tasks)
engines.quit()

View File

@ -246,9 +246,6 @@ class AudioEmbedding(nn.Module):
# prom
if self.capabilities is None:
offset = 0
# resp
#elif "len" in self.capabilities:
# offset = 1
elif "nar" not in self.capabilities:
offset = 0
elif quant_level > 0:
@ -492,16 +489,6 @@ class Base(nn.Module):
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
"""
elif "len" not in self.capabilities:
# +1 to include the stop token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
# NAR-len model
else:
n_resp_tokens = n_audio_tokens
l_tokens = [n_resp_tokens] * (self.n_resp_levels)
"""
self.unified_position_ids = unified_position_ids
self.interleave = interleave
@ -561,11 +548,11 @@ class Base(nn.Module):
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
# "len" RVQ level-0 gets an additional token
self.rvq_l_emb = Embedding(self.n_resp_levels + (1 if "len" in self.capabilities else 0), d_model)
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
self.time_emb = TimeEmbedding(d_model) if "len" in self.capabilities else None
self.len_emb = Embedding(11, d_model)
self.time_emb = TimeEmbedding(d_model)
if attention_backend == "auto":
attention_backend = "sdpa"
@ -645,7 +632,7 @@ class Base(nn.Module):
use_reentrant=False
))
elif self.arch_type == "llama":
LlamaClass = LlamaModel_Adapted if (self.layerskip or "len" in self.capabilities) else LlamaModel
LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel
if n_experts <= 1:
self.model = LlamaClass(LlamaConfig(
@ -668,12 +655,6 @@ class Base(nn.Module):
# replace with desired attention
if attention_backend not in HF_ATTENTIONS:
self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
# replace with modified Llama
"""
if "len" in self.capabilities:
self.model = ml.replace_attention( self.model, klass=LlamaDecoderLayer_Adapted, target=LlamaDecoderLayer, mode=attention_backend )
"""
else:
self.model = MixtralModel(MixtralConfig(
vocab_size =n_resp_tokens,
@ -1012,6 +993,7 @@ class Base(nn.Module):
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
task_type = task_list[i] if task_list is not None else "tts"
timestep = time_list[i] if time_list is not None else None
# insert task type as a string
inputs[i].append( ( "task", task_type ) )
@ -1023,12 +1005,6 @@ class Base(nn.Module):
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
# pick a random timestep
if "len" in self.capabilities and quant_level == 0:
timestep = random.random()
else:
timestep = 1.0
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
@ -1045,7 +1021,7 @@ class Base(nn.Module):
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if "len" in self.capabilities and quant_level == 0:
if timestep is not None:
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
# insert the current output response
@ -1053,7 +1029,7 @@ class Base(nn.Module):
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask
if "len" in self.capabilities and quant_level == 0:
if timestep is not None:
dropout_mask = _dropout_mask( resps_list[i], p=math.cos(timestep * math.pi * 0.5) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
@ -1072,9 +1048,7 @@ class Base(nn.Module):
inputs[i].append( ( "lang", lang_list[i] ) )
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self.rvq_l_emb is not None:
# override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference)
quant_levels[i] = 0
inputs[i].append( ( "quant_level", torch.tensor([ self.n_resp_levels ], device=device, dtype=torch.int16) ) )
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
@ -1195,7 +1169,7 @@ class Base(nn.Module):
embedding = _interleave_sequence_reshape( embeddings )
# if training NAR-len RVQ level 0
elif "len" in self.capabilities and quant_level == 0 and dropout_mask is not None:
elif dropout_mask is not None:
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
@ -1220,10 +1194,6 @@ class Base(nn.Module):
)
else:
offset = 0
"""
if "len" in self.capabilities:
offset = 1
"""
if "nar" not in self.capabilities:
offset = 0
elif quant_level > 0:
@ -1264,14 +1234,21 @@ class Base(nn.Module):
name,
at=None,
):
find_all = at is None
res = [] if at is None else None
for batch_index, batch_input in enumerate(inputs):
if at is not None and batch_index != at:
if not find_all and batch_index != at:
continue
for n, input in batch_input:
if n == name:
if n != name:
continue
if not find_all:
return input
return None
res.append( input )
return res
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
@ -1401,15 +1378,7 @@ class Base(nn.Module):
for i in range(batch_size):
quant_level = quant_levels[i]
task_name = task_list[i]
causal = False
if "len" in self.capabilities:
causal = task_name == "len"
if quant_level >= self.n_resp_levels:
quant_level = 0
else:
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
if causal:
l = self.causal_size
@ -1488,13 +1457,7 @@ class Base(nn.Module):
logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator
causal = False
if "len" in self.capabilities:
causal = task_name == "len"
if quant_level >= self.n_resp_levels:
quant_level = 0
else:
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
@ -1854,15 +1817,9 @@ class Base(nn.Module):
res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities
if "len" in self.capabilities:
scores = [
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
for logit, tokens in zip(logits, res)
]
else:
scores = [
[ F.softmax(logit[-1, :], dim=-1)[token].item() for token in tokens ]
for logit, tokens in zip(logits, res)
]
return Sampled(res, logits, scores, entropy)

View File

@ -1,672 +0,0 @@
"""
A (mostly) NAR model that handles inferencing all RVQ levels in parallel (NAR).
I believe Meta's Voicebox does this too (predict the utterance length, then decode in parallel)
It *does* have to inference the initial length in an autoregresssive-ish manner (it can technically also be done in parallel)
Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer.
"""
import random
import math
import numpy as np
import logging
import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from torch import Tensor
from tqdm import trange
from .base import Base, list_to_tensor, Categorical, _dropout_mask
from ..config import cfg
from ..emb.qnt import trim, repeat_extend_audio
from ..utils import clamp
_logger = logging.getLogger(__name__)
class NAR(Base):
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
task_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
training: bool | int | None = None,
max_steps: int = 1000,
max_levels: int = 0,
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_layer_skip: bool = False,
sampling_layer_skip_exit_layer: int = -1,
sampling_layer_skip_entropy_threshold: float = -1,
sampling_layer_skip_varentropy_threshold: float = -1,
sampling_refine_on_stop: bool = False,
disable_tqdm=False,
use_lora=None,
):
text_task = [ "stt" ]
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided
if task_list is None:
task_list = [ default_task for _ in range(batch_size) ]
has_none = resps_list is None or text_list is None
if not has_none:
for i, task in enumerate( task_list ):
if resps_list[i] is None or text_list[i] is None:
has_none = True
break
# is training or NAR
if not has_none:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# implicit
if training is None:
training = 0 if n_levels == self.n_resp_levels else None
# is training
if training is not None:
len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# assert n_levels == self.n_resp_levels
# to-do: make this YAML configurable
def sample_task():
return "len" if random.random() < len_train_p else "tts"
# generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ]
# specifies how to sample probabilities of which RVQ levels to train against
rvq_levels_p = self.config.experimental.rvq_levels_p if self.config is not None else "equal"
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ]
# rate to perform token dropout errors
token_dropout_error = self.config.experimental.token_dropout_error
# RVQ levels to apply token dropout on
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
cfg_prom_dropout_p = self.config.experimental.cfg_prom_dropout_p if self.config is not None else 0.0
# implicitly set it to all levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels - 1]
# allow passing a specific distribution of RVQ levels
rvq_levels_p = rvq_levels_p if isinstance(rvq_levels_p, list) else []
if not rvq_levels_p:
lo, hi = quant_level_range[0], quant_level_range[1] + 1
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
# yuck
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels
quant_levels = [ random.choice( rvq_levels_p ) for i in range(batch_size) ]
for i, task in enumerate( task_list ):
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1
# trim resps to only contain all levels below the target level
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# empty string for CFG
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much
for i, quant_level, text, resps, proms, task in zip(range(batch_size), quant_levels, text_list, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance( proms, torch.Tensor ):
if quant_level >= proms.shape[-1]:
quant_levels[i] = proms.shape[-1] - 1
elif isinstance( proms, list ):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = resps.shape[0]
for l in range( quant_level ):
for t in range( steps ):
token = resps[t, l].item()
if random.random() < token_dropout_error:
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0:
# append stop tokens for AR
if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
...
else:
#resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
...
# apply CFG (should probably only apply to NAR quant level 0)
if task not in text_task + ["len"]:
drop_text = False
drop_audio = False
if random.random() < cfg_prom_dropout_p:
drop_audio = True
if random.random() < cfg_cond_dropout_p:
drop_audio = True
drop_text = True
if drop_text:
text_list[i] = text_start_stop_sequence
if drop_audio:
proms_list[i] = None
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
task_list=task_list,
quant_levels=quant_levels,
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
if len_list is not None:
sampling_layer_skip_variables = {} if sampling_layer_skip else None
if max_levels == 0:
max_levels = self.n_max_levels - 1
if sampling_layer_skip:
if sampling_layer_skip_entropy_threshold >= 0:
sampling_layer_skip_variables["entropy_threshold"] = sampling_layer_skip_entropy_threshold
if sampling_layer_skip_varentropy_threshold >= 0:
sampling_layer_skip_variables["varentropy_threshold"] = sampling_layer_skip_varentropy_threshold
if sampling_layer_skip_exit_layer >= 0:
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# initial condition
"""
print( len_list )
len_list = [ clamp(1, max_steps, l) for l in len_list ]
print( len_list )
"""
metrics = []
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
prev_list = [ torch.concat([ mask_token for _ in range( resp_len ) ]) for resp_len in len_list ]
# special "scheduling" to inference RVQ-level 0
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
_super = super()
def demask_sampling( batch_index, seq_len ):
# overrides
max_steps = 10
temperature = 0.3
cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
# if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None:
start_noise = denoise_start
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device )
input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] )
else:
input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = start_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
# boolean mask
is_masked = input_ids == self.stop_token
# setup inputs
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = _super.inputs(
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
null_output = _super.forward(
inputs=null_inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
for logit, null_logits in zip(output.logits, null_output.logits):
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
# sample with sampler settings
filtered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
)
# retrieves unfiltered logits
unfiltered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0,
)
# update previous list of tokens
prev_list = [ input_ids ]
# extract logits
filtered_logits = filtered_sampled.logits[0]
unfiltered_logits = unfiltered_sampled.logits[0]
# extract scores
filtered_scores = filtered_sampled.scores[0]
unfiltered_scores = unfiltered_sampled.scores[0]
# extract sampled tokens
filtered_tokens = filtered_sampled[0][0]
unfiltered_tokens = unfiltered_sampled[0][0]
# sample with gumbelnoise
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
#sampled_ids = filtered_tokens
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
# perform demasked sampling (mock diffusion)
prev_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ]
# expand if given a raw 1D tensor
for i, resp in enumerate(prev_list):
if resp.dim() == 1:
prev_list[i] = resp.unsqueeze(-1)
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits, state = output.logits, output.state
sampled = super().sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0, # sampling_temperature,
#min_temperature=sampling_min_temperature,
#top_p=sampling_top_p,
#top_k=sampling_top_k,
#min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
resps_list = sampled[0]
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
# is AR
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = 10
task_list = [ "len" for _ in range(batch_size) ]
for n in trange(10, desc="AR", disable=disable_tqdm):
len_list = sequence_list
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
output = super().forward(
inputs=inputs,
)
logits = output.logits
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i][0] = stop_token
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
# convert tokens into int
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_100
from functools import partial
from einops import repeat
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine
from ..utils import wrapper as ml
import numpy as np
import re
device = "cuda"
def load_artifact( path ):
artifact = np.load(path, allow_pickle=True)[()]
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
return text, audio
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [ text ]
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
resps_list = [ audio ]
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1,
'p_dropout': 0.1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model
}
"""
try:
kwargs['config'] = cfg.model
except Exception as e:
pass
"""
model = NAR(**kwargs).to(device)
steps = 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.yaml_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer)
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
_logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=1000 ):
if cfg.audio_backend == "dac" and name == "init":
return
engine.eval()
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
len_list = [ min(l, 500) for l in len_list ]
for i, o in enumerate(resps_list):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -47,8 +47,8 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F
start = i + 1
# apply either up to limit tokens, or to the end
end = start + limit if limit > 0 else seq_len
start = clamp(0, seq_len - 1, start)
end = clamp(0, seq_len - 1, end)
start = clamp(start, 0, seq_len - 1)
end = clamp(end, 0, seq_len - 1)
for j in range( start, end ):
distance = j - i
logits[j, token] /= factor * (distance ** decay)