From 3826f9bae473ab64fa496d6e282b25e2216a5287 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 2 Nov 2024 21:00:21 -0500 Subject: [PATCH] saner mask creation? (it doesnt matter, kv cache wont work) --- vall_e/data.py | 2 ++ vall_e/emb/similar.py | 2 +- vall_e/models/base.py | 21 +++++++++++++-------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 440a709..81ccd8e 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -513,6 +513,8 @@ def get_task_symmap(): } def _replace_file_extension(path, suffix): + if not isinstance( path, Path ): + path = Path(path) return (path.parent / path.name.split(".")[0]).with_suffix(suffix) def _get_quant_extension(): diff --git a/vall_e/emb/similar.py b/vall_e/emb/similar.py index 3bd7193..f4f3ebb 100644 --- a/vall_e/emb/similar.py +++ b/vall_e/emb/similar.py @@ -72,7 +72,7 @@ def process( # easy way to load the model and handle encoding audio if tts is None: - tts = init_tts( yaml=yaml, restart=False, device=device, dtype=dtype ) + tts = init_tts( config=yaml, restart=False, device=device, dtype=dtype ) features = { key: None for key in metadata_keys } diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 59565df..5b1ef1e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -78,9 +78,11 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): l = list(map(len, x_list)) x = rearrange(pad_sequence(x_list), pattern) m = _create_mask(l, x_list[0].device) + """ m = m.t().unsqueeze(-1) # (t b 1) m = rearrange(m, pattern) - m = m.to(x) + """ + m = m.to(x).int() return x, m def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ): @@ -835,7 +837,7 @@ class Base(nn.Module): output_hidden_states = False, ): x = inputs - m = mask.squeeze(-1).int() + m = mask #.squeeze(-1).int() aux_loss = None attentions = None @@ -844,7 +846,7 @@ class Base(nn.Module): # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: kwargs = dict( - attention_mask=m, + #attention_mask=m, inputs_embeds=x, past_key_values=state, position_ids=position_ids, @@ -1475,7 +1477,9 @@ class Base(nn.Module): return metrics["logits_entropy"] < kwargs["logits_entropy"] and metrics["logits_varentropy"] < kwargs["logits_varentropy"] x_list = self.inputs_to_embeddings( inputs, quant_levels ) - x, m = list_to_tensor(x_list) + + x, mask = list_to_tensor(x_list) + m = mask.unsqueeze(dim=-1) training = self.training device = x.device @@ -1501,16 +1505,17 @@ class Base(nn.Module): # pad mask shape[2] = 1 padding = torch.zeros(shape, dtype=x.dtype, device=x.device) - m = torch.cat([m, padding], dim=1) + mask = torch.cat([mask, padding], dim=1) # needs to be done here as we still have our raw inputs - position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None + #position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None + position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_quant_levels = [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ] output = self._forward( inputs=x, - mask=m, + mask=mask, state=state, position_ids=position_ids, output_attentions = output_attentions, @@ -1530,7 +1535,7 @@ class Base(nn.Module): hidden_states[i] = self.classifier(hidden_states[i]) * m # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... - if self.classifiers is not None: + elif self.classifiers is not None: logits = self.classifiers(logits, levels = classifier_quant_levels) * m if hidden_states is not None: