swap priority to use nar-len if available, added notes
This commit is contained in:
parent
069b27570f
commit
6cfdf94bf9
|
@ -57,6 +57,8 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
The NAR-len model keeps things simple by:
|
||||
* training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens)
|
||||
* [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio.
|
||||
* randomly picking a duration is actually very ungood and harms the model during trainng.
|
||||
* this may only matter if swapping from a training on a fixed masking ratio to a random ratio without any timestep information being added.
|
||||
* not including any specific timestep embedding information
|
||||
* some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after).
|
||||
* it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio.
|
||||
|
|
|
@ -261,7 +261,7 @@ class ModelExperimentalSettings:
|
|||
masking_train_p: float = 0.0 # odds of training with masking
|
||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||
|
||||
masking_ratio: str | float = 0.0 # sets a masking ratio, "random" will randomly pick
|
||||
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick
|
||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||
|
||||
# classifier-free guidance shit
|
||||
|
|
|
@ -223,8 +223,9 @@ class TTS():
|
|||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
if model_ar is not None:
|
||||
text_list = model_ar(
|
||||
model = model_ar if model_ar is not None else model_nar
|
||||
if model is not None:
|
||||
text_list = model(
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
@ -254,20 +255,7 @@ class TTS():
|
|||
|
||||
# to-do: add in case for experimental.hf model
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
if model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
elif model_len is not None:
|
||||
if model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that
|
||||
kwargs = {}
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
|
@ -286,6 +274,19 @@ class TTS():
|
|||
use_lora=use_lora,
|
||||
**(sampling_kwargs | kwargs),
|
||||
)
|
||||
elif model_ar is not None:
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
|
|
|
@ -277,12 +277,12 @@ class AR_NAR(Base):
|
|||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
"""
|
||||
sampling_temperature = temperature * annealing
|
||||
sampling_cfg = cfg_strength * timestep
|
||||
"""
|
||||
sampling_temperature = temperature
|
||||
sampling_cfg = cfg_strength
|
||||
"""
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
|
|
|
@ -1185,7 +1185,6 @@ class Base(nn.Module):
|
|||
# NAR-len
|
||||
elif classifier_level == "NAR:0:0":
|
||||
embedding = self.resps_emb(
|
||||
# if masked use masked token, else original token
|
||||
input if input.dim() == 1 else input[:, 0],
|
||||
#quant_level = 0,
|
||||
name = classifier_level,
|
||||
|
@ -1222,11 +1221,6 @@ class Base(nn.Module):
|
|||
)
|
||||
"""
|
||||
|
||||
"""
|
||||
if classifier_level == "AR:0:0":
|
||||
classifier_level = "NAR:0:0"
|
||||
"""
|
||||
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
#offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||
|
|
Loading…
Reference in New Issue
Block a user