segregated experimental changes into its own streamlined file to avoid breaking the existing model, and it can pivot to the cleaned up code if it actually works (nothing is working)

This commit is contained in:
mrq 2025-02-26 21:26:13 -06:00
parent 7d2e64630c
commit 2ea387c08a
9 changed files with 2427 additions and 795 deletions

View File

@ -382,6 +382,12 @@ class Model:
def text_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
return self.size['text_tokens']
return 8575
@property
def phoneme_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"):
return self.size['phoneme_tokens']
return 256
@property

View File

@ -1611,10 +1611,11 @@ class Dataset(_Dataset):
task=task,
lang=lang,
tone=tone,
text=text,
proms=proms,
resps=resps,
raw_text=raw_text,
phns=text,
text=raw_text,
metadata=metadata,
)

View File

@ -306,7 +306,7 @@ class TTS():
seed = set_seed(seed)
batch_size = len(texts)
input_kwargs = dict(
text_list=texts,
phns_list=texts,
proms_list=proms,
lang_list=langs,
disable_tqdm=not use_tqdm,
@ -421,8 +421,8 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar
if model is not None:
text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
phns_list = model(
phns_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
disable_tqdm=not use_tqdm,
use_lora=use_lora,
**sampling_kwargs,
@ -430,9 +430,9 @@ class TTS():
else:
raise Exception("!")
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
phns_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
return text_list[0]
return phns_list[0]
elif task in ["phn", "un-phn"]:
lang = self.encode_lang( language )
lang = to_device(lang, device=self.device, dtype=torch.uint8)
@ -440,17 +440,17 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar
if task == "phn":
text_list = None
raw_text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
phns_list = None
text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
output_tokenizer = cfg.tokenizer
else:
text_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
raw_text_list = None
phns_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
text_list = None
output_tokenizer = cfg.text_tokenizer
if model is not None:
text_list = model(
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task],
phns_list = model(
phns_list=phns_list, text_list=text_list, lang_list=[lang], task_list=[task],
disable_tqdm=not use_tqdm,
use_lora=use_lora,
**sampling_kwargs,
@ -458,9 +458,9 @@ class TTS():
else:
raise Exception("!")
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
phns_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
return text_list[0]
return phns_list[0]
# stuff for rolling context
@ -504,8 +504,8 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp):
input_kwargs = dict(
text_list=[phns] if phonemize else None,
raw_text_list=[phns] if not phonemize else None,
phns_list=[phns] if phonemize else None,
text_list=[phns] if not phonemize else None,
proms_list=[prom],
lang_list=[lang],
disable_tqdm=not use_tqdm,

View File

@ -59,11 +59,18 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
def get_model(config, training=True, **model_kwargs):
from .ar_nar import AR_NAR # import here because reasons
name = config.name
model = AR_NAR(
n_text_tokens=config.text_tokens,
# crunge
if config.version < 7:
from .ar_nar import AR_NAR
ModelClass = AR_NAR
else:
from .ar_nar_v2 import AR_NAR_V2
ModelClass = AR_NAR_V2
cfg_kwargs = dict(
n_phn_tokens=config.phoneme_tokens,
n_audio_tokens=config.audio_tokens,
n_text_tokens=config.text_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
@ -75,9 +82,11 @@ def get_model(config, training=True, **model_kwargs):
training = training,
config = config,
**model_kwargs
)
name = config.name
model = ModelClass(**(cfg_kwargs | model_kwargs))
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model

View File

@ -42,22 +42,22 @@ class AR_NAR(Base):
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
):
# deduce batch_size
if text_list:
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
@ -75,9 +75,6 @@ class AR_NAR(Base):
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
if self.version >= 7:
masking_train_rvq_levels = [0,self.n_resp_levels]
if cfg.audio_backend == "nemo":
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
@ -134,13 +131,12 @@ class AR_NAR(Base):
timesteps[i] = (timesteps[i] * 0.6) + 0.2
# trim resps to only contain all levels below the target level
if self.version < 7:
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
@ -175,10 +171,10 @@ class AR_NAR(Base):
"""
# only apply stop token for RVQ level 0
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
if (quant_level <= 0 and timesteps[i] is None) or (self.predict_causally):
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
resps_list[i] = torch.cat([ resps, audio_stop_sequence.repeat((1, resps.shape[-1])) ])
if task == "len":
quant_levels[i] = 0
@ -196,26 +192,26 @@ class AR_NAR(Base):
drop_audio = True
drop_text = True
if random.random() < use_raw_text_p and raw_text_list[i] is not None:
if random.random() < use_raw_text_p and text_list[i] is not None:
swap_text = True
if drop_text:
text_list[i] = text_start_stop_sequence
phns_list[i] = text_start_stop_sequence
if drop_audio:
proms_list[i] = None
if swap_text and not drop_text:
text_list[i] = None
phns_list[i] = None
inputs = self.inputs(
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
task_list=task_list,
raw_text_list=raw_text_list,
text_list=text_list,
time_list=timesteps,
quant_levels=quant_levels,
@ -231,22 +227,22 @@ class AR_NAR(Base):
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
quant_levels: list[int] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
device = text_list[0].device
batch_size = len(text_list)
device = phns_list[0].device
batch_size = len(phns_list)
if quant_levels is None:
level = 0
@ -304,7 +300,7 @@ class AR_NAR(Base):
prefix_context = sampling_kwargs.get("prefix_context", None)
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
if prefix_context is not None:
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ]
phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], phns_list ) ]
prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ]
# if we're denoising from an existing sequence
@ -379,7 +375,7 @@ class AR_NAR(Base):
# setup inputs
inputs = super().inputs(
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
@ -396,7 +392,7 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
phns_list=null_text,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
@ -446,188 +442,11 @@ class AR_NAR(Base):
return resps_list
# handles doing demasking inferencing in parallel to inference all tokens
# it works if the underlying model is trained properly (which is a pain)
def forward_nar_masked_parallel(
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
device = text_list[0].device
batch_size = len(text_list)
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
min_length = sampling_kwargs.pop("min_duration", 1)
max_length = sampling_kwargs.pop("max_duration", 500)
max_steps = sampling_kwargs.get("max_steps", 25)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used
vc_list = sampling_kwargs.pop("vc_list", None)
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
original_cfg_strength = cfg_strength
cfg_strength = max( cfg_strength, minimum_cfg_strength )
prefix_context = sampling_kwargs.get("prefix_context", None)
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
for timestep in iterator:
# update previous list of tokens
prev_list = resps_list
# ramp down over time
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.mask_token for resps in resps_list ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
input_resps_list = resps_list
# setup inputs
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
l_scores = []
l_resps_list = []
# cringe hack because we're able to sample multiple levels at once
for l in range(self.n_resp_levels):
# sample with sampler settings
filtered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=sampling_temperature,
**sampling_kwargs,
)
# retrieves unfiltered logits
unfiltered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=0.0,
**sampling_kwargs,
)
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
# get probability scores
l_scores.append([
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
])
resps_list = []
scores = []
for batch_index in range(batch_size):
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
scores.append( score )
resps_list.append( resp )
return resps_list
def forward_nar(
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
@ -635,7 +454,7 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
@ -644,7 +463,7 @@ class AR_NAR(Base):
# inference NAR level 0
if len_list is not None:
resps_list = self.forward_nar_masked(
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
task_list=task_list,
@ -655,12 +474,12 @@ class AR_NAR(Base):
)
# deduce batch_size
if text_list:
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
@ -701,7 +520,7 @@ class AR_NAR(Base):
quant_levels = [ level for _ in range(batch_size) ]
inputs = self.inputs(
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
@ -717,7 +536,7 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
phns_list=null_text,
proms_list=null_prom,
resps_list=prev_list,
lang_list=lang_list,
@ -748,8 +567,8 @@ class AR_NAR(Base):
task_list: list[Tensor],
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
@ -761,12 +580,12 @@ class AR_NAR(Base):
**sampling_kwargs,
):
# deduce batch_size
if text_list:
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
@ -810,14 +629,14 @@ class AR_NAR(Base):
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
quant_levels=quant_levels,
)
@ -895,7 +714,7 @@ class AR_NAR(Base):
if prefix_context is not None:
prefix_text, prefix_resps, _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, phns_list ) ]
# feeding this into the NAR-len should automatically handle things
sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ]
@ -906,13 +725,13 @@ class AR_NAR(Base):
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ]
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ]
phns_list = [ sequence_list[i] if task in ["phn"] else phns_list[i] for i, task in enumerate(task_list) ]
text_list = [ sequence_list[i] if task in ["un-phn"] else text_list[i] for i, task in enumerate(task_list) ]
else:
if raw_text_list is not None:
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
if text_list is not None:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
else:
phns_list = [ sequence_list[i] if task in text_task else phns_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
@ -920,14 +739,14 @@ class AR_NAR(Base):
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
quant_levels=quant_levels,
)
@ -942,7 +761,7 @@ class AR_NAR(Base):
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
phns_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
@ -960,7 +779,7 @@ class AR_NAR(Base):
sampled = super().sample(
logits=logits,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
prev_list=[ resps_list[i] if task not in text_task else phns_list[i] for i, task in enumerate( task_list ) ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
@ -981,7 +800,7 @@ class AR_NAR(Base):
# first step, expand batch
if batch_size == 1:
batch_size = beam_width
text_list = text_list * beam_width
phns_list = phns_list * beam_width
proms_list = proms_list * beam_width
sequence_list = sequence_list * beam_width
task_list = task_list * beam_width
@ -1046,175 +865,18 @@ class AR_NAR(Base):
return sequence_list
def forward_ar_parallel(
self,
task_list: list[Tensor],
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
# deduce batch_size
if text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# convert AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
temperature = sampling_kwargs.get("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
layer_skip = sampling_kwargs.get("layer_skip", False)
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
audio_stop_token = self.stop_token
text_stop_token = 2
state = None
mirostat = [
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if mirostat_tau > 0.0 else None
scores = [ 1.0 ] * beam_width
metrics = []
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
# get next in sequence
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
if raw_text_list is not None:
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
quant_levels=quant_levels,
)
# to-do: find an elegant way to write this
output = super().forward(
inputs=inputs,
state=state,
#layer_skip_variables=sampling_layer_skip_variables,
output_attentions=entropix_sampling,
)
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
logits, state = output.logits, output.state
l_resps_list = [ [] for _ in range(batch_size) ]
for l in range(self.n_resp_levels):
sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in resps_list ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
ids = sampled.ids
# append tokens
for i, token in enumerate(ids):
if audio_stop_token in token:
stopped[i] = True
l_resps_list[i].append(token.to(device))
for i, l in enumerate(l_resps_list):
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
# stop token found
# stopped |= r == stop_token
if stopped.all().item():
iterator.close()
break
for i, l in enumerate( sequence_list ):
index = (l == audio_stop_token).nonzero()
# kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment
try:
index = index[:, 0].min()
sequence_list[i] = sequence_list[i][:index]
except Exception as e:
pass
return sequence_list
def forward(
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
training: bool | None = None,
@ -1224,12 +886,12 @@ class AR_NAR(Base):
):
# deduce batch_size
# deduce batch_size
if text_list:
if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
@ -1238,7 +900,7 @@ class AR_NAR(Base):
batch_size = len(resps_list)
# implicitly set for training
if training is None and text_list is not None and resps_list is not None:
if training is None and phns_list is not None and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
@ -1249,118 +911,47 @@ class AR_NAR(Base):
return self.forward_train(
task_list=task_list,
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
)
# is NAR
if (len_list is not None or resps_list is not None) and text_list is not None:
if self.version >= 7:
return self.forward_nar_masked_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
# NAR demasking for all levels
"""
resps_lists = [ None for _ in range(batch_size) ]
for level in range(self.n_resp_levels):
resp_list = self.forward_nar_masked(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
quant_levels=[ level for _ in range(batch_size) ],
**sampling_kwargs,
)
for batch_index, resp in enumerate(resp_list):
if resps_lists[batch_index] is None:
resps_lists[batch_index] = []
resps_lists[batch_index].append( resp )
for batch_index, resps in enumerate(resps_lists):
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
return resps_lists
"""
if (len_list is not None or resps_list is not None) and phns_list is not None:
return self.forward_nar(
task_list=task_list,
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
if self.version >= 7:
if task_list is None or task_list[0] != "len":
return self.forward_ar_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
# is AR
return self.forward_ar(
task_list=task_list,
text_list=text_list,
phns_list=phns_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
@ -1402,12 +993,11 @@ def example_usage():
text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}")
batch_size = cfg.hyperparameters.batch_size
text_list = [ text ] * batch_size
phns_list = [ text ] * batch_size
proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
kwargs = {
'n_text_tokens': cfg.model.text_tokens,
'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1024, # 256, # 1024, # 1536
@ -1545,7 +1135,7 @@ def example_usage():
def sample_data(t=None):
if isinstance(t, list):
tasks = t
texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
texts = [ phns_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ]
resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ]
@ -1559,7 +1149,7 @@ def example_usage():
for i in range(batch_size):
task = random.choice(available_tasks) if t is None else t
text = text_list[i].to(cfg.device)
text = phns_list[i].to(cfg.device)
prom = proms_list[i].to(cfg.device)
resp = resps_list[i].to(cfg.device)
@ -1580,16 +1170,16 @@ def example_usage():
def sample( name, steps=500, task=None ):
engine.eval()
text_list, proms_list, resp_list, task_list = sample_data( task )
phns_list, proms_list, resp_list, task_list = sample_data( task )
if task == "tts-nar":
len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = [ r.shape[0] for r in resp_list ]
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
else:
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
resps_list = engine( phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
for i, o in enumerate(resps_list):
print( o.shape, o )
@ -1604,7 +1194,7 @@ def example_usage():
texts, proms, resps, tasks = sample_data()
stats = {"step": i}
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")

1069
vall_e/models/ar_nar_v2.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,8 @@ This should handle all the "low" level things such as:
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
"""
# to-do: clean this whole mess up
import math
import torch
import torch.nn.functional as F
@ -50,22 +52,22 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt", "phn", "un-phn" ]
special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = {
"tts": "resp",
"ns": "resp",
"sr": "resp",
"stt": "text",
"stt": "phn",
"len": "len",
"phn": "text",
"un-phn": "raw_text",
"phn": "phn",
"un-phn": "text",
}
# yuck
def _get_offsets():
return {
"text": (0, 256),
"phn": (0, 256),
"quant_level": (256, 264),
"lang": (264, 270),
"task": (270, 279),
@ -126,15 +128,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
return torch.concat( [ i.t() for i in input ] ).t().flatten()
# automagically parses a batch-list and returns it as a list
"""
class Embedding(ml.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
# Deprecated implementation
class MultiEmbedding(nn.Module):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
@ -334,92 +327,6 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
# naively embeds each level of a codebook, then merges the embeddings with a Linear
class AudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# old way
# this probably is a tried and true good way to go about it
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# encode by interleaving
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
"""
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
x = self.proj(x)
"""
return x
class AudioDecoder(nn.Module):
def __init__(
self,
d_model,
hidden_size,
vocab_size,
resp_levels,
):
super().__init__()
self.resp_levels = resp_levels
self.head = nn.Linear( d_model, vocab_size * resp_levels )
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# prior way up-projected then down-projected, but that's silly
x = self.head( x )
# interleave by reshaping / permuting
# at least I hope this does it properly, it checks out against my OCR classifier
batch_size, seq_len, dim = x.shape
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
x = x.permute( 0, 2, 1, 3 )
return x
"""
"""
# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model
# this doesn't work well
"""
class ClassifiersParallel(nn.Module):
def __init__(
self,
n_levels: int, # codebook levels
n_tokens: int, # output token count
token_dim: int, # dimensionality of the embedding
bias: bool = False,
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)])
def forward(self, xi: Tensor, stack: bool = True ) -> Tensor:
dtype = xi.dtype
device = xi.device
xi = [ proj( xi ) for l, proj in enumerate(self.proj) ]
xi = torch.stack( xi )
xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
return xi
"""
class Metrics(nn.Module):
def __init__(
self,
@ -508,9 +415,9 @@ class Base(nn.Module):
def __init__(
self,
n_text_tokens: int = 256,
n_phn_tokens: int = 256,
n_audio_tokens: int = 1024,
n_raw_text_tokens: int = 8575,
n_text_tokens: int = 8575,
d_model: int = 512,
d_ffn: int = 4,
@ -531,9 +438,9 @@ class Base(nn.Module):
self.teaching = False
self.config = config
self.n_text_tokens = n_text_tokens
self.n_phn_tokens = n_phn_tokens
self.n_audio_tokens = n_audio_tokens
self.n_raw_text_tokens = n_raw_text_tokens
self.n_text_tokens = n_text_tokens
self.d_model = d_model
self.n_heads = n_heads
@ -595,36 +502,30 @@ class Base(nn.Module):
n_tones = self.config.tones if self.config is not None else 1
# pure AR
if self.version < 7:
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens + 2
if "nar" not in self.capabilities:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
@ -632,7 +533,7 @@ class Base(nn.Module):
# STT
l_classifier_names += [ "stt" ]
l_classifier_tokens += [ n_text_tokens ]
l_classifier_tokens += [ n_phn_tokens ]
# LEN
if "len" in self.capabilities:
@ -641,8 +542,8 @@ class Base(nn.Module):
# TEXT => PHN / PHN => TEXT
if self.version >= 6:
l_classifier_tokens += [ n_raw_text_tokens ]
l_classifier_names = l_embedding_names + [ "raw_text" ]
l_classifier_tokens += [ n_text_tokens ]
l_classifier_names = l_embedding_names + [ "text" ]
self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids
@ -651,7 +552,7 @@ class Base(nn.Module):
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.text_emb = Embedding(n_text_tokens, d_model)
self.text_emb = Embedding(n_phn_tokens, d_model)
self.raw_text_emb = None
self.langs_emb = None
self.tones_emb = None
@ -680,7 +581,7 @@ class Base(nn.Module):
levels=self.n_resp_levels if self.version > 3 else None,
)
self.audio_emb = None
elif self.version < 7:
else:
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
@ -691,10 +592,6 @@ class Base(nn.Module):
l_embedding_names=l_embedding_names,
)
self.audio_emb = None
else:
self.proms_emb = None
self.resps_emb = None
self.audio_emb = None
if self.version >= 3:
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
@ -716,34 +613,7 @@ class Base(nn.Module):
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model)
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
if self.version >= 7:
self.mask_token = self.stop_token + 1
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.audio_decoder = AudioDecoder(
d_model,
d_model * 2,
(n_audio_tokens + 1),
self.n_resp_levels,
)
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
if attention_backend == "auto":
attention_backend = "sdpa"
@ -1009,8 +879,8 @@ class Base(nn.Module):
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
self,
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
@ -1023,12 +893,12 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None
):
if text_list and text_list[0] is not None:
if phns_list and phns_list[0] is not None:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list and text_list[0] is not None:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list and raw_text_list[0] is not None:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list and proms_list[0] is not None:
device = proms_list[0].device
batch_size = len(proms_list)
@ -1054,10 +924,10 @@ class Base(nn.Module):
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if task_type in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt
if text_list is not None and text_list[i] is not None:
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
@ -1096,9 +966,6 @@ class Base(nn.Module):
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1108,10 +975,10 @@ class Base(nn.Module):
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
# insert the text prompt
if text_list is not None and text_list[i] is not None:
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
@ -1147,42 +1014,42 @@ class Base(nn.Module):
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the output text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "stt") )
inputs[i].append( ("classifier_level", "phn") )
# Text phonemizing task
# Sequence: <raw_text><sep><lang><sep><phonemes>
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "phn":
# insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "stt") )
inputs[i].append( ("classifier_level", "phn") )
# Text de-phonemizing task
# Sequence: <raw_text><sep><lang><sep><phonemes>
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "un-phn":
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ("classifier_level", "raw_text") )
inputs[i].append( ("classifier_level", "text") )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
@ -1240,17 +1107,11 @@ class Base(nn.Module):
input if quant_level == 0 else input[:, :quant_level]
)
if self.version < 7:
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
if self.audio_emb is not None:
return self.audio_emb( input )
return self.proms_emb( input )
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
# yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
@ -1293,11 +1154,11 @@ class Base(nn.Module):
# *maybe* inject a token for specifying task type
task_type = input
continue
elif name == "text":
elif name == "phn":
embedding = self.text_emb( input )
device = embedding.device
elif name == "raw_text" and self.raw_text_emb is not None:
elif name == "text" and self.raw_text_emb is not None:
embedding = self.raw_text_emb( input )
device = embedding.device
@ -1316,13 +1177,8 @@ class Base(nn.Module):
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if self.version >= 7:
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
else:
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
# if training NAR-len RVQ level 0
elif dropout_mask is not None:
if dropout_mask is not None:
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
@ -1494,13 +1350,10 @@ class Base(nn.Module):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
if self.version < 7:
return input if input.dim() == 1 else input[:, quant_level]
return input
return input if input.dim() == 1 else input[:, quant_level]
def _calc_loss( logit, sequence, causal = True ):
# filter tokens that exceed the vocab size
@ -1579,18 +1432,11 @@ class Base(nn.Module):
token = token[..., 0]
elif name == "resp":
# mask found, apply it
if self.version < 7:
token = input if input.dim() == 1 else input[:, quant_level]
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, token, self.ignore_index )
else:
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
token = input if input.dim() == 1 else input[:, quant_level]
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, token, self.ignore_index )
# not a special input, inject as-is
else:
token = input
@ -1785,7 +1631,7 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
causal_levels = [ "phn", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
@ -1802,36 +1648,14 @@ class Base(nn.Module):
logits = output.logits
hidden_states = output.hidden_states
# split between the two logit tasks, as audio logits become expanded
if self.version >= 7:
logits = [ logit for logit in logits ]
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
if decoders_indices:
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = self.audio_decoder( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
if classifiers_indices:
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
logits[batch_index] = logit
else:
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
@ -2206,7 +2030,7 @@ if __name__ == "__main__":
seq = seq[:rvq_l, :] if rvq_l > 0 else seq
sep_embd = embds["sep"](zero)
phn_embd = embds["text"](phn)
phn_embd = embds["phn"](phn)
rvq_l_embd = embds["rvq_l"](rvq_l)
lang_embd = embds["lang"](lang)
prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype)

1147
vall_e/models/base_v2.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -27,26 +27,27 @@ _logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch, teacher=None):
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ text.shape[0] for text in batch["phns"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"])
batch_size = len(batch["phns"])
engine.current_batch_size = batch_size
output = engine(
text_list=batch["text"],
phn_list=batch["phns"],
proms_list=batch["proms"],
resps_list=batch["resps"],
lang_list=batch["lang"],
tone_list=batch["tone"],
task_list=batch["task"],
raw_text_list=batch["raw_text"],
text_list=batch["text"],
training=True,
)
# get soft targets from teacher
"""
if teacher is not None:
# extract inputs forwarded to model
inputs = output.inputs
@ -99,6 +100,7 @@ def train_feeder(engine, batch, teacher=None):
for k in engine.module.loss.keys():
engine.module.loss[k] *= (1.0 - A)
engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size
"""
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")
@ -174,7 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size]
batch_size = len(batch["text"])
batch_size = len(batch["phns"])
"""
# to-do: eval for text tasks
@ -190,8 +192,8 @@ def run_eval(engines, eval_name, dl, args=None):
# random prompts requested
if args and args.eval_random_text_prompts and eval_name == "subtrain":
for i, _ in enumerate(batch["text"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
for i, _ in enumerate(batch["phns"]):
batch["phns"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None
"""
@ -200,7 +202,7 @@ def run_eval(engines, eval_name, dl, args=None):
engine = engines[name]
base_kwargs = dict(
text_list=batch["text"],
phns_list=batch["phns"],
proms_list=batch["proms"],
lang_list=batch["lang"],
task_list=batch["task"],
@ -242,22 +244,6 @@ def run_eval(engines, eval_name, dl, args=None):
process( name, batch, resps_list )
"""
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
kwargs["text_list"] = None
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
kwargs["resps_list"] = batch["resps"]
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
"""
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {
f'{name}.{eval_name}': stats,