swap priority to use nar-len if available, added notes

This commit is contained in:
mrq 2024-11-18 09:40:04 -06:00
parent 069b27570f
commit 6cfdf94bf9
5 changed files with 21 additions and 24 deletions

View File

@ -57,6 +57,8 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
The NAR-len model keeps things simple by:
* training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens)
* [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio.
* randomly picking a duration is actually very ungood and harms the model during trainng.
* this may only matter if swapping from a training on a fixed masking ratio to a random ratio without any timestep information being added.
* not including any specific timestep embedding information
* some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after).
* it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio.

View File

@ -261,7 +261,7 @@ class ModelExperimentalSettings:
masking_train_p: float = 0.0 # odds of training with masking
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
masking_ratio: str | float = 0.0 # sets a masking ratio, "random" will randomly pick
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
# classifier-free guidance shit

View File

@ -223,8 +223,9 @@ class TTS():
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
text_list = model_ar(
model = model_ar if model_ar is not None else model_nar
if model is not None:
text_list = model(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
disable_tqdm=not tqdm,
use_lora=use_lora,
@ -254,20 +255,7 @@ class TTS():
# to-do: add in case for experimental.hf model
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
elif model_len is not None:
if model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
kwargs = {}
# nasty hardcode to load a reference file and have that as the input target
@ -286,6 +274,19 @@ class TTS():
use_lora=use_lora,
**(sampling_kwargs | kwargs),
)
elif model_ar is not None:
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
else:
raise Exception("!")

View File

@ -277,12 +277,12 @@ class AR_NAR(Base):
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]
"""
sampling_temperature = temperature * annealing
sampling_cfg = cfg_strength * timestep
"""
sampling_temperature = temperature
sampling_cfg = cfg_strength
"""
# setup inputs
inputs = super().inputs(

View File

@ -1185,7 +1185,6 @@ class Base(nn.Module):
# NAR-len
elif classifier_level == "NAR:0:0":
embedding = self.resps_emb(
# if masked use masked token, else original token
input if input.dim() == 1 else input[:, 0],
#quant_level = 0,
name = classifier_level,
@ -1222,11 +1221,6 @@ class Base(nn.Module):
)
"""
"""
if classifier_level == "AR:0:0":
classifier_level = "NAR:0:0"
"""
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
#offset = 0 if classifier_level.startswith("AR:") else 1,