segregated experimental changes into its own streamlined file to avoid breaking the existing model, and it can pivot to the cleaned up code if it actually works (nothing is working)
This commit is contained in:
parent
7d2e64630c
commit
2ea387c08a
|
@ -382,6 +382,12 @@ class Model:
|
|||
def text_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
|
||||
return self.size['text_tokens']
|
||||
return 8575
|
||||
|
||||
@property
|
||||
def phoneme_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"):
|
||||
return self.size['phoneme_tokens']
|
||||
return 256
|
||||
|
||||
@property
|
||||
|
|
|
@ -1611,10 +1611,11 @@ class Dataset(_Dataset):
|
|||
task=task,
|
||||
lang=lang,
|
||||
tone=tone,
|
||||
text=text,
|
||||
proms=proms,
|
||||
resps=resps,
|
||||
raw_text=raw_text,
|
||||
|
||||
phns=text,
|
||||
text=raw_text,
|
||||
|
||||
metadata=metadata,
|
||||
)
|
||||
|
|
|
@ -306,7 +306,7 @@ class TTS():
|
|||
seed = set_seed(seed)
|
||||
batch_size = len(texts)
|
||||
input_kwargs = dict(
|
||||
text_list=texts,
|
||||
phns_list=texts,
|
||||
proms_list=proms,
|
||||
lang_list=langs,
|
||||
disable_tqdm=not use_tqdm,
|
||||
|
@ -421,8 +421,8 @@ class TTS():
|
|||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
|
||||
phns_list = model(
|
||||
phns_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=[task],
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
|
@ -430,9 +430,9 @@ class TTS():
|
|||
else:
|
||||
raise Exception("!")
|
||||
|
||||
text_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||
phns_list = [ cfg.tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
|
||||
|
||||
return text_list[0]
|
||||
return phns_list[0]
|
||||
elif task in ["phn", "un-phn"]:
|
||||
lang = self.encode_lang( language )
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
@ -440,17 +440,17 @@ class TTS():
|
|||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if task == "phn":
|
||||
text_list = None
|
||||
raw_text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
|
||||
phns_list = None
|
||||
text_list = [ self.encode_text( text, phonemize=False ).to(device=self.device, dtype=torch.int16) ]
|
||||
output_tokenizer = cfg.tokenizer
|
||||
else:
|
||||
text_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
|
||||
raw_text_list = None
|
||||
phns_list = [ self.encode_text( text ).to(device=self.device, dtype=torch.int16) ]
|
||||
text_list = None
|
||||
output_tokenizer = cfg.text_tokenizer
|
||||
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=text_list, raw_text_list=raw_text_list, lang_list=[lang], task_list=[task],
|
||||
phns_list = model(
|
||||
phns_list=phns_list, text_list=text_list, lang_list=[lang], task_list=[task],
|
||||
disable_tqdm=not use_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
|
@ -458,9 +458,9 @@ class TTS():
|
|||
else:
|
||||
raise Exception("!")
|
||||
|
||||
text_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in text_list ]
|
||||
phns_list = [ output_tokenizer.decode( text ).replace(" ", "_").replace(" ", "").replace("_", " ") for text in phns_list ]
|
||||
|
||||
return text_list[0]
|
||||
return phns_list[0]
|
||||
|
||||
|
||||
# stuff for rolling context
|
||||
|
@ -504,8 +504,8 @@ class TTS():
|
|||
|
||||
with torch.autocast(self.device, dtype=dtype, enabled=amp):
|
||||
input_kwargs = dict(
|
||||
text_list=[phns] if phonemize else None,
|
||||
raw_text_list=[phns] if not phonemize else None,
|
||||
phns_list=[phns] if phonemize else None,
|
||||
text_list=[phns] if not phonemize else None,
|
||||
proms_list=[prom],
|
||||
lang_list=[lang],
|
||||
disable_tqdm=not use_tqdm,
|
||||
|
|
|
@ -59,11 +59,18 @@ def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
|
|||
|
||||
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
from .ar_nar import AR_NAR # import here because reasons
|
||||
name = config.name
|
||||
model = AR_NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
# crunge
|
||||
if config.version < 7:
|
||||
from .ar_nar import AR_NAR
|
||||
ModelClass = AR_NAR
|
||||
else:
|
||||
from .ar_nar_v2 import AR_NAR_V2
|
||||
ModelClass = AR_NAR_V2
|
||||
|
||||
cfg_kwargs = dict(
|
||||
n_phn_tokens=config.phoneme_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
n_text_tokens=config.text_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
|
@ -75,9 +82,11 @@ def get_model(config, training=True, **model_kwargs):
|
|||
|
||||
training = training,
|
||||
config = config,
|
||||
**model_kwargs
|
||||
)
|
||||
|
||||
name = config.name
|
||||
model = ModelClass(**(cfg_kwargs | model_kwargs))
|
||||
|
||||
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
return model
|
||||
|
|
|
@ -42,22 +42,22 @@ class AR_NAR(Base):
|
|||
self,
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
phns_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
@ -75,9 +75,6 @@ class AR_NAR(Base):
|
|||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
# RVQ levels to apply masking training on
|
||||
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
|
||||
|
||||
if self.version >= 7:
|
||||
masking_train_rvq_levels = [0,self.n_resp_levels]
|
||||
|
||||
if cfg.audio_backend == "nemo":
|
||||
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
|
||||
|
@ -134,13 +131,12 @@ class AR_NAR(Base):
|
|||
timesteps[i] = (timesteps[i] * 0.6) + 0.2
|
||||
|
||||
# trim resps to only contain all levels below the target level
|
||||
if self.version < 7:
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
|
||||
|
||||
# final validations and stuff
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
|
@ -175,10 +171,10 @@ class AR_NAR(Base):
|
|||
"""
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None) or (self.predict_causally):
|
||||
if (quant_level <= 0 and timesteps[i] is None) or (self.predict_causally):
|
||||
# append stop tokens for AR
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence.repeat((1, resps.shape[-1])) ])
|
||||
|
||||
if task == "len":
|
||||
quant_levels[i] = 0
|
||||
|
@ -196,26 +192,26 @@ class AR_NAR(Base):
|
|||
drop_audio = True
|
||||
drop_text = True
|
||||
|
||||
if random.random() < use_raw_text_p and raw_text_list[i] is not None:
|
||||
if random.random() < use_raw_text_p and text_list[i] is not None:
|
||||
swap_text = True
|
||||
|
||||
if drop_text:
|
||||
text_list[i] = text_start_stop_sequence
|
||||
phns_list[i] = text_start_stop_sequence
|
||||
|
||||
if drop_audio:
|
||||
proms_list[i] = None
|
||||
|
||||
if swap_text and not drop_text:
|
||||
text_list[i] = None
|
||||
phns_list[i] = None
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
time_list=timesteps,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
|
@ -231,22 +227,22 @@ class AR_NAR(Base):
|
|||
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
phns_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
quant_levels: list[int] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
|
||||
if quant_levels is None:
|
||||
level = 0
|
||||
|
@ -304,7 +300,7 @@ class AR_NAR(Base):
|
|||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
|
||||
if prefix_context is not None:
|
||||
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ]
|
||||
phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], phns_list ) ]
|
||||
prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ]
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
|
@ -379,7 +375,7 @@ class AR_NAR(Base):
|
|||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -396,7 +392,7 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
phns_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -446,188 +442,11 @@ class AR_NAR(Base):
|
|||
|
||||
return resps_list
|
||||
|
||||
# handles doing demasking inferencing in parallel to inference all tokens
|
||||
# it works if the underlying model is trained properly (which is a pain)
|
||||
def forward_nar_masked_parallel(
|
||||
self,
|
||||
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
level = 0
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
# convert (N)AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
min_length = sampling_kwargs.pop("min_duration", 1)
|
||||
max_length = sampling_kwargs.pop("max_duration", 500)
|
||||
max_steps = sampling_kwargs.get("max_steps", 25)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
annealed_sampling = sampling_kwargs.get("annealed_sampling", True)
|
||||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
temperature = sampling_kwargs.pop("temperature", 0.0)
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
|
||||
# this really helps keep audio coherent so far
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
# to specify the initial mask used
|
||||
vc_list = sampling_kwargs.pop("vc_list", None)
|
||||
vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25)
|
||||
vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25)
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
||||
# force set CFG because too low / no CFG causes issues
|
||||
original_cfg_strength = cfg_strength
|
||||
cfg_strength = max( cfg_strength, minimum_cfg_strength )
|
||||
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
|
||||
for timestep in iterator:
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# ramp down over time
|
||||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# proportion of tokens to remask
|
||||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# normal masking
|
||||
# mask off inputs
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
sampling_temperature = temperature * annealing if annealed_sampling else temperature
|
||||
sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength
|
||||
|
||||
input_resps_list = resps_list
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
|
||||
|
||||
l_scores = []
|
||||
l_resps_list = []
|
||||
# cringe hack because we're able to sample multiple levels at once
|
||||
for l in range(self.n_resp_levels):
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=sampling_temperature,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# retrieves unfiltered logits
|
||||
unfiltered_sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in prev_list ],
|
||||
quant_levels=quant_levels,
|
||||
|
||||
temperature=0.0,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ])
|
||||
# get probability scores
|
||||
l_scores.append([
|
||||
# conjugate to have worse scoring tokens picked for topk
|
||||
1.0 -
|
||||
# only keep scores of tokens we are predicting (and ignore the tokens previously finalized)
|
||||
torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) )
|
||||
# use unmodified logit scores for this, as it offers better stability
|
||||
for scores, masked in zip( unfiltered_sampled.scores, is_masked )
|
||||
])
|
||||
|
||||
resps_list = []
|
||||
scores = []
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels
|
||||
resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1)
|
||||
|
||||
scores.append( score )
|
||||
resps_list.append( resp )
|
||||
|
||||
return resps_list
|
||||
|
||||
def forward_nar(
|
||||
self,
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
phns_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
|
@ -635,7 +454,7 @@ class AR_NAR(Base):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -644,7 +463,7 @@ class AR_NAR(Base):
|
|||
# inference NAR level 0
|
||||
if len_list is not None:
|
||||
resps_list = self.forward_nar_masked(
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
task_list=task_list,
|
||||
|
@ -655,12 +474,12 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
@ -701,7 +520,7 @@ class AR_NAR(Base):
|
|||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -717,7 +536,7 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
phns_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -748,8 +567,8 @@ class AR_NAR(Base):
|
|||
|
||||
task_list: list[Tensor],
|
||||
|
||||
phns_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
|
@ -761,12 +580,12 @@ class AR_NAR(Base):
|
|||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
@ -810,14 +629,14 @@ class AR_NAR(Base):
|
|||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
@ -895,7 +714,7 @@ class AR_NAR(Base):
|
|||
if prefix_context is not None:
|
||||
prefix_text, prefix_resps, _ = prefix_context
|
||||
# to-do: check if we actually need to drop the middle "<eos><bos>"
|
||||
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
|
||||
phns_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, phns_list ) ]
|
||||
# feeding this into the NAR-len should automatically handle things
|
||||
sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ]
|
||||
|
||||
|
@ -906,13 +725,13 @@ class AR_NAR(Base):
|
|||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
if batch_size == 1 and task_list[0] in ["phn", "un-phn"]:
|
||||
text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ]
|
||||
raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||
phns_list = [ sequence_list[i] if task in ["phn"] else phns_list[i] for i, task in enumerate(task_list) ]
|
||||
text_list = [ sequence_list[i] if task in ["un-phn"] else text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
if raw_text_list is not None:
|
||||
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
if text_list is not None:
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
phns_list = [ sequence_list[i] if task in text_task else phns_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
@ -920,14 +739,14 @@ class AR_NAR(Base):
|
|||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
@ -942,7 +761,7 @@ class AR_NAR(Base):
|
|||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
phns_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
|
@ -960,7 +779,7 @@ class AR_NAR(Base):
|
|||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
prev_list=[ resps_list[i] if task not in text_task else phns_list[i] for i, task in enumerate( task_list ) ],
|
||||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
|
@ -981,7 +800,7 @@ class AR_NAR(Base):
|
|||
# first step, expand batch
|
||||
if batch_size == 1:
|
||||
batch_size = beam_width
|
||||
text_list = text_list * beam_width
|
||||
phns_list = phns_list * beam_width
|
||||
proms_list = proms_list * beam_width
|
||||
sequence_list = sequence_list * beam_width
|
||||
task_list = task_list * beam_width
|
||||
|
@ -1046,175 +865,18 @@ class AR_NAR(Base):
|
|||
|
||||
return sequence_list
|
||||
|
||||
def forward_ar_parallel(
|
||||
self,
|
||||
|
||||
task_list: list[Tensor],
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
# convert AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
|
||||
layer_skip = sampling_kwargs.get("layer_skip", False)
|
||||
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
|
||||
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
|
||||
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
|
||||
|
||||
start_slice = [ 0 for _ in range(batch_size) ]
|
||||
sequence_list = [ torch.zeros((0, 8), device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
audio_stop_token = self.stop_token
|
||||
text_stop_token = 2
|
||||
|
||||
state = None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": mirostat_tau, "eta": mirostat_eta, "max_surprise": mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if mirostat_tau > 0.0 else None
|
||||
|
||||
scores = [ 1.0 ] * beam_width
|
||||
metrics = []
|
||||
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
# get next in sequence
|
||||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
if raw_text_list is not None:
|
||||
raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ]
|
||||
else:
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# to-do: find an elegant way to write this
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
state=state,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
output_attentions=entropix_sampling,
|
||||
)
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
|
||||
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
l_resps_list = [ [] for _ in range(batch_size) ]
|
||||
for l in range(self.n_resp_levels):
|
||||
sampled = super().sample(
|
||||
logits=[ logit[l] for logit in logits ],
|
||||
prev_list=[ resp[..., l] for resp in resps_list ],
|
||||
**(sampling_kwargs | {"attentions": output.attentions if entropix_sampling else None}),
|
||||
)
|
||||
|
||||
ids = sampled.ids
|
||||
|
||||
# append tokens
|
||||
for i, token in enumerate(ids):
|
||||
if audio_stop_token in token:
|
||||
stopped[i] = True
|
||||
l_resps_list[i].append(token.to(device))
|
||||
|
||||
for i, l in enumerate(l_resps_list):
|
||||
sequence_list[i] = torch.cat([sequence_list[i], torch.stack(l, dim=-1)])
|
||||
|
||||
# stop token found
|
||||
# stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
for i, l in enumerate( sequence_list ):
|
||||
index = (l == audio_stop_token).nonzero()
|
||||
# kludge for when it doesnt actually hit a stop token but i cant be bothered to properly address it right now since it only came up in test training at the moment
|
||||
try:
|
||||
index = index[:, 0].min()
|
||||
sequence_list[i] = sequence_list[i][:index]
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return sequence_list
|
||||
|
||||
def forward(
|
||||
self,
|
||||
task_list: list[Tensor] | None = None,
|
||||
|
||||
text_list: list[Tensor] | None = None,
|
||||
phns_list: list[Tensor] | None = None,
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
|
||||
|
@ -1224,12 +886,12 @@ class AR_NAR(Base):
|
|||
):
|
||||
# deduce batch_size
|
||||
# deduce batch_size
|
||||
if text_list:
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
@ -1238,7 +900,7 @@ class AR_NAR(Base):
|
|||
batch_size = len(resps_list)
|
||||
|
||||
# implicitly set for training
|
||||
if training is None and text_list is not None and resps_list is not None:
|
||||
if training is None and phns_list is not None and resps_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
|
@ -1249,118 +911,47 @@ class AR_NAR(Base):
|
|||
return self.forward_train(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
)
|
||||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||
if self.version >= 7:
|
||||
return self.forward_nar_masked_parallel(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# NAR demasking for all levels
|
||||
"""
|
||||
resps_lists = [ None for _ in range(batch_size) ]
|
||||
for level in range(self.n_resp_levels):
|
||||
resp_list = self.forward_nar_masked(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
quant_levels=[ level for _ in range(batch_size) ],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
for batch_index, resp in enumerate(resp_list):
|
||||
if resps_lists[batch_index] is None:
|
||||
resps_lists[batch_index] = []
|
||||
|
||||
resps_lists[batch_index].append( resp )
|
||||
|
||||
for batch_index, resps in enumerate(resps_lists):
|
||||
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
||||
|
||||
return resps_lists
|
||||
"""
|
||||
|
||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
||||
return self.forward_nar(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
if self.version >= 7:
|
||||
if task_list is None or task_list[0] != "len":
|
||||
return self.forward_ar_parallel(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
# is AR
|
||||
return self.forward_ar(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
|
@ -1402,12 +993,11 @@ def example_usage():
|
|||
text, audio = load_artifact(f"./data/qnt.{cfg.audio_backend_extension}")
|
||||
batch_size = cfg.hyperparameters.batch_size
|
||||
|
||||
text_list = [ text ] * batch_size
|
||||
phns_list = [ text ] * batch_size
|
||||
proms_list = [ audio[:int(cfg.dataset.frames_per_second), :] ] * batch_size
|
||||
resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size
|
||||
|
||||
kwargs = {
|
||||
'n_text_tokens': cfg.model.text_tokens,
|
||||
'n_audio_tokens': cfg.model.audio_tokens,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
|
@ -1545,7 +1135,7 @@ def example_usage():
|
|||
def sample_data(t=None):
|
||||
if isinstance(t, list):
|
||||
tasks = t
|
||||
texts = [ text_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
|
||||
texts = [ phns_list[0].to(cfg.device) if task not in text_task else None for i, task in enumerate( tasks ) ]
|
||||
proms = [ proms_list[0].to(cfg.device) if task not in text_task else [ "stt" ] for i, task in enumerate( tasks ) ]
|
||||
resps = [ None if task not in text_task else resps_list[0].to(cfg.device) for i, task in enumerate( tasks ) ]
|
||||
|
||||
|
@ -1559,7 +1149,7 @@ def example_usage():
|
|||
for i in range(batch_size):
|
||||
task = random.choice(available_tasks) if t is None else t
|
||||
|
||||
text = text_list[i].to(cfg.device)
|
||||
text = phns_list[i].to(cfg.device)
|
||||
prom = proms_list[i].to(cfg.device)
|
||||
resp = resps_list[i].to(cfg.device)
|
||||
|
||||
|
@ -1580,16 +1170,16 @@ def example_usage():
|
|||
def sample( name, steps=500, task=None ):
|
||||
engine.eval()
|
||||
|
||||
text_list, proms_list, resp_list, task_list = sample_data( task )
|
||||
phns_list, proms_list, resp_list, task_list = sample_data( task )
|
||||
|
||||
if task == "tts-nar":
|
||||
len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 )
|
||||
len_list = [ r.shape[0] for r in resp_list ]
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list )
|
||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list )
|
||||
else:
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 )
|
||||
if resps_list[0].dim() == 1 or resps_list[0].shape[-1] == 1:
|
||||
resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
resps_list = engine( phns_list=phns_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
print( o.shape, o )
|
||||
|
@ -1604,7 +1194,7 @@ def example_usage():
|
|||
texts, proms, resps, tasks = sample_data()
|
||||
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
||||
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
|
1069
vall_e/models/ar_nar_v2.py
Normal file
1069
vall_e/models/ar_nar_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -9,6 +9,8 @@ This should handle all the "low" level things such as:
|
|||
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
|
||||
"""
|
||||
|
||||
# to-do: clean this whole mess up
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -50,22 +52,22 @@ from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
|||
"""
|
||||
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
special_tasks = [ "len", "stt", "phn", "un-phn" ]
|
||||
special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
|
||||
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"ns": "resp",
|
||||
"sr": "resp",
|
||||
"stt": "text",
|
||||
"stt": "phn",
|
||||
"len": "len",
|
||||
"phn": "text",
|
||||
"un-phn": "raw_text",
|
||||
"phn": "phn",
|
||||
"un-phn": "text",
|
||||
}
|
||||
|
||||
# yuck
|
||||
def _get_offsets():
|
||||
return {
|
||||
"text": (0, 256),
|
||||
"phn": (0, 256),
|
||||
"quant_level": (256, 264),
|
||||
"lang": (264, 270),
|
||||
"task": (270, 279),
|
||||
|
@ -126,15 +128,6 @@ def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
|
|||
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
|
||||
return torch.concat( [ i.t() for i in input ] ).t().flatten()
|
||||
|
||||
# automagically parses a batch-list and returns it as a list
|
||||
"""
|
||||
class Embedding(ml.Embedding):
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||
"""
|
||||
|
||||
# Deprecated implementation
|
||||
class MultiEmbedding(nn.Module):
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
|
@ -334,92 +327,6 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
|
|||
x[..., level] = torch.where( dropout_mask, lhs, rhs )
|
||||
return x
|
||||
|
||||
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
if xi.shape[0] == 0:
|
||||
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
|
||||
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
|
||||
if dropout_mask is not None:
|
||||
xi = _dropout_codes( xi, dropout_mask, dropout_token )
|
||||
|
||||
# old way
|
||||
# this probably is a tried and true good way to go about it
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
|
||||
# encode by interleaving
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
"""
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
x = x.view(x.shape[0], -1)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
|
||||
return x
|
||||
|
||||
class AudioDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
resp_levels,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.resp_levels = resp_levels
|
||||
self.head = nn.Linear( d_model, vocab_size * resp_levels )
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
# prior way up-projected then down-projected, but that's silly
|
||||
x = self.head( x )
|
||||
|
||||
# interleave by reshaping / permuting
|
||||
# at least I hope this does it properly, it checks out against my OCR classifier
|
||||
batch_size, seq_len, dim = x.shape
|
||||
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
|
||||
x = x.permute( 0, 2, 1, 3 )
|
||||
|
||||
return x
|
||||
"""
|
||||
|
||||
"""
|
||||
# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model
|
||||
# this doesn't work well
|
||||
"""
|
||||
class ClassifiersParallel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_levels: int, # codebook levels
|
||||
n_tokens: int, # output token count
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)])
|
||||
|
||||
def forward(self, xi: Tensor, stack: bool = True ) -> Tensor:
|
||||
dtype = xi.dtype
|
||||
device = xi.device
|
||||
|
||||
xi = [ proj( xi ) for l, proj in enumerate(self.proj) ]
|
||||
xi = torch.stack( xi )
|
||||
xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
|
||||
|
||||
return xi
|
||||
"""
|
||||
|
||||
class Metrics(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -508,9 +415,9 @@ class Base(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
|
||||
n_text_tokens: int = 256,
|
||||
n_phn_tokens: int = 256,
|
||||
n_audio_tokens: int = 1024,
|
||||
n_raw_text_tokens: int = 8575,
|
||||
n_text_tokens: int = 8575,
|
||||
|
||||
d_model: int = 512,
|
||||
d_ffn: int = 4,
|
||||
|
@ -531,9 +438,9 @@ class Base(nn.Module):
|
|||
self.teaching = False
|
||||
self.config = config
|
||||
|
||||
self.n_text_tokens = n_text_tokens
|
||||
self.n_phn_tokens = n_phn_tokens
|
||||
self.n_audio_tokens = n_audio_tokens
|
||||
self.n_raw_text_tokens = n_raw_text_tokens
|
||||
self.n_text_tokens = n_text_tokens
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
|
@ -595,36 +502,30 @@ class Base(nn.Module):
|
|||
n_tones = self.config.tones if self.config is not None else 1
|
||||
|
||||
# pure AR
|
||||
if self.version < 7:
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens + 2
|
||||
if "nar" not in self.capabilities:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
|
@ -632,7 +533,7 @@ class Base(nn.Module):
|
|||
|
||||
# STT
|
||||
l_classifier_names += [ "stt" ]
|
||||
l_classifier_tokens += [ n_text_tokens ]
|
||||
l_classifier_tokens += [ n_phn_tokens ]
|
||||
|
||||
# LEN
|
||||
if "len" in self.capabilities:
|
||||
|
@ -641,8 +542,8 @@ class Base(nn.Module):
|
|||
|
||||
# TEXT => PHN / PHN => TEXT
|
||||
if self.version >= 6:
|
||||
l_classifier_tokens += [ n_raw_text_tokens ]
|
||||
l_classifier_names = l_embedding_names + [ "raw_text" ]
|
||||
l_classifier_tokens += [ n_text_tokens ]
|
||||
l_classifier_names = l_embedding_names + [ "text" ]
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.unified_position_ids = unified_position_ids
|
||||
|
@ -651,7 +552,7 @@ class Base(nn.Module):
|
|||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.text_emb = Embedding(n_phn_tokens, d_model)
|
||||
self.raw_text_emb = None
|
||||
self.langs_emb = None
|
||||
self.tones_emb = None
|
||||
|
@ -680,7 +581,7 @@ class Base(nn.Module):
|
|||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
self.audio_emb = None
|
||||
elif self.version < 7:
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
|
@ -691,10 +592,6 @@ class Base(nn.Module):
|
|||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
self.audio_emb = None
|
||||
else:
|
||||
self.proms_emb = None
|
||||
self.resps_emb = None
|
||||
self.audio_emb = None
|
||||
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
|
@ -716,34 +613,7 @@ class Base(nn.Module):
|
|||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model)
|
||||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
if self.version >= 7:
|
||||
self.mask_token = self.stop_token + 1
|
||||
if monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
d_model,
|
||||
d_model * 2,
|
||||
(n_audio_tokens + 1),
|
||||
self.n_resp_levels,
|
||||
)
|
||||
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
|
@ -1009,8 +879,8 @@ class Base(nn.Module):
|
|||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
self,
|
||||
phns_list: list[Tensor] | None = None,
|
||||
text_list: list[Tensor] | None = None,
|
||||
raw_text_list: list[Tensor] | None = None,
|
||||
|
||||
proms_list: list[Tensor] | None = None,
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
@ -1023,12 +893,12 @@ class Base(nn.Module):
|
|||
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
if text_list and text_list[0] is not None:
|
||||
if phns_list and phns_list[0] is not None:
|
||||
device = phns_list[0].device
|
||||
batch_size = len(phns_list)
|
||||
elif text_list and text_list[0] is not None:
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
elif raw_text_list and raw_text_list[0] is not None:
|
||||
device = raw_text_list[0].device
|
||||
batch_size = len(raw_text_list)
|
||||
elif proms_list and proms_list[0] is not None:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
|
@ -1054,10 +924,10 @@ class Base(nn.Module):
|
|||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||
if task_type in get_task_symmap() and task_type not in special_tasks:
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
elif text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
|
@ -1096,9 +966,6 @@ class Base(nn.Module):
|
|||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
if self.version >= 7:
|
||||
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1108,10 +975,10 @@ class Base(nn.Module):
|
|||
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
|
||||
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
elif text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
elif raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
|
@ -1147,42 +1014,42 @@ class Base(nn.Module):
|
|||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the output text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "stt") )
|
||||
inputs[i].append( ("classifier_level", "phn") )
|
||||
# Text phonemizing task
|
||||
# Sequence: <raw_text><sep><lang><sep><phonemes>
|
||||
# Sequence: <text><sep><lang><sep><phonemes>
|
||||
elif task_type == "phn":
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "stt") )
|
||||
inputs[i].append( ("classifier_level", "phn") )
|
||||
# Text de-phonemizing task
|
||||
# Sequence: <raw_text><sep><lang><sep><phonemes>
|
||||
# Sequence: <text><sep><lang><sep><phonemes>
|
||||
elif task_type == "un-phn":
|
||||
# insert the text prompt
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
if phns_list is not None and phns_list[i] is not None:
|
||||
inputs[i].append( ( "phn", phns_list[i] ) )
|
||||
# insert lang token if we're trained for it
|
||||
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
|
||||
inputs[i].append( ( "lang", lang_list[i] ) )
|
||||
if self.rvq_l_emb is not None:
|
||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||
# insert the text prompt
|
||||
if raw_text_list is not None and raw_text_list[i] is not None:
|
||||
inputs[i].append( ( "raw_text", raw_text_list[i] ) )
|
||||
if text_list is not None and text_list[i] is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "raw_text") )
|
||||
inputs[i].append( ("classifier_level", "text") )
|
||||
else:
|
||||
raise Exception(f'Unrecognized task: {task_type}')
|
||||
return inputs
|
||||
|
@ -1240,17 +1107,11 @@ class Base(nn.Module):
|
|||
input if quant_level == 0 else input[:, :quant_level]
|
||||
)
|
||||
|
||||
if self.version < 7:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
if self.audio_emb is not None:
|
||||
return self.audio_emb( input )
|
||||
|
||||
return self.proms_emb( input )
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
# yuck
|
||||
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||
|
@ -1293,11 +1154,11 @@ class Base(nn.Module):
|
|||
# *maybe* inject a token for specifying task type
|
||||
task_type = input
|
||||
continue
|
||||
elif name == "text":
|
||||
elif name == "phn":
|
||||
embedding = self.text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
elif name == "raw_text" and self.raw_text_emb is not None:
|
||||
elif name == "text" and self.raw_text_emb is not None:
|
||||
embedding = self.raw_text_emb( input )
|
||||
|
||||
device = embedding.device
|
||||
|
@ -1316,13 +1177,8 @@ class Base(nn.Module):
|
|||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if self.version >= 7:
|
||||
if self.audio_emb is not None:
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
else:
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
# if training NAR-len RVQ level 0
|
||||
elif dropout_mask is not None:
|
||||
if dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
|
||||
|
@ -1494,13 +1350,10 @@ class Base(nn.Module):
|
|||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
return torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
||||
if self.version < 7:
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
return input
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
def _calc_loss( logit, sequence, causal = True ):
|
||||
# filter tokens that exceed the vocab size
|
||||
|
@ -1579,18 +1432,11 @@ class Base(nn.Module):
|
|||
token = token[..., 0]
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if self.version < 7:
|
||||
token = input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = torch.where( dropout_mask, token, self.ignore_index )
|
||||
else:
|
||||
token = input
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||
token = input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = torch.where( dropout_mask, token, self.ignore_index )
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
@ -1785,7 +1631,7 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
causal_levels = [ "phn", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
@ -1802,36 +1648,14 @@ class Base(nn.Module):
|
|||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
# split between the two logit tasks, as audio logits become expanded
|
||||
if self.version >= 7:
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||
|
||||
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||
|
||||
if decoders_indices:
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = self.audio_decoder( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
|
||||
if classifiers_indices:
|
||||
classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ]
|
||||
classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ])
|
||||
classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels )
|
||||
for batch_index, logit in zip( classifiers_indices, classifiers_logits ):
|
||||
logits[batch_index] = logit
|
||||
else:
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
if self.classifier is not None:
|
||||
logits = self.classifier(logits) # * m
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
@ -2206,7 +2030,7 @@ if __name__ == "__main__":
|
|||
seq = seq[:rvq_l, :] if rvq_l > 0 else seq
|
||||
|
||||
sep_embd = embds["sep"](zero)
|
||||
phn_embd = embds["text"](phn)
|
||||
phn_embd = embds["phn"](phn)
|
||||
rvq_l_embd = embds["rvq_l"](rvq_l)
|
||||
lang_embd = embds["lang"](lang)
|
||||
prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype)
|
||||
|
|
1147
vall_e/models/base_v2.py
Normal file
1147
vall_e/models/base_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -27,26 +27,27 @@ _logger = logging.getLogger(__name__)
|
|||
mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
||||
|
||||
def train_feeder(engine, batch, teacher=None):
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["phns"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
batch_size = len(batch["text"])
|
||||
batch_size = len(batch["phns"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
output = engine(
|
||||
text_list=batch["text"],
|
||||
phn_list=batch["phns"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
lang_list=batch["lang"],
|
||||
tone_list=batch["tone"],
|
||||
task_list=batch["task"],
|
||||
raw_text_list=batch["raw_text"],
|
||||
text_list=batch["text"],
|
||||
|
||||
training=True,
|
||||
)
|
||||
|
||||
# get soft targets from teacher
|
||||
"""
|
||||
if teacher is not None:
|
||||
# extract inputs forwarded to model
|
||||
inputs = output.inputs
|
||||
|
@ -99,6 +100,7 @@ def train_feeder(engine, batch, teacher=None):
|
|||
for k in engine.module.loss.keys():
|
||||
engine.module.loss[k] *= (1.0 - A)
|
||||
engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size
|
||||
"""
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
@ -174,7 +176,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
for key in batch.keys():
|
||||
batch[key] = batch[key][:cfg.evaluation.batch_size]
|
||||
|
||||
batch_size = len(batch["text"])
|
||||
batch_size = len(batch["phns"])
|
||||
|
||||
"""
|
||||
# to-do: eval for text tasks
|
||||
|
@ -190,8 +192,8 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
# random prompts requested
|
||||
if args and args.eval_random_text_prompts and eval_name == "subtrain":
|
||||
for i, _ in enumerate(batch["text"]):
|
||||
batch["text"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||
for i, _ in enumerate(batch["phns"]):
|
||||
batch["phns"][i] = get_random_prompt(tokenized=True).to(device=cfg.device)
|
||||
batch["resps"][i] = None
|
||||
"""
|
||||
|
||||
|
@ -200,7 +202,7 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
engine = engines[name]
|
||||
|
||||
base_kwargs = dict(
|
||||
text_list=batch["text"],
|
||||
phns_list=batch["phns"],
|
||||
proms_list=batch["proms"],
|
||||
lang_list=batch["lang"],
|
||||
task_list=batch["task"],
|
||||
|
@ -242,22 +244,6 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
process( name, batch, resps_list )
|
||||
|
||||
"""
|
||||
# evaluate why it's so slow
|
||||
if has_stt:
|
||||
max_steps = max( [ text.shape[0] for text in batch["text"] ] )
|
||||
|
||||
kwargs["text_list"] = None
|
||||
kwargs["task_list"] = [ "stt" for _ in range(batch_size) ]
|
||||
kwargs["proms_list"] = [ ["stt"] for _ in range(batch_size) ]
|
||||
kwargs["resps_list"] = batch["resps"]
|
||||
|
||||
text_list = engine( **kwargs, max_steps=max_steps, sampling_temperature=0.0)
|
||||
text_list = [ cfg.tokenizer.decode( text ) for i, text in enumerate( text_list ) ]
|
||||
|
||||
_logger.info(f"Validation Metrics (STT): {text_list}")
|
||||
"""
|
||||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
|
|
Loading…
Reference in New Issue
Block a user