segregated experimental changes into its own streamlined file to avoid breaking the existing model, and it can pivot to the cleaned up code if it actually works (nothing is working)

This commit is contained in:
mrq 2025-02-26 21:26:13 -06:00
parent 7d2e64630c
commit 2ea387c08a
9 changed files with 2427 additions and 795 deletions

View File

@ -382,6 +382,12 @@ class Model:
def text_tokens(self): def text_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"): if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
return self.size['text_tokens'] return self.size['text_tokens']
return 8575
@property
def phoneme_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"):
return self.size['phoneme_tokens']
return 256 return 256
@property @property

View File

@ -1611,10 +1611,11 @@ class Dataset(_Dataset):
task=task, task=task,
lang=lang, lang=lang,
tone=tone, tone=tone,
text=text,
proms=proms, proms=proms,
resps=resps, resps=resps,
raw_text=raw_text,
phns=text,
text=raw_text,
metadata=metadata, metadata=metadata,
) )

View File

@ -306,7 +306,7 @@ class TTS():
seed = set_seed(seed) seed = set_seed(seed)
batch_size = len(texts) batch_size = len(texts)
input_kwargs = dict( input_kwargs = dict(
text_list=texts, phns_list=texts,
proms_list=proms, proms_list=proms,
lang_list=langs, lang_list=langs,
disable_tqdm=not use_tqdm, disable_tqdm=not use_tqdm,
@ -421,8 +421,8 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp): with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar model = model_ar if model_ar is not None else model_nar
if model is not None: if model is not None:
text_list = model( phns_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task], phns_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
disable_tqdm=not use_tqdm, disable_tqdm=not use_tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs, **sampling_kwargs,
@ -430,9 +430,9 @@ class TTS():
else: else:
raise Exception("!") raise Exception("!")
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] phns_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
return text_list[0] return phns_list[0]
elif task in ["phn", "un-phn"]: elif task in ["phn", "un-phn"]:
lang = self.encode_lang( language ) lang = self.encode_lang( language )
lang = to_device(lang, device=self.device, dtype=torch.uint8) lang = to_device(lang, device=self.device, dtype=torch.uint8)
@ -440,17 +440,17 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp): with torch.autocast(self.device, dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar model = model_ar if model_ar is not None else model_nar
if task == "phn": if task == "phn":
text_list = None phns_list = None
raw_text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ] text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
output_tokenizer = cfg.tokenizer output_tokenizer = cfg.tokenizer
else: else:
text_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ] phns_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
raw_text_list = None text_list = None
output_tokenizer = cfg.text_tokenizer output_tokenizer = cfg.text_tokenizer
if model is not None: if model is not None:
text_list = model( phns_list = model(
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task], phns_list=phns_list, text_list=text_list, lang_list=[lang], task_list=[task],
disable_tqdm=not use_tqdm, disable_tqdm=not use_tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs, **sampling_kwargs,
@ -458,9 +458,9 @@ class TTS():
else: else:
raise Exception("!") raise Exception("!")
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ] phns_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
return text_list[0] return phns_list[0]
# stuff for rolling context # stuff for rolling context
@ -504,8 +504,8 @@ class TTS():
with torch.autocast(self.device, dtype=dtype, enabled=amp): with torch.autocast(self.device, dtype=dtype, enabled=amp):
input_kwargs = dict( input_kwargs = dict(
text_list=[phns] if phonemize else None, phns_list=[phns] if phonemize else None,
raw_text_list=[phns] if not phonemize else None, text_list=[phns] if not phonemize else None,
proms_list=[prom], proms_list=[prom],
lang_list=[lang], lang_list=[lang],
disable_tqdm=not use_tqdm, disable_tqdm=not use_tqdm,

View File

@ -59,11 +59,18 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
def get_model(config, training=True, **model_kwargs): def get_model(config, training=True, **model_kwargs):
from .ar_nar import AR_NAR # import here because reasons # crunge
name = config.name if config.version < 7:
model = AR_NAR( from .ar_nar import AR_NAR
n_text_tokens=config.text_tokens, ModelClass = AR_NAR
else:
from .ar_nar_v2 import AR_NAR_V2
ModelClass = AR_NAR_V2
cfg_kwargs = dict(
n_phn_tokens=config.phoneme_tokens,
n_audio_tokens=config.audio_tokens, n_audio_tokens=config.audio_tokens,
n_text_tokens=config.text_tokens,
d_model=config.dim, d_model=config.dim,
n_heads=config.heads, n_heads=config.heads,
n_layers=config.layers, n_layers=config.layers,
@ -75,9 +82,11 @@ def get_model(config, training=True, **model_kwargs):
training = training, training = training,
config = config, config = config,
**model_kwargs
) )
name = config.name
model = ModelClass(**(cfg_kwargs | model_kwargs))
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model return model

View File

@ -42,22 +42,22 @@ class AR_NAR(Base):
self, self,
task_list: list[Tensor] | None = None, task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
): ):
# deduce batch_size # deduce batch_size
if text_list: if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list: elif proms_list:
device = proms_list[0].device device = proms_list[0].device
batch_size = len(proms_list) batch_size = len(proms_list)
@ -75,9 +75,6 @@ class AR_NAR(Base):
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
# RVQ levels to apply masking training on # RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
if self.version >= 7:
masking_train_rvq_levels = [0,self.n_resp_levels]
if cfg.audio_backend == "nemo": if cfg.audio_backend == "nemo":
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ] rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
@ -134,13 +131,12 @@ class AR_NAR(Base):
timesteps[i] = (timesteps[i] * 0.6) + 0.2 timesteps[i] = (timesteps[i] * 0.6) + 0.2
# trim resps to only contain all levels below the target level # trim resps to only contain all levels below the target level
if self.version < 7: resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0 # tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16) audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# final validations and stuff # final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
@ -175,10 +171,10 @@ class AR_NAR(Base):
""" """
# only apply stop token for RVQ level 0 # only apply stop token for RVQ level 0
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally): if (quant_level <= 0 and timesteps[i] is None) or (self.predict_causally):
# append stop tokens for AR # append stop tokens for AR
if task not in text_task: if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) resps_list[i] = torch.cat([ resps, audio_stop_sequence.repeat((1, resps.shape[-1])) ])
if task == "len": if task == "len":
quant_levels[i] = 0 quant_levels[i] = 0
@ -196,26 +192,26 @@ class AR_NAR(Base):
drop_audio = True drop_audio = True
drop_text = True drop_text = True
if random.random() < use_raw_text_p and raw_text_list[i] is not None: if random.random() < use_raw_text_p and text_list[i] is not None:
swap_text = True swap_text = True
if drop_text: if drop_text:
text_list[i] = text_start_stop_sequence phns_list[i] = text_start_stop_sequence
if drop_audio: if drop_audio:
proms_list[i] = None proms_list[i] = None
if swap_text and not drop_text: if swap_text and not drop_text:
text_list[i] = None phns_list[i] = None
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
task_list=task_list, task_list=task_list,
raw_text_list=raw_text_list, text_list=text_list,
time_list=timesteps, time_list=timesteps,
quant_levels=quant_levels, quant_levels=quant_levels,
@ -231,22 +227,22 @@ class AR_NAR(Base):
task_list: list[Tensor] | None = None, task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
quant_levels: list[int] | None = None, quant_levels: list[int] | None = None,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
**sampling_kwargs, **sampling_kwargs,
): ):
device = text_list[0].device device = phns_list[0].device
batch_size = len(text_list) batch_size = len(phns_list)
if quant_levels is None: if quant_levels is None:
level = 0 level = 0
@ -304,7 +300,7 @@ class AR_NAR(Base):
prefix_context = sampling_kwargs.get("prefix_context", None) prefix_context = sampling_kwargs.get("prefix_context", None)
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring # we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
if prefix_context is not None: if prefix_context is not None:
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ] phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], phns_list ) ]
prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ] prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ]
# if we're denoising from an existing sequence # if we're denoising from an existing sequence
@ -379,7 +375,7 @@ class AR_NAR(Base):
# setup inputs # setup inputs
inputs = super().inputs( inputs = super().inputs(
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -396,7 +392,7 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
text_list=null_text, phns_list=null_text,
proms_list=null_prom, proms_list=null_prom,
resps_list=input_resps_list, resps_list=input_resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -446,188 +442,11 @@ class AR_NAR(Base):
return resps_list return resps_list
# handles doing demasking inferencing in parallel to inference all tokens
# it works if the underlying model is trained properly (which is a pain)
def forward_nar_masked_parallel(
self,
task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
device = text_list[0].device
batch_size = len(text_list)
level = 0
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
min_length = sampling_kwargs.pop("min_duration", 1)
max_length = sampling_kwargs.pop("max_duration", 500)
max_steps = sampling_kwargs.get("max_steps", 25)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used
vc_list = sampling_kwargs.pop("vc_list", None)
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
# force set CFG because too low / no CFG causes issues
original_cfg_strength = cfg_strength
cfg_strength = max( cfg_strength, minimum_cfg_strength )
prefix_context = sampling_kwargs.get("prefix_context", None)
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
quant_levels = [ level for _ in range(batch_size) ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
for timestep in iterator:
# update previous list of tokens
prev_list = resps_list
# ramp down over time
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask
remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.mask_token for resps in resps_list ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
sampling_temperature = temperature * annealing if annealed_sampling else temperature
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
input_resps_list = resps_list
# setup inputs
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
l_scores = []
l_resps_list = []
# cringe hack because we're able to sample multiple levels at once
for l in range(self.n_resp_levels):
# sample with sampler settings
filtered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=sampling_temperature,
**sampling_kwargs,
)
# retrieves unfiltered logits
unfiltered_sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in prev_list ],
quant_levels=quant_levels,
temperature=0.0,
**sampling_kwargs,
)
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
# get probability scores
l_scores.append([
# conjugate to have worse scoring tokens picked for topk
1.0 -
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
# use unmodified logit scores for this, as it offers better stability
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
])
resps_list = []
scores = []
for batch_index in range(batch_size):
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
scores.append( score )
resps_list.append( resp )
return resps_list
def forward_nar( def forward_nar(
self, self,
task_list: list[Tensor] | None = None, task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
@ -635,7 +454,7 @@ class AR_NAR(Base):
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
@ -644,7 +463,7 @@ class AR_NAR(Base):
# inference NAR level 0 # inference NAR level 0
if len_list is not None: if len_list is not None:
resps_list = self.forward_nar_masked( resps_list = self.forward_nar_masked(
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
task_list=task_list, task_list=task_list,
@ -655,12 +474,12 @@ class AR_NAR(Base):
) )
# deduce batch_size # deduce batch_size
if text_list: if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list: elif proms_list:
device = proms_list[0].device device = proms_list[0].device
batch_size = len(proms_list) batch_size = len(proms_list)
@ -701,7 +520,7 @@ class AR_NAR(Base):
quant_levels = [ level for _ in range(batch_size) ] quant_levels = [ level for _ in range(batch_size) ]
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
@ -717,7 +536,7 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
text_list=null_text, phns_list=null_text,
proms_list=null_prom, proms_list=null_prom,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
@ -748,8 +567,8 @@ class AR_NAR(Base):
task_list: list[Tensor], task_list: list[Tensor],
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
@ -761,12 +580,12 @@ class AR_NAR(Base):
**sampling_kwargs, **sampling_kwargs,
): ):
# deduce batch_size # deduce batch_size
if text_list: if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list: elif proms_list:
device = proms_list[0].device device = proms_list[0].device
batch_size = len(proms_list) batch_size = len(proms_list)
@ -810,14 +629,14 @@ class AR_NAR(Base):
inputs = self.inputs( inputs = self.inputs(
task_list=task_list, task_list=task_list,
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
raw_text_list=raw_text_list, text_list=text_list,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
@ -895,7 +714,7 @@ class AR_NAR(Base):
if prefix_context is not None: if prefix_context is not None:
prefix_text, prefix_resps, _ = prefix_context prefix_text, prefix_resps, _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>" # to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ] phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, phns_list ) ]
# feeding this into the NAR-len should automatically handle things # feeding this into the NAR-len should automatically handle things
sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ] sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ]
@ -906,13 +725,13 @@ class AR_NAR(Base):
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm) iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator: for n in iterator:
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]: if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ] phns_list = [ sequence_list[i] if task in ["phn"] else phns_list[i] for i, task in enumerate(task_list) ]
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ] text_list = [ sequence_list[i] if task in ["un-phn"] else text_list[i] for i, task in enumerate(task_list) ]
else: else:
if raw_text_list is not None: if text_list is not None:
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
else:
phns_list = [ sequence_list[i] if task in text_task else phns_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
@ -920,14 +739,14 @@ class AR_NAR(Base):
inputs = self.inputs( inputs = self.inputs(
task_list=task_list, task_list=task_list,
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
raw_text_list=raw_text_list, text_list=text_list,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
@ -942,7 +761,7 @@ class AR_NAR(Base):
if cfg_strength > 0: if cfg_strength > 0:
null_inputs = super().inputs( null_inputs = super().inputs(
text_list=null_text, phns_list=null_text,
proms_list=null_prom, proms_list=null_prom,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
@ -960,7 +779,7 @@ class AR_NAR(Base):
sampled = super().sample( sampled = super().sample(
logits=logits, logits=logits,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], prev_list=[ resps_list[i] if task not in text_task else phns_list[i] for i, task in enumerate( task_list ) ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}), **(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
) )
@ -981,7 +800,7 @@ class AR_NAR(Base):
# first step, expand batch # first step, expand batch
if batch_size == 1: if batch_size == 1:
batch_size = beam_width batch_size = beam_width
text_list = text_list * beam_width phns_list = phns_list * beam_width
proms_list = proms_list * beam_width proms_list = proms_list * beam_width
sequence_list = sequence_list * beam_width sequence_list = sequence_list * beam_width
task_list = task_list * beam_width task_list = task_list * beam_width
@ -1046,175 +865,18 @@ class AR_NAR(Base):
return sequence_list return sequence_list
def forward_ar_parallel(
self,
task_list: list[Tensor],
text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
disable_tqdm=False,
use_lora=None,
**sampling_kwargs,
):
# deduce batch_size
if text_list:
device = text_list[0].device
batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list:
device = resps_list[0].device
batch_size = len(resps_list)
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
# convert AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
temperature = sampling_kwargs.get("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
layer_skip = sampling_kwargs.get("layer_skip", False)
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
start_slice = [ 0 for _ in range(batch_size) ]
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
audio_stop_token = self.stop_token
text_stop_token = 2
state = None
mirostat = [
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if mirostat_tau > 0.0 else None
scores = [ 1.0 ] * beam_width
metrics = []
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
# get next in sequence
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
if raw_text_list is not None:
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
else:
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
inputs = self.inputs(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
quant_levels=quant_levels,
)
# to-do: find an elegant way to write this
output = super().forward(
inputs=inputs,
state=state,
#layer_skip_variables=sampling_layer_skip_variables,
output_attentions=entropix_sampling,
)
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
logits, state = output.logits, output.state
l_resps_list = [ [] for _ in range(batch_size) ]
for l in range(self.n_resp_levels):
sampled = super().sample(
logits=[ logit[l] for logit in logits ],
prev_list=[ resp[..., l] for resp in resps_list ],
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
)
ids = sampled.ids
# append tokens
for i, token in enumerate(ids):
if audio_stop_token in token:
stopped[i] = True
l_resps_list[i].append(token.to(device))
for i, l in enumerate(l_resps_list):
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
# stop token found
# stopped |= r == stop_token
if stopped.all().item():
iterator.close()
break
for i, l in enumerate( sequence_list ):
index = (l == audio_stop_token).nonzero()
# kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment
try:
index = index[:, 0].min()
sequence_list[i] = sequence_list[i][:index]
except Exception as e:
pass
return sequence_list
def forward( def forward(
self, self,
task_list: list[Tensor] | None = None, task_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, phns_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
training: bool | None = None, training: bool | None = None,
@ -1224,12 +886,12 @@ class AR_NAR(Base):
): ):
# deduce batch_size # deduce batch_size
# deduce batch_size # deduce batch_size
if text_list: if phns_list:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list:
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
elif raw_text_list:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list: elif proms_list:
device = proms_list[0].device device = proms_list[0].device
batch_size = len(proms_list) batch_size = len(proms_list)
@ -1238,7 +900,7 @@ class AR_NAR(Base):
batch_size = len(resps_list) batch_size = len(resps_list)
# implicitly set for training # implicitly set for training
if training is None and text_list is not None and resps_list is not None: if training is None and phns_list is not None and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list} n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set)) n_levels = next(iter(n_levels_set))
@ -1249,118 +911,47 @@ class AR_NAR(Base):
return self.forward_train( return self.forward_train(
task_list=task_list, task_list=task_list,
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
raw_text_list=raw_text_list, text_list=text_list,
) )
# is NAR # is NAR
if (len_list is not None or resps_list is not None) and text_list is not None: if (len_list is not None or resps_list is not None) and phns_list is not None:
if self.version >= 7:
return self.forward_nar_masked_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
# NAR demasking for all levels
"""
resps_lists = [ None for _ in range(batch_size) ]
for level in range(self.n_resp_levels):
resp_list = self.forward_nar_masked(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
quant_levels=[ level for _ in range(batch_size) ],
**sampling_kwargs,
)
for batch_index, resp in enumerate(resp_list):
if resps_lists[batch_index] is None:
resps_lists[batch_index] = []
resps_lists[batch_index].append( resp )
for batch_index, resps in enumerate(resps_lists):
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
return resps_lists
"""
return self.forward_nar( return self.forward_nar(
task_list=task_list, task_list=task_list,
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
raw_text_list=raw_text_list, text_list=text_list,
disable_tqdm=disable_tqdm, disable_tqdm=disable_tqdm,
use_lora=use_lora, use_lora=use_lora,
**sampling_kwargs, **sampling_kwargs,
) )
if self.version >= 7:
if task_list is None or task_list[0] != "len":
return self.forward_ar_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
# is AR # is AR
return self.forward_ar( return self.forward_ar(
task_list=task_list, task_list=task_list,
text_list=text_list, phns_list=phns_list,
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
len_list=len_list, len_list=len_list,
raw_text_list=raw_text_list, text_list=text_list,
disable_tqdm=disable_tqdm, disable_tqdm=disable_tqdm,
use_lora=use_lora, use_lora=use_lora,
@ -1402,12 +993,11 @@ def example_usage():
text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}") text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}")
batch_size = cfg.hyperparameters.batch_size batch_size = cfg.hyperparameters.batch_size
text_list = [ text ] * batch_size phns_list = [ text ] * batch_size
proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
kwargs = { kwargs = {
'n_text_tokens': cfg.model.text_tokens,
'n_audio_tokens': cfg.model.audio_tokens, 'n_audio_tokens': cfg.model.audio_tokens,
'd_model': 1024, # 256, # 1024, # 1536 'd_model': 1024, # 256, # 1024, # 1536
@ -1545,7 +1135,7 @@ def example_usage():
def sample_data(t=None): def sample_data(t=None):
if isinstance(t, list): if isinstance(t, list):
tasks = t tasks = t
texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ] texts = [ phns_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ] proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ]
resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ] resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ]
@ -1559,7 +1149,7 @@ def example_usage():
for i in range(batch_size): for i in range(batch_size):
task = random.choice(available_tasks) if t is None else t task = random.choice(available_tasks) if t is None else t
text = text_list[i].to(cfg.device) text = phns_list[i].to(cfg.device)
prom = proms_list[i].to(cfg.device) prom = proms_list[i].to(cfg.device)
resp = resps_list[i].to(cfg.device) resp = resps_list[i].to(cfg.device)
@ -1580,16 +1170,16 @@ def example_usage():
def sample( name, steps=500, task=None ): def sample( name, steps=500, task=None ):
engine.eval() engine.eval()
text_list, proms_list, resp_list, task_list = sample_data( task ) phns_list, proms_list, resp_list, task_list = sample_data( task )
if task == "tts-nar": if task == "tts-nar":
len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
len_list = [ r.shape[0] for r in resp_list ] len_list = [ r.shape[0] for r in resp_list ]
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list ) resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
else: else:
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1: if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) resps_list = engine( phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
for i, o in enumerate(resps_list): for i, o in enumerate(resps_list):
print( o.shape, o ) print( o.shape, o )
@ -1604,7 +1194,7 @@ def example_usage():
texts, proms, resps, tasks = sample_data() texts, proms, resps, tasks = sample_data()
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")

1069
vall_e/models/ar_nar_v2.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,8 @@ This should handle all the "low" level things such as:
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
""" """
# to-do: clean this whole mess up
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -50,22 +52,22 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
""" """
summed_embeddings_task = [ "stt" ] summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt", "phn", "un-phn" ] special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"] non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = { task_outputs = {
"tts": "resp", "tts": "resp",
"ns": "resp", "ns": "resp",
"sr": "resp", "sr": "resp",
"stt": "text", "stt": "phn",
"len": "len", "len": "len",
"phn": "text", "phn": "phn",
"un-phn": "raw_text", "un-phn": "text",
} }
# yuck # yuck
def _get_offsets(): def _get_offsets():
return { return {
"text": (0, 256), "phn": (0, 256),
"quant_level": (256, 264), "quant_level": (256, 264),
"lang": (264, 270), "lang": (264, 270),
"task": (270, 279), "task": (270, 279),
@ -126,15 +128,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
def _interleave_sequence_flatten( input: list[torch.Tensor] ): def _interleave_sequence_flatten( input: list[torch.Tensor] ):
return torch.concat( [ i.t() for i in input ] ).t().flatten() return torch.concat( [ i.t() for i in input ] ).t().flatten()
# automagically parses a batch-list and returns it as a list
"""
class Embedding(ml.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
"""
# Deprecated implementation # Deprecated implementation
class MultiEmbedding(nn.Module): class MultiEmbedding(nn.Module):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False): def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
@ -334,92 +327,6 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
x[..., level] = torch.where( dropout_mask, lhs, rhs ) x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x return x
# naively embeds each level of a codebook, then merges the embeddings with a Linear
class AudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# old way
# this probably is a tried and true good way to go about it
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# encode by interleaving
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
"""
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
x = self.proj(x)
"""
return x
class AudioDecoder(nn.Module):
def __init__(
self,
d_model,
hidden_size,
vocab_size,
resp_levels,
):
super().__init__()
self.resp_levels = resp_levels
self.head = nn.Linear( d_model, vocab_size * resp_levels )
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# prior way up-projected then down-projected, but that's silly
x = self.head( x )
# interleave by reshaping / permuting
# at least I hope this does it properly, it checks out against my OCR classifier
batch_size, seq_len, dim = x.shape
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
x = x.permute( 0, 2, 1, 3 )
return x
"""
"""
# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model
# this doesn't work well
"""
class ClassifiersParallel(nn.Module):
def __init__(
self,
n_levels: int, # codebook levels
n_tokens: int, # output token count
token_dim: int, # dimensionality of the embedding
bias: bool = False,
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)])
def forward(self, xi: Tensor, stack: bool = True ) -> Tensor:
dtype = xi.dtype
device = xi.device
xi = [ proj( xi ) for l, proj in enumerate(self.proj) ]
xi = torch.stack( xi )
xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
return xi
"""
class Metrics(nn.Module): class Metrics(nn.Module):
def __init__( def __init__(
self, self,
@ -508,9 +415,9 @@ class Base(nn.Module):
def __init__( def __init__(
self, self,
n_text_tokens: int = 256, n_phn_tokens: int = 256,
n_audio_tokens: int = 1024, n_audio_tokens: int = 1024,
n_raw_text_tokens: int = 8575, n_text_tokens: int = 8575,
d_model: int = 512, d_model: int = 512,
d_ffn: int = 4, d_ffn: int = 4,
@ -531,9 +438,9 @@ class Base(nn.Module):
self.teaching = False self.teaching = False
self.config = config self.config = config
self.n_text_tokens = n_text_tokens self.n_phn_tokens = n_phn_tokens
self.n_audio_tokens = n_audio_tokens self.n_audio_tokens = n_audio_tokens
self.n_raw_text_tokens = n_raw_text_tokens self.n_text_tokens = n_text_tokens
self.d_model = d_model self.d_model = d_model
self.n_heads = n_heads self.n_heads = n_heads
@ -595,36 +502,30 @@ class Base(nn.Module):
n_tones = self.config.tones if self.config is not None else 1 n_tones = self.config.tones if self.config is not None else 1
# pure AR # pure AR
if self.version < 7: if "nar" not in self.capabilities:
if "nar" not in self.capabilities: n_resp_tokens = n_audio_tokens + 1
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens + 2
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
@ -632,7 +533,7 @@ class Base(nn.Module):
# STT # STT
l_classifier_names += [ "stt" ] l_classifier_names += [ "stt" ]
l_classifier_tokens += [ n_text_tokens ] l_classifier_tokens += [ n_phn_tokens ]
# LEN # LEN
if "len" in self.capabilities: if "len" in self.capabilities:
@ -641,8 +542,8 @@ class Base(nn.Module):
# TEXT => PHN / PHN => TEXT # TEXT => PHN / PHN => TEXT
if self.version >= 6: if self.version >= 6:
l_classifier_tokens += [ n_raw_text_tokens ] l_classifier_tokens += [ n_text_tokens ]
l_classifier_names = l_embedding_names + [ "raw_text" ] l_classifier_names = l_embedding_names + [ "text" ]
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids self.unified_position_ids = unified_position_ids
@ -651,7 +552,7 @@ class Base(nn.Module):
self.ignore_inputs_for_loss = ignore_inputs_for_loss self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks self.noncausal_masks = noncausal_masks
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_phn_tokens, d_model)
self.raw_text_emb = None self.raw_text_emb = None
self.langs_emb = None self.langs_emb = None
self.tones_emb = None self.tones_emb = None
@ -680,7 +581,7 @@ class Base(nn.Module):
levels=self.n_resp_levels if self.version > 3 else None, levels=self.n_resp_levels if self.version > 3 else None,
) )
self.audio_emb = None self.audio_emb = None
elif self.version < 7: else:
self.proms_emb = AudioEmbedding( self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model, [n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True, sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
@ -691,10 +592,6 @@ class Base(nn.Module):
l_embedding_names=l_embedding_names, l_embedding_names=l_embedding_names,
) )
self.audio_emb = None self.audio_emb = None
else:
self.proms_emb = None
self.resps_emb = None
self.audio_emb = None
if self.version >= 3: if self.version >= 3:
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
@ -716,34 +613,7 @@ class Base(nn.Module):
# experimental NAR-only mode # experimental NAR-only mode
self.len_emb = Embedding(11, d_model) self.len_emb = Embedding(11, d_model)
if self.version >= 6: if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
if self.version >= 7:
self.mask_token = self.stop_token + 1
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.audio_decoder = AudioDecoder(
d_model,
d_model * 2,
(n_audio_tokens + 1),
self.n_resp_levels,
)
if attention_backend == "auto": if attention_backend == "auto":
attention_backend = "sdpa" attention_backend = "sdpa"
@ -1009,8 +879,8 @@ class Base(nn.Module):
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs( def inputs(
self, self,
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None, text_list: list[Tensor] | None = None,
raw_text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None, proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
@ -1023,12 +893,12 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None quant_levels: int | list[int] | Tensor | None = None
): ):
if text_list and text_list[0] is not None: if phns_list and phns_list[0] is not None:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list and text_list[0] is not None:
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
elif raw_text_list and raw_text_list[0] is not None:
device = raw_text_list[0].device
batch_size = len(raw_text_list)
elif proms_list and proms_list[0] is not None: elif proms_list and proms_list[0] is not None:
device = proms_list[0].device device = proms_list[0].device
batch_size = len(proms_list) batch_size = len(proms_list)
@ -1054,10 +924,10 @@ class Base(nn.Module):
# prom /may/ include <task> tokens inside to help guide things, per SpeechX # prom /may/ include <task> tokens inside to help guide things, per SpeechX
if task_type in get_task_symmap() and task_type not in special_tasks: if task_type in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it # insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) ) inputs[i].append( ( "lang", lang_list[i] ) )
@ -1096,9 +966,6 @@ class Base(nn.Module):
if resps_list is not None and resps_list[i] is not None: if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) ) inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) ) inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task # Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len> # Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1108,10 +975,10 @@ class Base(nn.Module):
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.") raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
elif raw_text_list is not None and raw_text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
# insert lang token if we're trained for it # insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) ) inputs[i].append( ( "lang", lang_list[i] ) )
@ -1147,42 +1014,42 @@ class Base(nn.Module):
if self.rvq_l_emb is not None: if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the output text prompt # insert the output text prompt
if text_list is not None and text_list[i] is not None: if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "stt") ) inputs[i].append( ("classifier_level", "phn") )
# Text phonemizing task # Text phonemizing task
# Sequence: <raw_text><sep><lang><sep><phonemes> # Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "phn": elif task_type == "phn":
# insert the text prompt # insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None: if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it # insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) ) inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None: if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "stt") ) inputs[i].append( ("classifier_level", "phn") )
# Text de-phonemizing task # Text de-phonemizing task
# Sequence: <raw_text><sep><lang><sep><phonemes> # Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "un-phn": elif task_type == "un-phn":
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "phn", phns_list[i] ) )
# insert lang token if we're trained for it # insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) ) inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None: if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt # insert the text prompt
if raw_text_list is not None and raw_text_list[i] is not None: if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "raw_text", raw_text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ("classifier_level", "raw_text") ) inputs[i].append( ("classifier_level", "text") )
else: else:
raise Exception(f'Unrecognized task: {task_type}') raise Exception(f'Unrecognized task: {task_type}')
return inputs return inputs
@ -1240,17 +1107,11 @@ class Base(nn.Module):
input if quant_level == 0 else input[:, :quant_level] input if quant_level == 0 else input[:, :quant_level]
) )
if self.version < 7: return self.proms_emb(
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level offset = 0,
offset = 0, )
)
if self.audio_emb is not None:
return self.audio_emb( input )
return self.proms_emb( input )
# yuck # yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0 token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
@ -1293,11 +1154,11 @@ class Base(nn.Module):
# *maybe* inject a token for specifying task type # *maybe* inject a token for specifying task type
task_type = input task_type = input
continue continue
elif name == "text": elif name == "phn":
embedding = self.text_emb( input ) embedding = self.text_emb( input )
device = embedding.device device = embedding.device
elif name == "raw_text" and self.raw_text_emb is not None: elif name == "text" and self.raw_text_emb is not None:
embedding = self.raw_text_emb( input ) embedding = self.raw_text_emb( input )
device = embedding.device device = embedding.device
@ -1316,13 +1177,8 @@ class Base(nn.Module):
elif name == "tone" and self.tones_emb is not None: elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input ) embedding = self.tones_emb( input )
elif name == "resp": elif name == "resp":
if self.version >= 7:
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
else:
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
# if training NAR-len RVQ level 0 # if training NAR-len RVQ level 0
elif dropout_mask is not None: if dropout_mask is not None:
embedding = self.resps_emb( embedding = self.resps_emb(
# if masked use masked token, else original token # if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ), torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
@ -1494,13 +1350,10 @@ class Base(nn.Module):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums): if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index) return torch.full_like(input[..., 0], self.ignore_index)
if self.version < 7:
return input if input.dim() == 1 else input[:, quant_level]
return input return input if input.dim() == 1 else input[:, quant_level]
def _calc_loss( logit, sequence, causal = True ): def _calc_loss( logit, sequence, causal = True ):
# filter tokens that exceed the vocab size # filter tokens that exceed the vocab size
@ -1579,18 +1432,11 @@ class Base(nn.Module):
token = token[..., 0] token = token[..., 0]
elif name == "resp": elif name == "resp":
# mask found, apply it # mask found, apply it
if self.version < 7: token = input if input.dim() == 1 else input[:, quant_level]
token = input if input.dim() == 1 else input[:, quant_level]
# mask found, apply it
# mask found, apply it if dropout_mask is not None:
if dropout_mask is not None: token = torch.where( dropout_mask, token, self.ignore_index )
token = torch.where( dropout_mask, token, self.ignore_index )
else:
token = input
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
# not a special input, inject as-is # not a special input, inject as-is
else: else:
token = input token = input
@ -1785,7 +1631,7 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs # needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" ) classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ] causal_levels = [ "phn", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks... # right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
@ -1802,36 +1648,14 @@ class Base(nn.Module):
logits = output.logits logits = output.logits
hidden_states = output.hidden_states hidden_states = output.hidden_states
# split between the two logit tasks, as audio logits become expanded # output projection layer
if self.version >= 7: # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
logits = [ logit for logit in logits ] if self.classifier is not None:
logits = self.classifier(logits) # * m
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ] # to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ] elif self.classifiers is not None:
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ] logits = self.classifiers(logits, levels = classifier_levels )
if decoders_indices:
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = self.audio_decoder( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
if classifiers_indices:
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
logits[batch_index] = logit
else:
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# Remove padding # Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
@ -2206,7 +2030,7 @@ if __name__ == "__main__":
seq = seq[:rvq_l, :] if rvq_l > 0 else seq seq = seq[:rvq_l, :] if rvq_l > 0 else seq
sep_embd = embds["sep"](zero) sep_embd = embds["sep"](zero)
phn_embd = embds["text"](phn) phn_embd = embds["phn"](phn)
rvq_l_embd = embds["rvq_l"](rvq_l) rvq_l_embd = embds["rvq_l"](rvq_l)
lang_embd = embds["lang"](lang) lang_embd = embds["lang"](lang)
prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype) prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype)

1147
vall_e/models/base_v2.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -27,26 +27,27 @@ _logger = logging.getLogger(__name__)
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch, teacher=None): def train_feeder(engine, batch, teacher=None):
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ]) engine.tokens_processed += sum([ text.shape[0] for text in batch["phns"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ]) engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
batch_size = len(batch["text"]) batch_size = len(batch["phns"])
engine.current_batch_size = batch_size engine.current_batch_size = batch_size
output = engine( output = engine(
text_list=batch["text"], phn_list=batch["phns"],
proms_list=batch["proms"], proms_list=batch["proms"],
resps_list=batch["resps"], resps_list=batch["resps"],
lang_list=batch["lang"], lang_list=batch["lang"],
tone_list=batch["tone"], tone_list=batch["tone"],
task_list=batch["task"], task_list=batch["task"],
raw_text_list=batch["raw_text"], text_list=batch["text"],
training=True, training=True,
) )
# get soft targets from teacher # get soft targets from teacher
"""
if teacher is not None: if teacher is not None:
# extract inputs forwarded to model # extract inputs forwarded to model
inputs = output.inputs inputs = output.inputs
@ -99,6 +100,7 @@ def train_feeder(engine, batch, teacher=None):
for k in engine.module.loss.keys(): for k in engine.module.loss.keys():
engine.module.loss[k] *= (1.0 - A) engine.module.loss[k] *= (1.0 - A)
engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size
"""
losses = engine.gather_attribute("loss") losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats") stat = engine.gather_attribute("stats")
@ -174,7 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
for key in batch.keys(): for key in batch.keys():
batch[key] = batch[key][:cfg.evaluation.batch_size] batch[key] = batch[key][:cfg.evaluation.batch_size]
batch_size = len(batch["text"]) batch_size = len(batch["phns"])
""" """
# to-do: eval for text tasks # to-do: eval for text tasks
@ -190,8 +192,8 @@ def run_eval(engines, eval_name, dl, args=None):
# random prompts requested # random prompts requested
if args and args.eval_random_text_prompts and eval_name == "subtrain": if args and args.eval_random_text_prompts and eval_name == "subtrain":
for i, _ in enumerate(batch["text"]): for i, _ in enumerate(batch["phns"]):
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device) batch["phns"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
batch["resps"][i] = None batch["resps"][i] = None
""" """
@ -200,7 +202,7 @@ def run_eval(engines, eval_name, dl, args=None):
engine = engines[name] engine = engines[name]
base_kwargs = dict( base_kwargs = dict(
text_list=batch["text"], phns_list=batch["phns"],
proms_list=batch["proms"], proms_list=batch["proms"],
lang_list=batch["lang"], lang_list=batch["lang"],
task_list=batch["task"], task_list=batch["task"],
@ -242,22 +244,6 @@ def run_eval(engines, eval_name, dl, args=None):
process( name, batch, resps_list ) process( name, batch, resps_list )
"""
# evaluate why it's so slow
if has_stt:
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
kwargs["text_list"] = None
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
kwargs["resps_list"] = batch["resps"]
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
_logger.info(f"Validation Metrics (STT): {text_list}")
"""
stats = {k: sum(v) / len(v) for k, v in stats.items() if v} stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = { engines_stats = {
f'{name}.{eval_name}': stats, f'{name}.{eval_name}': stats,