diff --git a/vall_e/config.py b/vall_e/config.py index 770dbf1..4a4d71a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -260,6 +260,7 @@ class ModelExperimentalSettings: masking_train_p: float = 0.0 # odds of training with masking masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on + masking_separate_embeddings: bool = False # classifier-free guidance shit cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 5a51496..7f7c609 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -204,6 +204,17 @@ def load_engines(training=True, **model_kwargs): continue state[k] = ml.resize_weight( state[k], tokens ) + """ + if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state: + state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone() + state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone() + + del state['classifiers.proj.8.weight'] + del state['classifiers.proj.8.bias'] + + state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() + """ + model.load_state_dict(state, strict=cfg.trainer.strict_loading) # load lora weights if exists diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 611f69f..3648e98 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -108,7 +108,7 @@ class AR_NAR(Base): #p = math.acos(r) / (math.pi * 0.5) #timesteps[i] = 1.0 - clamp(p, 0.0, 1.0) timesteps[i] = random.random() - + # trim resps to only contain all levels below the target level resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] @@ -896,7 +896,7 @@ def example_usage(): text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") batch_size = cfg.hyperparameters.batch_size - cfg.model.experimental.masking_train_p = 0.5 + cfg.model.experimental.masking_train_p = 1.0 text_list = [ text ] * batch_size proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0464acf..0f49c59 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -476,6 +476,8 @@ class Base(nn.Module): layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1 layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1 + masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False + n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 n_tones = self.config.tones if self.config is not None else 1 @@ -484,7 +486,12 @@ class Base(nn.Module): if "nar" not in self.capabilities: n_resp_tokens = n_audio_tokens + 1 l_tokens = [n_resp_tokens] * self.n_resp_levels - # AR+NAR model / NAR-len model + # NAR-len model + elif "len" in self.capabilities and masking_separate_embeddings: + # +1 to include the stop or mask token + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) + l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] + # AR+NAR model else: # +1 to include the stop or mask token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) @@ -495,6 +502,7 @@ class Base(nn.Module): self.layerskip = layerskip self.special_tasks = [ "len", "stt" ] self.inject_timestep_embedding = False # results in bad output + self.masking_separate_embeddings = masking_separate_embeddings self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -1182,7 +1190,7 @@ class Base(nn.Module): embedding = self.resps_emb( # if masked use masked token, else original token torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), - offset = 0, + offset = -1 if self.masking_separate_embeddings else 0, # pick last quant_level = 0, ) # cheat-y way to handle performing STT across all levels @@ -1325,10 +1333,9 @@ class Base(nn.Module): device = logits[0].device batch_size = len(logits) summed_embeddings_task = [ "stt" ] - #classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] - tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] + is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ] + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ] # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1623,15 +1630,15 @@ class Base(nn.Module): position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ] - + is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ] + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ] + if self.inject_timestep_embedding: timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ] timesteps = [ self.time_emb(timestep) if timestep is not None else None for i, timestep in enumerate(timesteps) ] else: timesteps = [] - classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] - output = self._forward( inputs=x, mask=mask,