diff --git a/docs/models.md b/docs/models.md index 530cbbd..89adfbc 100644 --- a/docs/models.md +++ b/docs/models.md @@ -57,6 +57,8 @@ However, having a pure NAR is challenging, as you need to both explicitly provid The NAR-len model keeps things simple by: * training with a fixed masking ratio (80% of the tokens are masked and trained to predict the remaining tokens) * [this paper](https://arxiv.org/abs/2406.05478v1) mentions a fixed ratio during training yields better results than randomly picking a masking ratio. + * randomly picking a duration is actually very ungood and harms the model during trainng. + * this may only matter if swapping from a training on a fixed masking ratio to a random ratio without any timestep information being added. * not including any specific timestep embedding information * some solutions add in the (sinusoidal position'd) timestep embedding, either on top of the input embeddings, or as some normalization weight around the attention head (before and after). * it does not seem to be necessary what-so-ever to require this, especially training under a fixed masking ratio. diff --git a/vall_e/config.py b/vall_e/config.py index c3bea46..0f56c9b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -261,7 +261,7 @@ class ModelExperimentalSettings: masking_train_p: float = 0.0 # odds of training with masking masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on - masking_ratio: str | float = 0.0 # sets a masking ratio, "random" will randomly pick + masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence # classifier-free guidance shit diff --git a/vall_e/inference.py b/vall_e/inference.py index a78debb..60528f7 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -223,8 +223,9 @@ class TTS(): lang = to_device(lang, device=self.device, dtype=torch.uint8) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): - if model_ar is not None: - text_list = model_ar( + model = model_ar if model_ar is not None else model_nar + if model is not None: + text_list = model( text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], task_list=["stt"], disable_tqdm=not tqdm, use_lora=use_lora, @@ -254,20 +255,7 @@ class TTS(): # to-do: add in case for experimental.hf model with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): - if model_ar is not None: - resps_list = model_ar( - text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"], - disable_tqdm=not tqdm, - use_lora=use_lora, - **sampling_kwargs, - ) - resps_list = model_nar( - text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], - disable_tqdm=not tqdm, - use_lora=use_lora, - **sampling_kwargs, - ) - elif model_len is not None: + if model_len is not None: len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_steps": 5} ) # don't need more than that kwargs = {} # nasty hardcode to load a reference file and have that as the input target @@ -286,6 +274,19 @@ class TTS(): use_lora=use_lora, **(sampling_kwargs | kwargs), ) + elif model_ar is not None: + resps_list = model_ar( + text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"], + disable_tqdm=not tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + resps_list = model_nar( + text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], + disable_tqdm=not tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) else: raise Exception("!") diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index a5eafd9..eb86a10 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -277,12 +277,12 @@ class AR_NAR(Base): # timestep inputs time_list = [ timestep for _ in range(batch_size) ] - """ sampling_temperature = temperature * annealing sampling_cfg = cfg_strength * timestep """ sampling_temperature = temperature sampling_cfg = cfg_strength + """ # setup inputs inputs = super().inputs( diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 90df472..cca36d7 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1185,7 +1185,6 @@ class Base(nn.Module): # NAR-len elif classifier_level == "NAR:0:0": embedding = self.resps_emb( - # if masked use masked token, else original token input if input.dim() == 1 else input[:, 0], #quant_level = 0, name = classifier_level, @@ -1222,11 +1221,6 @@ class Base(nn.Module): ) """ - """ - if classifier_level == "AR:0:0": - classifier_level = "NAR:0:0" - """ - embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], #offset = 0 if classifier_level.startswith("AR:") else 1,